mirror of https://github.com/maderix/ANE.git
Fixed the dynamic pipeline logit generation
This commit is contained in:
parent
06535fc5be
commit
c3c5094865
|
|
@ -166,6 +166,8 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
k_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||
v_cache = [[np.zeros((0, HD), dtype=np.float32) for _ in range(HEADS)] for _ in range(NLAYERS)]
|
||||
|
||||
res_alpha = 1.0 / math.sqrt(2.0 * NLAYERS)
|
||||
|
||||
for step in range(max_tokens):
|
||||
seq_len = len(tokens)
|
||||
if seq_len > SEQ:
|
||||
|
|
@ -206,8 +208,8 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
attn = softmax(scores)
|
||||
o[h * HD:(h + 1) * HD] = attn @ v_cache[L][h]
|
||||
|
||||
# Residual + output projection
|
||||
x2 = x + W[f'Wo{L}'] @ o
|
||||
# Residual + output projection (scaled residual, matches training)
|
||||
x2 = x + res_alpha * (W[f'Wo{L}'] @ o)
|
||||
|
||||
# FFN
|
||||
x2n = rmsnorm(x2, W[f'rms2_{L}'])
|
||||
|
|
@ -217,7 +219,7 @@ def generate_text(W, max_tokens=64, temperature=0.8):
|
|||
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
|
||||
ffn_out = W[f'W2_{L}'] @ h1
|
||||
|
||||
x = x2 + ffn_out
|
||||
x = x2 + res_alpha * ffn_out
|
||||
|
||||
x = rmsnorm(x, W['rms_final'])
|
||||
|
||||
|
|
|
|||
|
|
@ -254,15 +254,16 @@ int main(int argc, char *argv[]) {
|
|||
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
||||
if (ane_extras) printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n");
|
||||
else printf("ANE extras DISABLED (classifier/softmax/rmsnorm_bwd on CPU)\n");
|
||||
if (!load_pretrained(lw, rms_final, embed, model_path)) {
|
||||
printf("Pretrained load failed, using random init\n");
|
||||
{
|
||||
printf(" Training from scratch (random init)\n");
|
||||
srand48(42);
|
||||
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
||||
float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*res_scale*(2*drand48()-1);}
|
||||
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
|
||||
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
|
||||
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*res_scale*(2*drand48()-1);
|
||||
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
|
||||
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -162,3 +162,23 @@ static void embed_backward(float *d_embed, const float *dx, const uint16_t *toke
|
|||
d_embed[tok*dim + d] += dx[d*seq + t];
|
||||
}
|
||||
}
|
||||
|
||||
// RoPE backward (in-place): inverse rotation on dQ/dK gradients
|
||||
// Data layout: [DIM, SEQ] channel-first, DIM = nheads * hd
|
||||
static void rope_backward_inplace(float *dx, int seq, int dim, int hd) {
|
||||
int nheads = dim / hd;
|
||||
for (int h = 0; h < nheads; h++) {
|
||||
for (int i = 0; i < hd/2; i++) {
|
||||
float freq = 1.0f / powf(10000.0f, 2.0f * i / (float)hd);
|
||||
for (int p = 0; p < seq; p++) {
|
||||
float theta = p * freq;
|
||||
float cos_t = cosf(theta), sin_t = sinf(theta);
|
||||
int idx0 = (h * hd + 2 * i) * seq + p;
|
||||
int idx1 = (h * hd + 2 * i + 1) * seq + p;
|
||||
float v0 = dx[idx0], v1 = dx[idx1];
|
||||
dx[idx0] = v0 * cos_t + v1 * sin_t;
|
||||
dx[idx1] = -v0 * sin_t + v1 * cos_t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,8 +135,41 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
|
|||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
|
||||
|
||||
// Q @ K^T
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
|
||||
// RoPE: q_rope = q * cos + rotate_half(q) * sin, same for k
|
||||
int pairs = SEQ * HD / 2;
|
||||
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> rope_cos = const()[name=string(\"rc\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/rope_cos.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> rope_sin = const()[name=string(\"rs\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/rope_sin.bin\"), offset=uint64(64)))];\n", SEQ, HD, SEQ, HD];
|
||||
[m appendFormat:@" tensor<int32, [4]> rp_sh = const()[name=string(\"rp_sh\"), val=tensor<int32, [4]>([1,%d,%d,2])];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<int32, [4]> rp_s1 = const()[name=string(\"rp_s1\"), val=tensor<int32, [4]>([1,%d,%d,1])];\n", HEADS, pairs];
|
||||
[m appendString:@" tensor<int32, [4]> rp_b0 = const()[name=string(\"rp_b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
|
||||
[m appendString:@" tensor<int32, [4]> rp_b1 = const()[name=string(\"rp_b1\"), val=tensor<int32, [4]>([0,0,0,1])];\n"];
|
||||
[m appendFormat:@" tensor<int32, [4]> rp_bk = const()[name=string(\"rp_bk\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, HD];
|
||||
// rotate_half(q): reshape to pairs, swap+negate, reshape back
|
||||
[m appendString:@" fp16 neg1 = const()[name=string(\"neg1\"), val=fp16(-1)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> q_p = reshape(shape=rp_sh,x=q)[name=string(\"q_p\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_e = slice_by_size(x=q_p,begin=rp_b0,size=rp_s1)[name=string(\"q_e\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> q_o = slice_by_size(x=q_p,begin=rp_b1,size=rp_s1)[name=string(\"q_o\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nq = mul(x=q_o,y=neg1)[name=string(\"nq\")];\n", HEADS, pairs];
|
||||
[m appendString:@" int32 rpax = const()[name=string(\"rpax\"), val=int32(3)];\n"];
|
||||
[m appendString:@" bool rpil = const()[name=string(\"rpil\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> qrp = concat(axis=rpax,interleave=rpil,values=(nq,q_e))[name=string(\"qrp\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rot = reshape(shape=rp_bk,x=qrp)[name=string(\"q_rot\")];\n", HEADS, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qc = mul(x=q,y=rope_cos)[name=string(\"qc\")];\n", HEADS, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qrs = mul(x=q_rot,y=rope_sin)[name=string(\"qrs\")];\n", HEADS, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q_rope = add(x=qc,y=qrs)[name=string(\"q_rope\")];\n", HEADS, SEQ, HD];
|
||||
// rotate_half(k)
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> k_p = reshape(shape=rp_sh,x=k)[name=string(\"k_p\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_e = slice_by_size(x=k_p,begin=rp_b0,size=rp_s1)[name=string(\"k_e\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> k_o = slice_by_size(x=k_p,begin=rp_b1,size=rp_s1)[name=string(\"k_o\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> nk = mul(x=k_o,y=neg1)[name=string(\"nk\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,2]> krp = concat(axis=rpax,interleave=rpil,values=(nk,k_e))[name=string(\"krp\")];\n", HEADS, pairs];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rot = reshape(shape=rp_bk,x=krp)[name=string(\"k_rot\")];\n", HEADS, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kc = mul(x=k,y=rope_cos)[name=string(\"kc\")];\n", HEADS, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krs = mul(x=k_rot,y=rope_sin)[name=string(\"krs\")];\n", HEADS, SEQ, HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k_rope = add(x=kc,y=krs)[name=string(\"k_rope\")];\n", HEADS, SEQ, HD];
|
||||
|
||||
// Q_rope @ K_rope^T
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q_rope,y=k_rope)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
|
||||
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
|
||||
|
||||
|
|
@ -162,10 +195,16 @@ static NSString *gen_sdpa_fwd_dynamic(void) {
|
|||
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ];
|
||||
|
||||
// Output: concat(o_out, qf, kf, vf, af, xn) — same as original for backward compatibility
|
||||
// Convert RoPE'd Q,K back to [1,DIM,1,SEQ] for backward pass output
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qrt = transpose(perm=pm,x=q_rope)[name=string(\"qrt\")];\n", HEADS, HD, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qrf = reshape(shape=os,x=qrt)[name=string(\"qrf\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> krt = transpose(perm=pm,x=k_rope)[name=string(\"krt\")];\n", HEADS, HD, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> krf = reshape(shape=os,x=krt)[name=string(\"krf\")];\n", DIM, SEQ];
|
||||
|
||||
// Output: concat(o_out, Q_rope, K_rope, V, attn_out, xnorm) for backward
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qrf,krf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
|
@ -331,8 +370,11 @@ static NSString *gen_ffn_fused_dynamic(void) {
|
|||
[m appendFormat:@" tensor<int32, [4]> rd2 = const()[name=string(\"rd2\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_out = reshape(shape=rd2,x=ft)[name=string(\"ffn_out\")];\n", DIM, SEQ];
|
||||
|
||||
// Residual: x_next = x2 + ffn_out
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x_next = add(x=x2,y=ffn_out)[name=string(\"x_next\")];\n", DIM, SEQ];
|
||||
// Residual: x_next = x2 + alpha * ffn_out (residual scaling)
|
||||
float alpha = 1.0f / sqrtf(2.0f * NLAYERS);
|
||||
[m appendFormat:@" fp16 res_alpha = const()[name=string(\"res_alpha\"), val=fp16(%g)];\n", alpha];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> ffn_scaled = mul(x=ffn_out,y=res_alpha)[name=string(\"ffn_sc\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> x_next = add(x=x2,y=ffn_scaled)[name=string(\"x_next\")];\n", DIM, SEQ];
|
||||
|
||||
// Output: concat(x_next, h1, h3, gate) — gate=silu*h3 needed for dW2
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
|
|
@ -665,3 +707,39 @@ static NSData *get_mask_blob(void) {
|
|||
}
|
||||
return g_mask_blob;
|
||||
}
|
||||
|
||||
// RoPE cos/sin blobs [1, 1, SEQ, HD] — rotary position encodings
|
||||
static NSData *g_rope_cos_blob = nil;
|
||||
static NSData *g_rope_sin_blob = nil;
|
||||
|
||||
static NSData *get_rope_cos_blob(void) {
|
||||
if (!g_rope_cos_blob) {
|
||||
_Float16 *buf = (_Float16*)calloc(SEQ * HD, sizeof(_Float16));
|
||||
for (int p = 0; p < SEQ; p++)
|
||||
for (int i = 0; i < HD/2; i++) {
|
||||
float theta = p / powf(10000.0f, 2.0f * i / (float)HD);
|
||||
_Float16 cv = (_Float16)cosf(theta);
|
||||
buf[p * HD + 2*i] = cv;
|
||||
buf[p * HD + 2*i + 1] = cv;
|
||||
}
|
||||
g_rope_cos_blob = build_blob_fp16(buf, SEQ * HD);
|
||||
free(buf);
|
||||
}
|
||||
return g_rope_cos_blob;
|
||||
}
|
||||
|
||||
static NSData *get_rope_sin_blob(void) {
|
||||
if (!g_rope_sin_blob) {
|
||||
_Float16 *buf = (_Float16*)calloc(SEQ * HD, sizeof(_Float16));
|
||||
for (int p = 0; p < SEQ; p++)
|
||||
for (int i = 0; i < HD/2; i++) {
|
||||
float theta = p / powf(10000.0f, 2.0f * i / (float)HD);
|
||||
_Float16 sv = (_Float16)sinf(theta);
|
||||
buf[p * HD + 2*i] = sv;
|
||||
buf[p * HD + 2*i + 1] = sv;
|
||||
}
|
||||
g_rope_sin_blob = build_blob_fp16(buf, SEQ * HD);
|
||||
free(buf);
|
||||
}
|
||||
return g_rope_sin_blob;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,10 +58,15 @@ static void transpose_weight(float *dst, const float *src, int rows, int cols) {
|
|||
// ===== Compile all dynamic kernels (ONCE) =====
|
||||
static bool compile_dynamic_kernels(DynLayerKernels *dk) {
|
||||
NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}};
|
||||
NSDictionary *sdpa_fwd_w = @{
|
||||
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
|
||||
@"@model_path/weights/rope_cos.bin": @{@"offset":@0, @"data":get_rope_cos_blob()},
|
||||
@"@model_path/weights/rope_sin.bin": @{@"offset":@0, @"data":get_rope_sin_blob()}
|
||||
};
|
||||
|
||||
// SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp16 → [1, 6*DIM, 1, SEQ] fp16
|
||||
printf(" Compiling sdpaFwd...\n");
|
||||
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w,
|
||||
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), sdpa_fwd_w,
|
||||
DIM*(SEQ+4*DIM)*2, 6*DIM*SEQ*2);
|
||||
if (!dk->sdpaFwd) return false;
|
||||
|
||||
|
|
@ -213,7 +218,7 @@ int main(int argc, char *argv[]) {
|
|||
int warmup_steps = 100;
|
||||
float grad_clip = 1.0f;
|
||||
float loss_scale = 256.0f; // fp16 loss scaling for ANE backward
|
||||
float act_clip = 20.0f;
|
||||
float res_alpha = 1.0f / sqrtf(2.0f * NLAYERS); // residual scaling (DeepNet-style)
|
||||
float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1
|
||||
|
||||
bool do_resume = false, from_scratch = false;
|
||||
|
|
@ -273,11 +278,12 @@ int main(int argc, char *argv[]) {
|
|||
else printf(" Pretrained load failed, using random init\n");
|
||||
srand48(42);
|
||||
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
||||
float res_scale = 1.0f/sqrtf(2.0f*NLAYERS); // LLaMA-style output proj scaling
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*res_scale*(2*drand48()-1);}
|
||||
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
|
||||
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
|
||||
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*res_scale*(2*drand48()-1);
|
||||
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
|
||||
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
|
||||
}
|
||||
|
|
@ -362,10 +368,10 @@ int main(int argc, char *argv[]) {
|
|||
for (int L = 0; L < NLAYERS; L++) {
|
||||
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
|
||||
stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2);
|
||||
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]);
|
||||
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]);
|
||||
stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]);
|
||||
stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]);
|
||||
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, lw[L].W2);
|
||||
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, lw[L].W1, lw[L].W3);
|
||||
stage_wot_bwd_weights(pls[L].wotBwd_in, lw[L].Wo);
|
||||
stage_qkv_bwd_weights(pls[L].qkvBwd_in, lw[L].Wq, lw[L].Wk, lw[L].Wv);
|
||||
}
|
||||
printf("Per-layer weight staging complete\n\n");
|
||||
|
||||
|
|
@ -455,9 +461,10 @@ int main(int argc, char *argv[]) {
|
|||
IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
|
||||
t_io_fwd += tb_ms(mach_absolute_time() - t0);
|
||||
|
||||
// CPU: residual + RMSNorm (ANE can't fuse RMS with 3 matmuls)
|
||||
// CPU: scaled residual + RMSNorm (ANE can't fuse RMS with 3 matmuls)
|
||||
t0 = mach_absolute_time();
|
||||
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
|
||||
// x2 = x_cur + alpha * o_out (residual scaling keeps activations bounded)
|
||||
vDSP_vsma(ac->o_out, 1, &res_alpha, x_cur, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
|
||||
rmsnorm(ac->x2norm, ac->x2, lw[L].rms_ffn, DIM, SEQ);
|
||||
t_rms += tb_ms(mach_absolute_time() - t0);
|
||||
|
||||
|
|
@ -484,14 +491,8 @@ int main(int argc, char *argv[]) {
|
|||
IOSurfaceUnlock(dk.ffnFused->ioOut, kIOSurfaceLockReadOnly, NULL);
|
||||
t_io_fwd += tb_ms(mach_absolute_time() - t0);
|
||||
|
||||
// Scale down residual stream if max magnitude exceeds threshold
|
||||
{
|
||||
float amx; vDSP_maxmgv(x_cur, 1, &amx, (vDSP_Length)(SEQ*DIM));
|
||||
if (amx > act_clip) {
|
||||
float sc = act_clip / amx;
|
||||
vDSP_vsmul(x_cur, 1, &sc, x_cur, 1, (vDSP_Length)(SEQ*DIM));
|
||||
}
|
||||
}
|
||||
// (act_clip removed — was causing gradient explosion without backward,
|
||||
// vanishing gradients with backward. RMSNorm keeps activations bounded.)
|
||||
}
|
||||
|
||||
// Final RMSNorm + classifier + loss (CPU)
|
||||
|
|
@ -533,7 +534,9 @@ int main(int argc, char *argv[]) {
|
|||
for (int L=NLAYERS-1; L>=0; L--) {
|
||||
LayerActs *ac = &acts[L];
|
||||
LayerGrads *gr = &grads[L];
|
||||
memcpy(dffn, dy, SEQ*DIM*4);
|
||||
|
||||
// dffn = alpha * dy (gradient into FFN branch scaled by residual alpha)
|
||||
vDSP_vsmul(dy, 1, &res_alpha, dffn, 1, (vDSP_Length)(SEQ*DIM));
|
||||
|
||||
// FFN backward: dffn @ pre-staged W2^T → dsilu_raw
|
||||
t0 = mach_absolute_time();
|
||||
|
|
@ -609,9 +612,12 @@ int main(int argc, char *argv[]) {
|
|||
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
|
||||
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
|
||||
|
||||
// Wo^T backward (ANE): dx2 @ pre-staged Wo^T → da
|
||||
// Wo^T backward (ANE): alpha*dx2 @ pre-staged Wo^T → da
|
||||
// Scale dx2 by alpha for the attention branch (residual scaling backward)
|
||||
float *dx2_scaled = (float*)malloc(SEQ*DIM*4);
|
||||
vDSP_vsmul(dx2, 1, &res_alpha, dx2_scaled, 1, (vDSP_Length)(SEQ*DIM));
|
||||
t0 = mach_absolute_time();
|
||||
write_wot_bwd_acts(pls[L].wotBwd_in, dx2);
|
||||
write_wot_bwd_acts(pls[L].wotBwd_in, dx2_scaled);
|
||||
t_io_bwd += tb_ms(mach_absolute_time() - t0);
|
||||
t0 = mach_absolute_time();
|
||||
ane_eval_req(dk.wotBwd, plr[L].wotBwd);
|
||||
|
|
@ -621,9 +627,10 @@ int main(int argc, char *argv[]) {
|
|||
io_read_dyn(dk.wotBwd->ioOut, da_buf, DIM, SEQ);
|
||||
t_io_bwd += tb_ms(mach_absolute_time() - t0);
|
||||
|
||||
// dWo async
|
||||
// dWo async (uses alpha-scaled dx2)
|
||||
t0 = mach_absolute_time();
|
||||
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2, SEQ*DIM*4);
|
||||
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2_scaled, SEQ*DIM*4);
|
||||
free(dx2_scaled);
|
||||
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
|
||||
t_dw_copy += tb_ms(mach_absolute_time() - t0);
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
|
|
@ -673,6 +680,11 @@ int main(int argc, char *argv[]) {
|
|||
io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
|
||||
t_io_bwd += tb_ms(mach_absolute_time() - t0);
|
||||
|
||||
// RoPE backward: dq, dk are grads w.r.t. Q_rope, K_rope
|
||||
// Inverse rotation to get grads w.r.t. pre-RoPE Q, K
|
||||
rope_backward_inplace(dq, SEQ, DIM, HD);
|
||||
rope_backward_inplace(dk_buf, SEQ, DIM, HD);
|
||||
|
||||
// Debug: check SDPA backward output magnitudes
|
||||
if (L == 0 && step % 10 == 0) {
|
||||
float dqmx, dkmx, dvmx;
|
||||
|
|
@ -853,10 +865,10 @@ int main(int argc, char *argv[]) {
|
|||
// Re-stage weights into per-layer IOSurfaces
|
||||
stage_sdpa_fwd_weights(pls[L].sdpaFwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
|
||||
stage_ffn_fused_weights(pls[L].ffnFused_in, W1t_buf[L], W3t_buf[L], lw[L].W2);
|
||||
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, W2t_buf[L]);
|
||||
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, W1t_buf[L], W3t_buf[L]);
|
||||
stage_wot_bwd_weights(pls[L].wotBwd_in, Wot_buf[L]);
|
||||
stage_qkv_bwd_weights(pls[L].qkvBwd_in, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L]);
|
||||
stage_ffn_bwd_w2t_weights(pls[L].ffnBwdW2t_in, lw[L].W2);
|
||||
stage_ffn_bwd_w13t_weights(pls[L].ffnBwdW13t_in, lw[L].W1, lw[L].W3);
|
||||
stage_wot_bwd_weights(pls[L].wotBwd_in, lw[L].Wo);
|
||||
stage_qkv_bwd_weights(pls[L].qkvBwd_in, lw[L].Wq, lw[L].Wk, lw[L].Wv);
|
||||
}
|
||||
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps, 0.0f);
|
||||
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps, wd);
|
||||
|
|
|
|||
Loading…
Reference in New Issue