Fix backward pass: global loss scaling, weight transpose, AdamW, activation clipping

Three bugs prevented loss from converging below 5.5 (unigram plateau):

1. FP16 underflow in ANE backward matmuls: gradient (~8e-5) × weight (~0.036)
   products flushed to zero in fp16. Fixed with global loss scaling (256×)
   applied once to dlogits, divided out before Adam update.

2. Backward weight staging used raw weights instead of transposed — all 4
   backward kernels (wotBwd, qkvBwd, ffnBwdW2t, ffnBwdW13t) now use
   pre-transposed buffers (Wot_buf, Wqt_buf, etc.).

3. Added AdamW (decoupled weight decay, wd=0.1 for weights, 0.0 for norms),
   activation clipping (act_clip=20), gradient clipping, cosine LR schedule,
   per-layer IOSurface weight pre-staging, and vocab compaction.

Loss now drops 9.14 → 5.74 in 500 steps from random init (87ms/step).
This commit is contained in:
maderix 2026-03-05 07:23:08 -08:00
parent efcf193075
commit 926f977b40
5 changed files with 515 additions and 201 deletions

View File

@ -62,6 +62,18 @@ typedef struct {
// ANE kernel handle // ANE kernel handle
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern; typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
// Per-layer IOSurfaces for pre-staged weights
typedef struct {
IOSurfaceRef sdpaFwd_in, ffnFused_in;
IOSurfaceRef ffnBwdW2t_in, ffnBwdW13t_in, wotBwd_in, qkvBwd_in;
} PerLayerSurfaces;
// Per-layer ANE requests (bound to per-layer IOSurfaces)
typedef struct {
void *sdpaFwd, *ffnFused;
void *ffnBwdW2t, *ffnBwdW13t, *wotBwd, *qkvBwd;
} PerLayerRequests;
// Checkpoint header // Checkpoint header
typedef struct { typedef struct {
int magic, version, step, total_steps; int magic, version, step, total_steps;

View File

@ -53,13 +53,13 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c
free(ss); free(rrms); free(dot); free(ss); free(rrms); free(dot);
} }
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) { static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps, float wd) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t); float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) { for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i]; s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i]; s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2; float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps); w[i] -= lr * (mh / (sqrtf(vh) + eps) + wd * w[i]);
} }
} }

View File

@ -74,17 +74,17 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int
IOSurfaceUnlock(s, 0, NULL); IOSurfaceUnlock(s, 0, NULL);
} }
// fp32 IOSurface I/O (for dynamic matmul kernels that use fp32 input/output) // fp16 IOSurface I/O (for dynamic matmul kernels with fp16 input/output)
// Layout: [1, IC, 1, SP] where SP = SEQ + OC // Layout: [1, IC, 1, SP] where SP = SEQ + OC
// Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC] // Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC]
static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq, static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
const float *W, int oc) { const float *W, int oc) {
int sp = seq + oc; int sp = seq + oc;
IOSurfaceLock(s, 0, NULL); IOSurfaceLock(s, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(s); _Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < ic; d++) { for (int d = 0; d < ic; d++) {
memcpy(buf + d*sp, act + d*seq, seq*4); cvt_f32_f16(buf + d*sp, act + d*seq, seq);
memcpy(buf + d*sp + seq, W + d*oc, oc*4); cvt_f32_f16(buf + d*sp + seq, W + d*oc, oc);
} }
IOSurfaceUnlock(s, 0, NULL); IOSurfaceUnlock(s, 0, NULL);
} }
@ -92,7 +92,7 @@ static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
// Read output from dynamic matmul kernel: [1, OC, 1, SEQ] // Read output from dynamic matmul kernel: [1, OC, 1, SEQ]
static void io_read_dyn(IOSurfaceRef s, float *out, int oc, int seq) { static void io_read_dyn(IOSurfaceRef s, float *out, int oc, int seq) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL); IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(out, (float*)IOSurfaceGetBaseAddress(s), oc * seq * 4); cvt_f16_f32(out, (_Float16*)IOSurfaceGetBaseAddress(s), oc * seq);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL); IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
} }
@ -145,3 +145,201 @@ static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil; id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
} }
// Evaluate with a per-layer request (different ioIn, same model)
static void ane_eval_req(Kern *k, void *request) {
id mdl = (__bridge id)k->model; id req = (__bridge id)request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}
// Create an ANE request binding a custom ioIn to a kernel's model+ioOut
static void *make_request(Kern *k, IOSurfaceRef ioIn) {
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
id req = ((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0);
return (void*)CFBridgingRetain(req);
}
// ===== Per-layer weight staging (write once, reuse across steps) =====
// All surfaces are now fp16 — staging converts fp32 weights to fp16
// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] fp16 — weights at sp[SEQ:]
static void stage_sdpa_fwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk,
const float *Wv, const float *Wo) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + SEQ, Wq + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_sdpa_fwd_acts(IOSurfaceRef s, const float *xnorm) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnFused: [1, DIM, 1, 2*SEQ+3*HIDDEN] fp16
static void stage_ffn_fused_weights(IOSurfaceRef s,
const float *W1t, const float *W3t, const float *W2_orig) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 3*HIDDEN;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + 2*SEQ, W1t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*sp + 2*SEQ+HIDDEN, W3t + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*sp + 2*SEQ+2*HIDDEN, W2_orig + d*HIDDEN, HIDDEN);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_fused_acts(IOSurfaceRef s, const float *x2norm, const float *x2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 3*HIDDEN;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp, x2norm + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + SEQ, x2 + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] fp16
static void stage_ffn_w13_weights(IOSurfaceRef s, const float *W1, const float *W3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN);
cvt_f32_f16(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_w13_acts(IOSurfaceRef s, const float *xnorm) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, xnorm + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnW2: [1, HIDDEN, 1, SEQ+DIM] fp16
static void stage_ffn_w2_weights(IOSurfaceRef s, const float *W2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++)
cvt_f32_f16(buf + d*sp + SEQ, W2 + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_w2_acts(IOSurfaceRef s, const float *gate) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++)
cvt_f32_f16(buf + d*sp, gate + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnBwdW2t: [1, DIM, 1, SEQ+HIDDEN] fp16
static void stage_ffn_bwd_w2t_weights(IOSurfaceRef s, const float *W2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + HIDDEN;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp + SEQ, W2 + d*HIDDEN, HIDDEN);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_bwd_w2t_acts(IOSurfaceRef s, const float *dffn) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + HIDDEN;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, dffn + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// ffnBwdW13t: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16
static void stage_ffn_bwd_w13t_weights(IOSurfaceRef s, const float *W1, const float *W3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
cvt_f32_f16(buf + d*sp + 2*SEQ, W1 + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + 2*SEQ + DIM, W3 + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_ffn_bwd_w13t_acts(IOSurfaceRef s, const float *dh1, const float *dh3) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
cvt_f32_f16(buf + d*sp, dh1 + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// wotBwd: [1, DIM, 1, SEQ+DIM] fp16
static void stage_wot_bwd_weights(IOSurfaceRef s, const float *Wo) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp + SEQ, Wo + d*DIM, DIM);
IOSurfaceUnlock(s, 0, NULL);
}
static void write_wot_bwd_acts(IOSurfaceRef s, const float *dx2) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = SEQ + DIM;
for (int d = 0; d < DIM; d++)
cvt_f32_f16(buf + d*sp, dx2 + d*SEQ, SEQ);
IOSurfaceUnlock(s, 0, NULL);
}
// qkvBwd: [1, DIM, 1, 3*SEQ+3*DIM] fp16
static void stage_qkv_bwd_weights(IOSurfaceRef s, const float *Wq, const float *Wk, const float *Wv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp + 3*SEQ, Wq + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + 3*SEQ + DIM, Wk + d*DIM, DIM);
cvt_f32_f16(buf + d*sp + 3*SEQ + 2*DIM, Wv + d*DIM, DIM);
}
IOSurfaceUnlock(s, 0, NULL);
}
static void write_qkv_bwd_acts(IOSurfaceRef s, const float *dq, const float *dk, const float *dv) {
IOSurfaceLock(s, 0, NULL);
_Float16 *buf = (_Float16*)IOSurfaceGetBaseAddress(s);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
cvt_f32_f16(buf + d*sp, dq + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + SEQ, dk + d*SEQ, SEQ);
cvt_f32_f16(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ);
}
IOSurfaceUnlock(s, 0, NULL);
}
// Free per-layer surfaces and requests
static void free_per_layer(PerLayerSurfaces *pls, PerLayerRequests *plr) {
for (int L = 0; L < NLAYERS; L++) {
CFRelease(pls[L].sdpaFwd_in); CFRelease(pls[L].ffnFused_in);
CFRelease(pls[L].ffnBwdW2t_in); CFRelease(pls[L].ffnBwdW13t_in);
CFRelease(pls[L].wotBwd_in); CFRelease(pls[L].qkvBwd_in);
CFRelease(plr[L].sdpaFwd); CFRelease(plr[L].ffnFused);
CFRelease(plr[L].ffnBwdW2t); CFRelease(plr[L].ffnBwdW13t);
CFRelease(plr[L].wotBwd); CFRelease(plr[L].qkvBwd);
}
}

View File

@ -45,25 +45,21 @@ static void gen_dyn_matmul(NSMutableString *m, const char *prefix,
} }
// ===== Dynamic matmul kernel: y = x @ W ===== // ===== Dynamic matmul kernel: y = x @ W =====
// Input: [1, IC, 1, SEQ+OC] fp32 — act[0:SEQ] + W[SEQ:SEQ+OC] // Input: [1, IC, 1, SEQ+OC] fp16 — act[0:SEQ] + W[SEQ:SEQ+OC]
// Output: [1, OC, 1, SEQ] fp32 // Output: [1, OC, 1, SEQ] fp16
static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) { static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string]; NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR]; [m appendString:MIL_HDR];
int sp = seq + oc; int sp = seq + oc;
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp]; [m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", ic, sp];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"]; gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "x");
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", ic, sp]; [m appendString:@" } -> (mm_y);\n}\n"];
gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "xh");
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=mm_y)[name=string(\"cout\")];\n", oc, seq];
[m appendString:@" } -> (y);\n}\n"];
return m; return m;
} }
// ===== SDPA forward (dynamic weights) ===== // ===== SDPA forward (dynamic weights) =====
// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul // Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul
// Input: [1, DIM, 1, SEQ + 4*DIM] fp32 // Input: [1, DIM, 1, SEQ + 4*DIM] fp16
// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels) // sp[0:SEQ] = xnorm (rmsnorm output, DIM channels)
// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM] // sp[SEQ:SEQ+DIM] = Wq[DIM,DIM]
// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM] // sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM]
@ -77,32 +73,29 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
int sp_in = SEQ + w_total; int sp_in = SEQ + w_total;
NSMutableString *m = [NSMutableString string]; NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR]; [m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in]; [m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
// Cast to fp16
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm [1,DIM,1,SEQ] // Slice xnorm [1,DIM,1,SEQ]
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"]; [m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=x,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice Wq [1,DIM,1,DIM] // Slice Wq [1,DIM,1,DIM]
[m appendFormat:@" tensor<int32, [4]> bq = const()[name=string(\"bq\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ]; [m appendFormat:@" tensor<int32, [4]> bq = const()[name=string(\"bq\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM]; [m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=xh,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=x,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM];
// Slice Wk // Slice Wk
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+DIM]; [m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=xh,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=x,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM];
// Slice Wv // Slice Wv
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+2*DIM]; [m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=xh,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=x,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM];
// Slice Wo // Slice Wo
[m appendFormat:@" tensor<int32, [4]> bo = const()[name=string(\"bo\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+3*DIM]; [m appendFormat:@" tensor<int32, [4]> bo = const()[name=string(\"bo\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wo = slice_by_size(x=xh,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wo = slice_by_size(x=x,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM];
// Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D] // Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D]
[m appendFormat:@" tensor<int32, [4]> r2 = const()[name=string(\"r2\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor<int32, [4]> r2 = const()[name=string(\"r2\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
@ -173,10 +166,7 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
// Cast to fp32 [m appendString:@" } -> (out);\n}\n"];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 6*DIM, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
return m; return m;
} }
@ -257,6 +247,101 @@ static NSString *gen_ffn_w13_dynamic(void) {
return m; return m;
} }
// ===== Fused FFN forward: W1,W3 + SiLU + W2 + residual =====
// RMSNorm stays on CPU (ANE can't handle RMS + 3 matmuls without BNNS fallback)
// Replaces: ffnW13 + CPU gate read + ffnW2 + CPU residual
// Input: [1, DIM, 1, 2*SEQ + 3*HIDDEN] fp16
// sp[0:SEQ] = x2norm (RMSNorm output, from CPU)
// sp[SEQ:2*SEQ] = x2 (residual, for x_next = x2 + ffn_out)
// sp[2*SEQ : 2*SEQ+HIDDEN] = W1t[DIM,HIDDEN]
// sp[2*SEQ+HIDDEN : 2*SEQ+2*HIDDEN] = W3t[DIM,HIDDEN]
// sp[2*SEQ+2*HIDDEN : 2*SEQ+3*HIDDEN] = W2_orig[DIM,HIDDEN] (transposed inside kernel)
// Output: [1, DIM + 3*HIDDEN, 1, SEQ] fp16
// = concat(x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN])
static NSString *gen_ffn_fused_dynamic(void) {
int sp_in = 2*SEQ + 3*HIDDEN;
int out_ch = DIM + 3*HIDDEN;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Slice x2norm [DIM, SEQ] — RMSNorm output (computed on CPU)
[m appendString:@" tensor<int32, [4]> b_xn = const()[name=string(\"b_xn\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> s_ds = const()[name=string(\"s_ds\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2norm = slice_by_size(x=x,begin=b_xn,size=s_ds)[name=string(\"x2norm\")];\n", DIM, SEQ];
// Slice x2 [DIM, SEQ] — for residual: x_next = x2 + ffn_out
[m appendFormat:@" tensor<int32, [4]> b_x2 = const()[name=string(\"b_x2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x2 = slice_by_size(x=x,begin=b_x2,size=s_ds)[name=string(\"x2\")];\n", DIM, SEQ];
// Slice W1 [DIM, HIDDEN]
[m appendFormat:@" tensor<int32, [4]> b_w1 = const()[name=string(\"b_w1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> s_wh = const()[name=string(\"s_wh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=x,begin=b_w1,size=s_wh)[name=string(\"W1\")];\n", DIM, HIDDEN];
// Slice W3 [DIM, HIDDEN]
[m appendFormat:@" tensor<int32, [4]> b_w3 = const()[name=string(\"b_w3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=x,begin=b_w3,size=s_wh)[name=string(\"W3\")];\n", DIM, HIDDEN];
// Slice W2_orig [DIM, HIDDEN] (transposed inside kernel)
[m appendFormat:@" tensor<int32, [4]> b_w2 = const()[name=string(\"b_w2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+2*HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2r = slice_by_size(x=x,begin=b_w2,size=s_wh)[name=string(\"W2r\")];\n", DIM, HIDDEN];
// Reshape for matmul: x2norm [1,DIM,1,SEQ] → [1,1,DIM,SEQ] → [1,1,SEQ,DIM]
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=x2norm)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
// Reshape weights: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN];
// h1 = x2norm_t @ W1, h3 = x2norm_t @ W3 [SEQ,DIM] @ [DIM,HIDDEN] → [SEQ,HIDDEN]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN];
// Reshape back: [1,1,SEQ,HIDDEN] → [1,1,HIDDEN,SEQ] → [1,HIDDEN,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> rh = const()[name=string(\"rh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ];
// SiLU + gate: gate = silu(h1) * h3
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ];
// gate @ W2: reshape gate [1,HIDDEN,1,SEQ] → [1,1,HIDDEN,SEQ] → [1,1,SEQ,HIDDEN]
[m appendFormat:@" tensor<int32, [4]> rg = const()[name=string(\"rg\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> g2 = reshape(shape=rg,x=gate)[name=string(\"g2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> gt = transpose(perm=pm,x=g2)[name=string(\"gtt\")];\n", SEQ, HIDDEN];
// W2: [1,DIM,1,HIDDEN] → [1,1,DIM,HIDDEN] → transpose → [1,1,HIDDEN,DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2r)[name=string(\"W22\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W2t = transpose(perm=pm,x=W22)[name=string(\"W2t\")];\n", HIDDEN, DIM];
// matmul: [1,1,SEQ,HIDDEN] @ [1,1,HIDDEN,DIM] → [1,1,SEQ,DIM]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> fm = matmul(transpose_x=bF,transpose_y=bF,x=gt,y=W2t)[name=string(\"fm\")];\n", SEQ, DIM];
// Reshape: [1,1,SEQ,DIM] → [1,1,DIM,SEQ] → [1,DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ft = transpose(perm=pm,x=fm)[name=string(\"ft\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rd2 = const()[name=string(\"rd2\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ];
// Residual: x_next = x2 + ffn_out
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x_next = add(x=x2,y=ffn_out)[name=string(\"x_next\")];\n", DIM, SEQ];
// Output: concat(x_next, h1, h3, gate) — gate=silu*h3 needed for dW2
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(x_next,h1,h3,gate))[name=string(\"cat\")];\n", out_ch, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// FFN part 2: gate @ W2 (HIDDEN→DIM) // FFN part 2: gate @ W2 (HIDDEN→DIM)
// Input: [1, HIDDEN, 1, SEQ + DIM] fp32 // Input: [1, HIDDEN, 1, SEQ + DIM] fp32
// sp[0:SEQ] = gate [HIDDEN,SEQ] // sp[0:SEQ] = gate [HIDDEN,SEQ]
@ -305,7 +390,7 @@ static NSString *gen_ffn_w2_dynamic(void) {
// Input: [1, DIM, 1, SEQ + HIDDEN] fp32 // Input: [1, DIM, 1, SEQ + HIDDEN] fp32
// sp[0:SEQ] = dffn [DIM, SEQ] // sp[0:SEQ] = dffn [DIM, SEQ]
// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN] // sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN]
// Output: [1, HIDDEN, 1, SEQ] fp32 = dsilu_raw // Output: [1, HIDDEN, 1, SEQ] fp16 = dsilu_raw
static NSString *gen_ffn_bwd_w2t_dynamic(void) { static NSString *gen_ffn_bwd_w2t_dynamic(void) {
return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ); return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ);
} }
@ -320,32 +405,30 @@ static NSString *gen_ffn_bwd_w2t_dynamic(void) {
// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ] // sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ]
// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM] // sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM]
// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM] // sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dx1 + dx3 // Output: [1, DIM, 1, SEQ] fp16 = dx1 + dx3
static NSString *gen_ffn_bwd_w13t_dynamic(void) { static NSString *gen_ffn_bwd_w13t_dynamic(void) {
int sp_in = 2*SEQ + 2*DIM; int sp_in = 2*SEQ + 2*DIM;
NSMutableString *m = [NSMutableString string]; NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR]; [m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in]; [m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
// Slice dh1 [HIDDEN, SEQ] // Slice dh1 [HIDDEN, SEQ]
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"]; [m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=xh,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=x,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
// Slice dh3 // Slice dh3
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ]; [m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=xh,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=x,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
// Slice W1^T [HIDDEN, DIM] // Slice W1^T [HIDDEN, DIM]
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ]; [m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM]; [m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=xh,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=x,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM];
// Slice W3^T // Slice W3^T
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM]; [m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"]; [m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
@ -370,9 +453,7 @@ static NSString *gen_ffn_bwd_w13t_dynamic(void) {
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; [m appendString:@" } -> (dx);\n}\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m; return m;
} }
@ -516,27 +597,25 @@ static NSString *gen_qkvb_dynamic(void) {
int sp_in = 3*SEQ + 3*DIM; int sp_in = 3*SEQ + 3*DIM;
NSMutableString *m = [NSMutableString string]; NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR]; [m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in]; [m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice dq, dk, dv // Slice dq, dk, dv
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"]; [m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=xh,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=x,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ]; [m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=xh,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ]; [m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=xh,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ];
// Slice Wq^T, Wk^T, Wv^T // Slice Wq^T, Wk^T, Wv^T
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM]; [m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ]; [m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wqt = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wqt = slice_by_size(x=x,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b4 = const()[name=string(\"b4\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+DIM]; [m appendFormat:@" tensor<int32, [4]> b4 = const()[name=string(\"b4\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=xh,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=x,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b5 = const()[name=string(\"b5\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+2*DIM]; [m appendFormat:@" tensor<int32, [4]> b5 = const()[name=string(\"b5\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=xh,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=x,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"]; [m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
@ -570,9 +649,7 @@ static NSString *gen_qkvb_dynamic(void) {
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ]; [m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"]; [m appendString:@" } -> (dx);\n}\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m; return m;
} }

View File

@ -11,8 +11,7 @@
// Dynamic kernel set per layer // Dynamic kernel set per layer
typedef struct { typedef struct {
Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface) Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface)
Kern *ffnW13; // W1,W3 matmul (dynamic) Kern *ffnFused; // residual + RMSNorm + W1,W3 + SiLU + W2 + residual (fused)
Kern *ffnW2; // W2 matmul (dynamic)
Kern *ffnBwdW2t; // dffn @ W2^T (dynamic) Kern *ffnBwdW2t; // dffn @ W2^T (dynamic)
Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic) Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic)
Kern *wotBwd; // dx2 @ Wo^T (dynamic) Kern *wotBwd; // dx2 @ Wo^T (dynamic)
@ -60,40 +59,36 @@ static void transpose_weight(float *dst, const float *src, int rows, int cols) {
static bool compile_dynamic_kernels(DynLayerKernels *dk) { static bool compile_dynamic_kernels(DynLayerKernels *dk) {
NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}}; NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}};
// SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp32 [1, 6*DIM, 1, SEQ] fp32 // SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp16 [1, 6*DIM, 1, SEQ] fp16
printf(" Compiling sdpaFwd...\n"); printf(" Compiling sdpaFwd...\n");
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w, dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w,
DIM*(SEQ+4*DIM)*4, 6*DIM*SEQ*4); DIM*(SEQ+4*DIM)*2, 6*DIM*SEQ*2);
if (!dk->sdpaFwd) return false; if (!dk->sdpaFwd) return false;
// FFN W1+W3: [1, DIM, 1, SEQ+2*HIDDEN] fp32 [1, 3*HIDDEN, 1, SEQ] fp32 // Fused FFN: W1,W3 + SiLU + W2 + residual (RMSNorm on CPU)
printf(" Compiling ffnW13...\n"); printf(" Compiling ffnFused...\n");
dk->ffnW13 = compile_kern_mil_w(gen_ffn_w13_dynamic(), @{}, int ffn_fused_sp = 2*SEQ + 3*HIDDEN;
DIM*(SEQ+2*HIDDEN)*4, 3*HIDDEN*SEQ*4); int ffn_fused_och = DIM + 3*HIDDEN;
if (!dk->ffnW13) return false; dk->ffnFused = compile_kern_mil_w(gen_ffn_fused_dynamic(), @{},
DIM*ffn_fused_sp*2, ffn_fused_och*SEQ*2);
if (!dk->ffnFused) return false;
// FFN W2: [1, HIDDEN, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32 // FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp16 [1, HIDDEN, 1, SEQ] fp16
printf(" Compiling ffnW2...\n");
dk->ffnW2 = compile_kern_mil_w(gen_ffn_w2_dynamic(), @{},
HIDDEN*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->ffnW2) return false;
// FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp32 [1, HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnBwdW2t...\n"); printf(" Compiling ffnBwdW2t...\n");
dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{}, dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{},
DIM*(SEQ+HIDDEN)*4, HIDDEN*SEQ*4); DIM*(SEQ+HIDDEN)*2, HIDDEN*SEQ*2);
if (!dk->ffnBwdW2t) return false; if (!dk->ffnBwdW2t) return false;
// FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp32 [1, DIM, 1, SEQ] fp32 // FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp16 [1, DIM, 1, SEQ] fp16
printf(" Compiling ffnBwdW13t...\n"); printf(" Compiling ffnBwdW13t...\n");
dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{}, dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{},
HIDDEN*(2*SEQ+2*DIM)*4, DIM*SEQ*4); HIDDEN*(2*SEQ+2*DIM)*2, DIM*SEQ*2);
if (!dk->ffnBwdW13t) return false; if (!dk->ffnBwdW13t) return false;
// Wo^T backward: [1, DIM, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32 // Wo^T backward: [1, DIM, 1, SEQ+DIM] fp16 [1, DIM, 1, SEQ] fp16
printf(" Compiling wotBwd...\n"); printf(" Compiling wotBwd...\n");
dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{}, dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{},
DIM*(SEQ+DIM)*4, DIM*SEQ*4); DIM*(SEQ+DIM)*2, DIM*SEQ*2);
if (!dk->wotBwd) return false; if (!dk->wotBwd) return false;
// SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 [1, DIM+2*SCORE_CH, 1, SEQ] fp16 // SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 [1, DIM+2*SCORE_CH, 1, SEQ] fp16
@ -108,10 +103,10 @@ static bool compile_dynamic_kernels(DynLayerKernels *dk) {
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2); (2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
if (!dk->sdpaBwd2) return false; if (!dk->sdpaBwd2) return false;
// QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp32 [1, DIM, 1, SEQ] fp32 // QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp16 [1, DIM, 1, SEQ] fp16
printf(" Compiling qkvBwd...\n"); printf(" Compiling qkvBwd...\n");
dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{}, dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{},
DIM*(3*SEQ+3*DIM)*4, DIM*SEQ*4); DIM*(3*SEQ+3*DIM)*2, DIM*SEQ*2);
if (!dk->qkvBwd) return false; if (!dk->qkvBwd) return false;
return true; return true;
@ -134,32 +129,6 @@ static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm,
IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL); IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL);
} }
// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] xnorm at sp[0:S], W1,W3 at sp[S:]
static void write_ffn_w13_input(DynLayerKernels *dk, const float *xnorm,
const float *W1, const float *W3) {
IOSurfaceLock(dk->ffnW13->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW13->ioIn);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN*4);
memcpy(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN*4);
}
IOSurfaceUnlock(dk->ffnW13->ioIn, 0, NULL);
}
// ffnW2: [1, HIDDEN, 1, SEQ+DIM] gate at sp[0:S], W2 at sp[S:]
static void write_ffn_w2_input(DynLayerKernels *dk, const float *gate, const float *W2) {
IOSurfaceLock(dk->ffnW2->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW2->ioIn);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, gate + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W2 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->ffnW2->ioIn, 0, NULL);
}
// ===== Checkpoint ===== // ===== Checkpoint =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss, static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double ct, double cw, int cs, int adam_t, double ct, double cw, int cs, int adam_t,
@ -238,11 +207,13 @@ int main(int argc, char *argv[]) {
int total_steps = 10000; int total_steps = 10000;
float max_lr = 3e-4f; float max_lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; float adam_b1=0.9f, adam_b2=0.95f, adam_eps=1e-8f, wd=0.1f;
int adam_t = 0, start_step = 0; int adam_t = 0, start_step = 0;
int accum_steps = 10; int accum_steps = 10;
int warmup_steps = 100; int warmup_steps = 100;
float grad_clip = 1.0f; float grad_clip = 1.0f;
float loss_scale = 256.0f; // fp16 loss scaling for ANE backward
float act_clip = 20.0f;
float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1 float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1
bool do_resume = false, from_scratch = false; bool do_resume = false, from_scratch = false;
@ -287,7 +258,7 @@ int main(int argc, char *argv[]) {
double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6; double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6;
double embed_m = (double)VOCAB*DIM / 1e6; double embed_m = (double)VOCAB*DIM / 1e6;
printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m); printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m);
printf("Kernels: 9 compiled, 9 weight-bearing\n"); printf("Kernels: 8 compiled (ffnFused replaces ffnW13+ffnW2, RMSNorm on CPU)\n");
printf("Accum %d steps, LR=%g\n", accum_steps, max_lr); printf("Accum %d steps, LR=%g\n", accum_steps, max_lr);
// FLOPs estimate: 6*N*B*T for transformer (forward+backward 3x forward) // FLOPs estimate: 6*N*B*T for transformer (forward+backward 3x forward)
double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ; double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ;
@ -353,14 +324,48 @@ int main(int argc, char *argv[]) {
AdamState acembed = adam_alloc((size_t)CV*DIM); AdamState acembed = adam_alloc((size_t)CV*DIM);
// ===== Compile all kernels ONCE ===== // ===== Compile all kernels ONCE =====
printf("Compiling %d dynamic kernels (one-time)...\n", 9); printf("Compiling %d dynamic kernels (one-time)...\n", 8);
uint64_t tc = mach_absolute_time(); uint64_t tc = mach_absolute_time();
DynLayerKernels dk; DynLayerKernels dk;
if (!compile_dynamic_kernels(&dk)) { if (!compile_dynamic_kernels(&dk)) {
printf("Compilation failed!\n"); return 1; printf("Compilation failed!\n"); return 1;
} }
double compile_ms = tb_ms(mach_absolute_time() - tc); double compile_ms = tb_ms(mach_absolute_time() - tc);
printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n\n", compile_ms, NLAYERS); printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n", compile_ms, NLAYERS);
// Allocate per-layer IOSurfaces + requests (pre-stage weights)
int per_layer_bytes = (DIM*(SEQ+4*DIM) + DIM*(2*SEQ+3*HIDDEN) +
DIM*(SEQ+HIDDEN) + HIDDEN*(2*SEQ+2*DIM) + DIM*(SEQ+DIM) + DIM*(3*SEQ+3*DIM)) * 2;
int total_surf_mb = (int)((long)per_layer_bytes * NLAYERS / (1024*1024));
printf("Allocating per-layer IOSurfaces (%d surfaces, ~%dMB fp16)...\n", NLAYERS*6, total_surf_mb);
PerLayerSurfaces pls[NLAYERS];
PerLayerRequests plr[NLAYERS];
for (int L = 0; L < NLAYERS; L++) {
pls[L].sdpaFwd_in = make_surface(DIM*(SEQ+4*DIM)*2);
pls[L].ffnFused_in = make_surface(DIM*(2*SEQ+3*HIDDEN)*2);
pls[L].ffnBwdW2t_in = make_surface(DIM*(SEQ+HIDDEN)*2);
pls[L].ffnBwdW13t_in= make_surface(HIDDEN*(2*SEQ+2*DIM)*2);
pls[L].wotBwd_in = make_surface(DIM*(SEQ+DIM)*2);
pls[L].qkvBwd_in = make_surface(DIM*(3*SEQ+3*DIM)*2);
plr[L].sdpaFwd = make_request(dk.sdpaFwd, pls[L].sdpaFwd_in);
plr[L].ffnFused = make_request(dk.ffnFused, pls[L].ffnFused_in);
plr[L].ffnBwdW2t = make_request(dk.ffnBwdW2t, pls[L].ffnBwdW2t_in);
plr[L].ffnBwdW13t= make_request(dk.ffnBwdW13t,pls[L].ffnBwdW13t_in);
plr[L].wotBwd = make_request(dk.wotBwd, pls[L].wotBwd_in);
plr[L].qkvBwd = make_request(dk.qkvBwd, pls[L].qkvBwd_in);
}
// Stage weights into per-layer surfaces
for (int L = 0; L < NLAYERS; L++) {
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2);
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]);
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]);
stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]);
stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]);
}
printf("Per-layer weight staging complete\n\n");
// Gradient + work buffers // Gradient + work buffers
float *dy = (float*)malloc(SEQ*DIM*4); float *dy = (float*)malloc(SEQ*DIM*4);
@ -428,70 +433,63 @@ int main(int argc, char *argv[]) {
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER); dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t_cblas_wait += tb_ms(mach_absolute_time() - t0); t_cblas_wait += tb_ms(mach_absolute_time() - t0);
// SDPA forward (ANE): xnorm + Wq,Wk,Wv,Wo o_out,Q,K,V,attn_out,xnorm // SDPA forward (ANE): xnorm + pre-staged Wq,Wk,Wv,Wo o_out,Q,K,V,attn_out,xnorm
t0 = mach_absolute_time(); t0 = mach_absolute_time();
write_sdpa_fwd_input(&dk, xnorm_buf, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]); write_sdpa_fwd_acts(pls[L].sdpaFwd_in, xnorm_buf);
t_io_fwd += tb_ms(mach_absolute_time() - t0); t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
ane_eval(dk.sdpaFwd); ane_eval_req(dk.sdpaFwd, plr[L].sdpaFwd);
t_ane_fwd += tb_ms(mach_absolute_time() - t0); t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read output: [1, 6*DIM, 1, SEQ] fp32 // Read output: [1, 6*DIM, 1, SEQ] fp16
t0 = mach_absolute_time(); t0 = mach_absolute_time();
IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL); IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
float *fwd_out = (float*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut); _Float16 *fwd_out = (_Float16*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut);
memcpy(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ*4); cvt_f16_f32(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ);
memcpy(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ*4); cvt_f16_f32(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ);
memcpy(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ*4); cvt_f16_f32(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ);
memcpy(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ*4); cvt_f16_f32(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ);
memcpy(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ*4); cvt_f16_f32(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ);
IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL); IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0); t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x2 = x_cur + o_out // CPU: residual + RMSNorm (ANE can't fuse RMS with 3 matmuls)
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
// RMSNorm2 (CPU)
t0 = mach_absolute_time(); t0 = mach_absolute_time();
rmsnorm(xnorm_buf, ac->x2, lw[L].rms_ffn, DIM, SEQ); vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
memcpy(ac->x2norm, xnorm_buf, SEQ*DIM*4); rmsnorm(ac->x2norm, ac->x2, lw[L].rms_ffn, DIM, SEQ);
t_rms += tb_ms(mach_absolute_time() - t0); t_rms += tb_ms(mach_absolute_time() - t0);
// FFN W1+W3 (ANE): xnorm h1, h3, gate // Fused FFN (ANE): W1,W3 + SiLU + W2 + residual
// Input: x2norm + x2 (acts), W1t + W3t + W2 (pre-staged weights)
// Output: x_next, h1, h3, silu_out
t0 = mach_absolute_time(); t0 = mach_absolute_time();
write_ffn_w13_input(&dk, xnorm_buf, W1t_buf[L], W3t_buf[L]); write_ffn_fused_acts(pls[L].ffnFused_in, ac->x2norm, ac->x2);
t_io_fwd += tb_ms(mach_absolute_time() - t0); t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
ane_eval(dk.ffnW13); ane_eval_req(dk.ffnFused, plr[L].ffnFused);
t_ane_fwd += tb_ms(mach_absolute_time() - t0); t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read h1, h3, gate from output [1, 3*HIDDEN, 1, SEQ] // Read fused output: [1, DIM+3*HIDDEN, 1, SEQ] fp16
// Layout: x_next[DIM], h1[HIDDEN], h3[HIDDEN], silu_out[HIDDEN]
t0 = mach_absolute_time(); t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL); IOSurfaceLock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL);
float *ffn13_out = (float*)IOSurfaceGetBaseAddress(dk.ffnW13->ioOut); _Float16 *ffn_out = (_Float16*)IOSurfaceGetBaseAddress(dk.ffnFused->ioOut);
memcpy(ac->h1, ffn13_out, HIDDEN*SEQ*4); int off = 0;
memcpy(ac->h3, ffn13_out + HIDDEN*SEQ, HIDDEN*SEQ*4); cvt_f16_f32(x_cur, ffn_out + off, DIM*SEQ); off += DIM*SEQ;
memcpy(gate_buf, ffn13_out + 2*HIDDEN*SEQ, HIDDEN*SEQ*4); cvt_f16_f32(ac->h1, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ;
memcpy(ac->silu_out, gate_buf, HIDDEN*SEQ*4); cvt_f16_f32(ac->h3, ffn_out + off, HIDDEN*SEQ); off += HIDDEN*SEQ;
IOSurfaceUnlock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL); cvt_f16_f32(ac->silu_out,ffn_out + off, HIDDEN*SEQ);
IOSurfaceUnlock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0); t_io_fwd += tb_ms(mach_absolute_time() - t0);
// FFN W2 (ANE): gate @ W2 ffn_out // Scale down residual stream if max magnitude exceeds threshold
t0 = mach_absolute_time(); {
write_ffn_w2_input(&dk, gate_buf, W2t_buf[L]); float amx; vDSP_maxmgv(x_cur, 1, &amx, (vDSP_Length)(SEQ*DIM));
t_io_fwd += tb_ms(mach_absolute_time() - t0); if (amx > act_clip) {
t0 = mach_absolute_time(); float sc = act_clip / amx;
ane_eval(dk.ffnW2); vDSP_vsmul(x_cur, 1, &sc, x_cur, 1, (vDSP_Length)(SEQ*DIM));
t_ane_fwd += tb_ms(mach_absolute_time() - t0); }
}
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
memcpy(ac->ffn_out, (float*)IOSurfaceGetBaseAddress(dk.ffnW2->ioOut), DIM*SEQ*4);
IOSurfaceUnlock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x_cur = x2 + ffn_out
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
} }
// Final RMSNorm + classifier + loss (CPU) // Final RMSNorm + classifier + loss (CPU)
@ -507,6 +505,10 @@ int main(int argc, char *argv[]) {
last_loss = loss; last_loss = loss;
// ===== BACKWARD ===== // ===== BACKWARD =====
// Loss scaling: scale dlogits to prevent fp16 underflow in ANE backward kernels
// All gradients flow scaled; weight grads divided by loss_scale before Adam
vDSP_vsmul(dlogits, 1, &loss_scale, dlogits, 1, (vDSP_Length)(SEQ*CV));
// Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ] // Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ]
t0 = mach_absolute_time(); t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
@ -531,12 +533,12 @@ int main(int argc, char *argv[]) {
LayerGrads *gr = &grads[L]; LayerGrads *gr = &grads[L];
memcpy(dffn, dy, SEQ*DIM*4); memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward: dffn @ W2^T dsilu_raw // FFN backward: dffn @ pre-staged W2^T dsilu_raw
t0 = mach_absolute_time(); t0 = mach_absolute_time();
io_write_dyn(dk.ffnBwdW2t->ioIn, dffn, DIM, SEQ, lw[L].W2, HIDDEN); write_ffn_bwd_w2t_acts(pls[L].ffnBwdW2t_in, dffn);
t_io_bwd += tb_ms(mach_absolute_time() - t0); t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW2t); ane_eval_req(dk.ffnBwdW2t, plr[L].ffnBwdW2t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0); t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ); io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ);
@ -569,23 +571,12 @@ int main(int argc, char *argv[]) {
} }
t_silu += tb_ms(mach_absolute_time() - t0); t_silu += tb_ms(mach_absolute_time() - t0);
// dh1@W1^T + dh3@W3^T dx_ffn (ANE) // dh1@W1^T + dh3@W3^T dx_ffn (ANE, pre-staged weights)
t0 = mach_absolute_time(); t0 = mach_absolute_time();
{ write_ffn_bwd_w13t_acts(pls[L].ffnBwdW13t_in, dh1, dh3);
IOSurfaceLock(dk.ffnBwdW13t->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.ffnBwdW13t->ioIn);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, dh1 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, lw[L].W1 + d*DIM, DIM*4);
memcpy(buf + d*sp + 2*SEQ + DIM, lw[L].W3 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.ffnBwdW13t->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0); t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW13t); ane_eval_req(dk.ffnBwdW13t, plr[L].ffnBwdW13t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0); t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ); io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ);
@ -616,12 +607,12 @@ int main(int argc, char *argv[]) {
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i]; for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
t_rms_bwd += tb_ms(mach_absolute_time() - t0); t_rms_bwd += tb_ms(mach_absolute_time() - t0);
// Wo^T backward (ANE): dx2 @ Wo^T da // Wo^T backward (ANE): dx2 @ pre-staged Wo^T da
t0 = mach_absolute_time(); t0 = mach_absolute_time();
io_write_dyn(dk.wotBwd->ioIn, dx2, DIM, SEQ, lw[L].Wo, DIM); write_wot_bwd_acts(pls[L].wotBwd_in, dx2);
t_io_bwd += tb_ms(mach_absolute_time() - t0); t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
ane_eval(dk.wotBwd); ane_eval_req(dk.wotBwd, plr[L].wotBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0); t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
float *da_buf = (float*)malloc(SEQ*DIM*4); float *da_buf = (float*)malloc(SEQ*DIM*4);
@ -639,6 +630,19 @@ int main(int argc, char *argv[]) {
free(capt_do); free(capt_attn); free(capt_do); free(capt_attn);
}); });
if (L == 0 && step % 10 == 0) {
float damx, dx2mx, dx2mean;
vDSP_maxmgv(da_buf, 1, &damx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dx2, 1, &dx2mx, (vDSP_Length)(SEQ*DIM));
vDSP_meamgv(dx2, 1, &dx2mean, (vDSP_Length)(SEQ*DIM));
// Count how many dx2 values survive fp16 conversion
int nz = 0;
for (int i=0; i<SEQ*DIM && i<1000; i++) {
_Float16 h = (_Float16)dx2[i];
if (h != 0) nz++;
}
printf(" L0 wot_bwd: |da|=%.2e |dx2| max=%.2e mean=%.2e fp16_nz=%d/1000\n", damx, dx2mx, dx2mean, nz);
}
// SDPA backward part 1 (ANE, fp16): Q,K,V,da dV,probs,dp // SDPA backward part 1 (ANE, fp16): Q,K,V,da dV,probs,dp
t0 = mach_absolute_time(); t0 = mach_absolute_time();
io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, DIM, SEQ); io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, DIM, SEQ);
@ -667,6 +671,15 @@ int main(int argc, char *argv[]) {
io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ); io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0); t_io_bwd += tb_ms(mach_absolute_time() - t0);
// Debug: check SDPA backward output magnitudes
if (L == 0 && step % 10 == 0) {
float dqmx, dkmx, dvmx;
vDSP_maxmgv(dq, 1, &dqmx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dk_buf, 1, &dkmx, (vDSP_Length)(SEQ*DIM));
vDSP_maxmgv(dv, 1, &dvmx, (vDSP_Length)(SEQ*DIM));
printf(" L0 sdpa_bwd: |dq|=%.6f |dk|=%.6f |dv|=%.6f\n", dqmx, dkmx, dvmx);
}
// dWq/dWk/dWv async // dWq/dWk/dWv async
t0 = mach_absolute_time(); t0 = mach_absolute_time();
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4); float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
@ -684,25 +697,12 @@ int main(int argc, char *argv[]) {
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn); free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
}); });
// QKV backward (ANE): dq,dk,dv @ Wq^T,Wk^T,Wv^T dx_attn // QKV backward (ANE): dq,dk,dv @ pre-staged Wq^T,Wk^T,Wv^T dx_attn
t0 = mach_absolute_time(); t0 = mach_absolute_time();
{ write_qkv_bwd_acts(pls[L].qkvBwd_in, dq, dk_buf, dv);
IOSurfaceLock(dk.qkvBwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.qkvBwd->ioIn);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, dq + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dk_buf + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 3*SEQ, lw[L].Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+DIM, lw[L].Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+2*DIM, lw[L].Wv + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.qkvBwd->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0); t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
ane_eval(dk.qkvBwd); ane_eval_req(dk.qkvBwd, plr[L].qkvBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0); t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time(); t0 = mach_absolute_time();
io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ); io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ);
@ -741,10 +741,10 @@ int main(int argc, char *argv[]) {
// Adam update every accum_steps // Adam update every accum_steps
if ((step+1) % accum_steps == 0 || step == total_steps-1) { if ((step+1) % accum_steps == 0 || step == total_steps-1) {
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER); dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
float gsc = 1.0f / accum_steps; float gsc = 1.0f / (accum_steps * loss_scale);
adam_t++; adam_t++;
// Scale gradients by 1/accum_steps // Scale gradients by 1/(accum_steps * loss_scale)
for (int L=0; L<NLAYERS; L++) { for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L]; LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;} for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
@ -778,7 +778,25 @@ int main(int argc, char *argv[]) {
vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); grad_norm_sq+=s; vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); grad_norm_sq+=s;
} }
float grad_norm = sqrtf(grad_norm_sq); float grad_norm = sqrtf(grad_norm_sq);
if ((step+1) % 10 == 0) printf(" grad_norm=%.4f\n", grad_norm); if ((step+1) % 10 == 0) {
// Per-component gradient norms for diagnostics
float attn_sq=0, ffn_sq=0, embed_sq=0;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L]; float s;
vDSP_dotpr(g->Wq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); attn_sq+=s;
vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); attn_sq+=s;
vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); ffn_sq+=s;
vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); ffn_sq+=s;
vDSP_dotpr(g->W3,1,g->W3,1,&s,(vDSP_Length)W3_SZ); ffn_sq+=s;
}
{ float s;
vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); embed_sq=s;
}
printf(" grad_norm=%.4f attn=%.4f ffn=%.4f embed=%.4f\n",
grad_norm, sqrtf(attn_sq), sqrtf(ffn_sq), sqrtf(embed_sq));
}
// Gradient clipping // Gradient clipping
if (grad_clip > 0 && grad_norm > grad_clip) { if (grad_clip > 0 && grad_norm > grad_clip) {
@ -811,15 +829,15 @@ int main(int argc, char *argv[]) {
// Adam update // Adam update
for (int L=0; L<NLAYERS; L++) { for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L]; LayerGrads *g = &grads[L];
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f);
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f);
// Update transposed weight buffers // Update transposed weight buffers
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM); transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
@ -829,9 +847,17 @@ int main(int argc, char *argv[]) {
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM); transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN); transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM); transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
// Re-stage weights into per-layer IOSurfaces
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2);
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]);
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]);
stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]);
stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]);
} }
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f);
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps); adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
// Re-extract compact embed from updated full embed // Re-extract compact embed from updated full embed
free(cembed); free(cembed);
cembed = vocab_compact_embed(embed, &vm, DIM); cembed = vocab_compact_embed(embed, &vm, DIM);
@ -867,7 +893,8 @@ int main(int argc, char *argv[]) {
free(Wqt_buf[L]); free(Wkt_buf[L]); free(Wvt_buf[L]); free(Wot_buf[L]); free(Wqt_buf[L]); free(Wkt_buf[L]); free(Wvt_buf[L]); free(Wot_buf[L]);
free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]); free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]);
} }
free_kern(dk.sdpaFwd); free_kern(dk.ffnW13); free_kern(dk.ffnW2); free_per_layer(pls, plr);
free_kern(dk.sdpaFwd); free_kern(dk.ffnFused);
free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd); free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd);
free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd); free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd);
munmap(token_data, data_len); close(data_fd); munmap(token_data, data_len); close(data_fd);