diff --git a/training/dashboard.py b/training/dashboard.py index b4c795c..55e8bb9 100644 --- a/training/dashboard.py +++ b/training/dashboard.py @@ -166,6 +166,8 @@ def generate_text(W, max_tokens=64, temperature=0.8): k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)] v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)] + res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS) + for step in range(max_tokens): seq_len = len(tokens) if seq_len > SEQ: @@ -206,8 +208,8 @@ def generate_text(W, max_tokens=64, temperature=0.8): attn = softmax(scores) o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h] - # Residual + output projection - x2 = x + W[f'Wo{L}'] @ o + # Residual + output projection (scaled residual, matches training) + x2 = x + res_alpha * (W[f'Wo{L}'] @ o) # FFN x2n = rmsnorm(x2, W[f'rms2_{L}']) @@ -217,7 +219,7 @@ def generate_text(W, max_tokens=64, temperature=0.8): h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3 ffn_out = W[f'W2_{L}'] @ h1 - x = x2 + ffn_out + x = x2 + res_alpha * ffn_out x = rmsnorm(x, W['rms_final']) diff --git a/training/train_large_ane.m b/training/train_large_ane.m index 25e9160..6a47a3f 100644 --- a/training/train_large_ane.m +++ b/training/train_large_ane.m @@ -254,15 +254,16 @@ int main(int argc, char *argv[]) { printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS); if (ane_extras) printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n"); else printf("ANE extras DISABLED (classifier/softmax/rmsnorm_bwd on CPU)\n"); - if (!load_pretrained(lw, rms_final, embed, model_path)) { - printf("Pretrained load failed, using random init\n"); + { + printf(" Training from scratch (random init)\n"); srand48(42); float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN); + float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling for (int L=0; L v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ]; [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD]; - // Q @ K^T - [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; + // RoPE: q_rope = q * cos + rotate_half(q) * sin, same for k + int pairs = SEQ * HD / 2; + [m appendFormat:@" tensor rope_cos = const()[name=string(\"rc\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rope_cos.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD]; + [m appendFormat:@" tensor rope_sin = const()[name=string(\"rs\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rope_sin.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD]; + [m appendFormat:@" tensor rp_sh = const()[name=string(\"rp_sh\"), val=tensor([1,%d,%d,2])];\n", HEADS, pairs]; + [m appendFormat:@" tensor rp_s1 = const()[name=string(\"rp_s1\"), val=tensor([1,%d,%d,1])];\n", HEADS, pairs]; + [m appendString:@" tensor rp_b0 = const()[name=string(\"rp_b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendString:@" tensor rp_b1 = const()[name=string(\"rp_b1\"), val=tensor([0,0,0,1])];\n"]; + [m appendFormat:@" tensor rp_bk = const()[name=string(\"rp_bk\"), val=tensor([1,%d,%d,%d])];\n", HEADS, SEQ, HD]; + // rotate_half(q): reshape to pairs, swap+negate, reshape back + [m appendString:@" fp16 neg1 = const()[name=string(\"neg1\"), val=fp16(-1)];\n"]; + [m appendFormat:@" tensor q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs]; + [m appendString:@" int32 rpax = const()[name=string(\"rpax\"), val=int32(3)];\n"]; + [m appendString:@" bool rpil = const()[name=string(\"rpil\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor q_rot = reshape(shape=rp_bk,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor qc = mul(x=q,y=rope_cos)[name=string(\"qc\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor qrs = mul(x=q_rot,y=rope_sin)[name=string(\"qrs\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor q_rope = add(x=qc,y=qrs)[name=string(\"q_rope\")];\n", HEADS, SEQ, HD]; + // rotate_half(k) + [m appendFormat:@" tensor k_p = reshape(shape=rp_sh,x=k)[name=string(\"k_p\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1)[name=string(\"k_e\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1)[name=string(\"k_o\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", HEADS, pairs]; + [m appendFormat:@" tensor k_rot = reshape(shape=rp_bk,x=krp)[name=string(\"k_rot\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", HEADS, SEQ, HD]; + [m appendFormat:@" tensor k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", HEADS, SEQ, HD]; + + // Q_rope @ K_rope^T + [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_rope)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ]; [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ]; @@ -162,10 +195,16 @@ static NSString *gen_sdpa_fwd_dynamic(void) { [m appendFormat:@" tensor ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ]; [m appendFormat:@" tensor oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ]; - // Output: concat(o_out, qf, kf, vf, af, xn) — same as original for backward compatibility + // Convert RoPE'd Q,K back to [1,DIM,1,SEQ] for backward pass output + [m appendFormat:@" tensor qrt = transpose(perm=pm,x=q_rope)[name=string(\"qrt\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor qrf = reshape(shape=os,x=qrt)[name=string(\"qrf\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", HEADS, HD, SEQ]; + [m appendFormat:@" tensor krf = reshape(shape=os,x=krt)[name=string(\"krf\")];\n", DIM, SEQ]; + + // Output: concat(o_out, Q_rope, K_rope, V, attn_out, xnorm) for backward [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qrf,krf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ]; [m appendString:@" } -> (out);\n}\n"]; return m; } @@ -331,8 +370,11 @@ static NSString *gen_ffn_fused_dynamic(void) { [m appendFormat:@" tensor rd2 = const()[name=string(\"rd2\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; [m appendFormat:@" tensor ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ]; - // Residual: x_next = x2 + ffn_out - [m appendFormat:@" tensor x_next = add(x=x2,y=ffn_out)[name=string(\"x_next\")];\n", DIM, SEQ]; + // Residual: x_next = x2 + alpha * ffn_out (residual scaling) + float alpha = 1.0f / sqrtf(2.0f * NLAYERS); + [m appendFormat:@" fp16 res_alpha = const()[name=string(\"res_alpha\"), val=fp16(%g)];\n", alpha]; + [m appendFormat:@" tensor ffn_scaled = mul(x=ffn_out,y=res_alpha)[name=string(\"ffn_sc\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor x_next = add(x=x2,y=ffn_scaled)[name=string(\"x_next\")];\n", DIM, SEQ]; // Output: concat(x_next, h1, h3, gate) — gate=silu*h3 needed for dW2 [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; @@ -665,3 +707,39 @@ static NSData *get_mask_blob(void) { } return g_mask_blob; } + +// RoPE cos/sin blobs [1, 1, SEQ, HD] — rotary position encodings +static NSData *g_rope_cos_blob = nil; +static NSData *g_rope_sin_blob = nil; + +static NSData *get_rope_cos_blob(void) { + if (!g_rope_cos_blob) { + _Float16 *buf = (_Float16*)calloc(SEQ * HD, sizeof(_Float16)); + for (int p = 0; p < SEQ; p++) + for (int i = 0; i < HD/2; i++) { + float theta = p / powf(10000.0f, 2.0f * i / (float)HD); + _Float16 cv = (_Float16)cosf(theta); + buf[p * HD + 2*i] = cv; + buf[p * HD + 2*i + 1] = cv; + } + g_rope_cos_blob = build_blob_fp16(buf, SEQ * HD); + free(buf); + } + return g_rope_cos_blob; +} + +static NSData *get_rope_sin_blob(void) { + if (!g_rope_sin_blob) { + _Float16 *buf = (_Float16*)calloc(SEQ * HD, sizeof(_Float16)); + for (int p = 0; p < SEQ; p++) + for (int i = 0; i < HD/2; i++) { + float theta = p / powf(10000.0f, 2.0f * i / (float)HD); + _Float16 sv = (_Float16)sinf(theta); + buf[p * HD + 2*i] = sv; + buf[p * HD + 2*i + 1] = sv; + } + g_rope_sin_blob = build_blob_fp16(buf, SEQ * HD); + free(buf); + } + return g_rope_sin_blob; +} diff --git a/training/training_dynamic/train.m b/training/training_dynamic/train.m index a95591e..685e075 100644 --- a/training/training_dynamic/train.m +++ b/training/training_dynamic/train.m @@ -58,10 +58,15 @@ static void transpose_weight(float *dst, const float *src, int rows, int cols) { // ===== Compile all dynamic kernels (ONCE) ===== static bool compile_dynamic_kernels(DynLayerKernels *dk) { NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}}; + NSDictionary *sdpa_fwd_w = @{ + @"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}, + @"@model_path/weights/rope_cos.bin": @{@"offset":@0, @"data":get_rope_cos_blob()}, + @"@model_path/weights/rope_sin.bin": @{@"offset":@0, @"data":get_rope_sin_blob()} + }; // SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp16 → [1, 6*DIM, 1, SEQ] fp16 printf(" Compiling sdpaFwd...\n"); - dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w, + dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), sdpa_fwd_w, DIM*(SEQ+4*DIM)*2, 6*DIM*SEQ*2); if (!dk->sdpaFwd) return false; @@ -213,7 +218,7 @@ int main(int argc, char *argv[]) { int warmup_steps = 100; float grad_clip = 1.0f; float loss_scale = 256.0f; // fp16 loss scaling for ANE backward - float act_clip = 20.0f; + float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS); // residual scaling (DeepNet-style) float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1 bool do_resume = false, from_scratch = false; @@ -273,11 +278,12 @@ int main(int argc, char *argv[]) { else printf(" Pretrained load failed, using random init\n"); srand48(42); float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN); + float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling for (int L=0; LioOut, kIOSurfaceLockReadOnly, NULL); t_io_fwd += tb_ms(mach_absolute_time() - t0); - // CPU: residual + RMSNorm (ANE can't fuse RMS with 3 matmuls) + // CPU: scaled residual + RMSNorm (ANE can't fuse RMS with 3 matmuls) t0 = mach_absolute_time(); - vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM)); + // x2 = x_cur + alpha * o_out (residual scaling keeps activations bounded) + vDSP_vsma(ac->o_out, 1, &res_alpha, x_cur, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM)); rmsnorm(ac->x2norm, ac->x2, lw[L].rms_ffn, DIM, SEQ); t_rms += tb_ms(mach_absolute_time() - t0); @@ -484,14 +491,8 @@ int main(int argc, char *argv[]) { IOSurfaceUnlock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL); t_io_fwd += tb_ms(mach_absolute_time() - t0); - // Scale down residual stream if max magnitude exceeds threshold - { - float amx; vDSP_maxmgv(x_cur, 1, &amx, (vDSP_Length)(SEQ*DIM)); - if (amx > act_clip) { - float sc = act_clip / amx; - vDSP_vsmul(x_cur, 1, &sc, x_cur, 1, (vDSP_Length)(SEQ*DIM)); - } - } + // (act_clip removed — was causing gradient explosion without backward, + // vanishing gradients with backward. RMSNorm keeps activations bounded.) } // Final RMSNorm + classifier + loss (CPU) @@ -533,7 +534,9 @@ int main(int argc, char *argv[]) { for (int L=NLAYERS-1; L>=0; L--) { LayerActs *ac = &acts[L]; LayerGrads *gr = &grads[L]; - memcpy(dffn, dy, SEQ*DIM*4); + + // dffn = alpha * dy (gradient into FFN branch scaled by residual alpha) + vDSP_vsmul(dy, 1, &res_alpha, dffn, 1, (vDSP_Length)(SEQ*DIM)); // FFN backward: dffn @ pre-staged W2^T → dsilu_raw t0 = mach_absolute_time(); @@ -609,9 +612,12 @@ int main(int argc, char *argv[]) { for(int i=0;iioOut, da_buf, DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); - // dWo async + // dWo async (uses alpha-scaled dx2) t0 = mach_absolute_time(); - float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2, SEQ*DIM*4); + float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2_scaled, SEQ*DIM*4); + free(dx2_scaled); float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4); t_dw_copy += tb_ms(mach_absolute_time() - t0); dispatch_group_async(dw_grp, dw_q, ^{ @@ -673,6 +680,11 @@ int main(int argc, char *argv[]) { io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ); t_io_bwd += tb_ms(mach_absolute_time() - t0); + // RoPE backward: dq, dk are grads w.r.t. Q_rope, K_rope + // Inverse rotation to get grads w.r.t. pre-RoPE Q, K + rope_backward_inplace(dq, SEQ, DIM, HD); + rope_backward_inplace(dk_buf, SEQ, DIM, HD); + // Debug: check SDPA backward output magnitudes if (L == 0 && step % 10 == 0) { float dqmx, dkmx, dvmx; @@ -853,10 +865,10 @@ int main(int argc, char *argv[]) { // Re-stage weights into per-layer IOSurfaces stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]); stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2); - stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]); - stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]); - stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]); - stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]); + stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, lw[L].W2); + stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, lw[L].W1, lw[L].W3); + stage_wot_bwd_weights(pls[L].wotBwd_in, lw[L].Wo); + stage_qkv_bwd_weights(pls[L].qkvBwd_in, lw[L].Wq, lw[L].Wk, lw[L].Wv); } adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f); adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);