// backward.h — Backward pass using CPU matmul (correct gradients) + ANE optional #pragma once #include "model.h" #include "forward.h" #include #include // dW += dy @ x^T — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim] static void cpu_accum_dW(float *dW, const float *dy, const float *x, int S, int out_dim, int in_dim) { for (int t = 0; t < S; t++) for (int i = 0; i < out_dim; i++) for (int j = 0; j < in_dim; j++) dW[i*in_dim+j] += dy[t*out_dim+i] * x[t*in_dim+j]; } // dx = W^T @ dy — W: [out_dim, in_dim], dy: [S, out_dim] → dx: [S, in_dim] static void cpu_matmul_backward_dx(const float *W, const float *dy, float *dx, int S, int out_dim, int in_dim) { for (int t = 0; t < S; t++) for (int j = 0; j < in_dim; j++) { float sum = 0; for (int i = 0; i < out_dim; i++) sum += W[i*in_dim+j] * dy[t*out_dim+i]; dx[t*in_dim+j] = sum; } } static void cpu_rmsnorm_backward(float *dx, const float *dy, const float *x, const float *w, int S, int D) { for (int t = 0; t < S; t++) { float ss = 0; for (int i = 0; i < D; i++) ss += x[t*D+i] * x[t*D+i]; float rms = sqrtf(ss / D + 1e-5f); float inv_rms = 1.0f / rms; float dot = 0; for (int i = 0; i < D; i++) dot += dy[t*D+i] * w[i] * x[t*D+i]; dot /= (D * rms * rms); for (int i = 0; i < D; i++) dx[t*D+i] = dy[t*D+i] * w[i] * inv_rms - x[t*D+i] * dot; } } static inline float silu_backward(float x) { float s = 1.0f / (1.0f + expf(-x)); return s * (1.0f + x * (1.0f - s)); } static void cpu_attention_backward(float *dq, float *dk, float *dv, const float *d_out, const float *q, const float *k, const float *v, int S, int n_heads, int head_dim) { float scale = 1.0f / sqrtf((float)head_dim); int D = n_heads * head_dim; float *scores = (float*)malloc(S * sizeof(float)); float *dscores = (float*)malloc(S * sizeof(float)); memset(dq, 0, S * D * sizeof(float)); memset(dk, 0, S * D * sizeof(float)); memset(dv, 0, S * D * sizeof(float)); for (int h = 0; h < n_heads; h++) { for (int t = 0; t < S; t++) { // Recompute softmax for this row float mx = -1e9f; for (int s = 0; s <= t; s++) { float dot = 0; for (int i = 0; i < head_dim; i++) dot += q[t*D + h*head_dim + i] * k[s*D + h*head_dim + i]; scores[s] = dot * scale; if (scores[s] > mx) mx = scores[s]; } float sm = 0; for (int s = 0; s <= t; s++) { scores[s] = expf(scores[s] - mx); sm += scores[s]; } for (int s = 0; s <= t; s++) scores[s] /= sm; // dscores = d_out · v float ds_sum = 0; for (int s = 0; s <= t; s++) { float dot = 0; for (int i = 0; i < head_dim; i++) dot += d_out[t*D + h*head_dim + i] * v[s*D + h*head_dim + i]; dscores[s] = dot; ds_sum += scores[s] * dot; } // Softmax backward + scale for (int s = 0; s <= t; s++) { float ds = scores[s] * (dscores[s] - ds_sum) * scale; // dq[t] += ds * k[s] for (int i = 0; i < head_dim; i++) dq[t*D + h*head_dim + i] += ds * k[s*D + h*head_dim + i]; // dk[s] += ds * q[t] for (int i = 0; i < head_dim; i++) dk[s*D + h*head_dim + i] += ds * q[t*D + h*head_dim + i]; // dv[s] += scores[t,s] * d_out[t] for (int i = 0; i < head_dim; i++) dv[s*D + h*head_dim + i] += scores[s] * d_out[t*D + h*head_dim + i]; } } } free(scores); free(dscores); } static void cpu_rope_backward(float *dq, float *dk, int S, int n_heads, int head_dim) { for (int t = 0; t < S; t++) for (int h = 0; h < n_heads; h++) for (int i = 0; i < head_dim; i += 2) { float freq = 1.0f / powf(10000.0f, (float)i / head_dim); float val = t * freq; float cos_v = cosf(val), sin_v = sinf(val); int off = t * n_heads * head_dim + h * head_dim + i; float dq0 = dq[off], dq1 = dq[off+1]; dq[off] = dq0 * cos_v + dq1 * sin_v; dq[off+1] = -dq0 * sin_v + dq1 * cos_v; float dk0 = dk[off], dk1 = dk[off+1]; dk[off] = dk0 * cos_v + dk1 * sin_v; dk[off+1] = -dk0 * sin_v + dk1 * cos_v; } } static void model_clip_gradients(Model *m, float max_norm) { int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size; double total_norm_sq = 0; #define ACCUM_NORM(grad, size) do { \ for (size_t _i = 0; _i < (size_t)(size); _i++) total_norm_sq += (double)(grad)[_i] * (grad)[_i]; \ } while(0) for (int l = 0; l < N_LAYERS; l++) { ACCUM_NORM(m->grad_wq[l], d*d); ACCUM_NORM(m->grad_wk[l], d*d); ACCUM_NORM(m->grad_wv[l], d*d); ACCUM_NORM(m->grad_wo[l], d*d); ACCUM_NORM(m->grad_w1[l], hd*d); ACCUM_NORM(m->grad_w2[l], d*hd); ACCUM_NORM(m->grad_w3[l], hd*d); } ACCUM_NORM(m->grad_wcls, vs*d); ACCUM_NORM(m->grad_emb, vs*d); #undef ACCUM_NORM float total_norm = sqrtf((float)total_norm_sq); if (total_norm > max_norm) { float scale = max_norm / total_norm; #define SCALE_GRAD(grad, size) do { \ for (size_t _i = 0; _i < (size_t)(size); _i++) (grad)[_i] *= scale; \ } while(0) for (int l = 0; l < N_LAYERS; l++) { SCALE_GRAD(m->grad_wq[l], d*d); SCALE_GRAD(m->grad_wk[l], d*d); SCALE_GRAD(m->grad_wv[l], d*d); SCALE_GRAD(m->grad_wo[l], d*d); SCALE_GRAD(m->grad_w1[l], hd*d); SCALE_GRAD(m->grad_w2[l], d*hd); SCALE_GRAD(m->grad_w3[l], hd*d); } SCALE_GRAD(m->grad_wcls, vs*d); SCALE_GRAD(m->grad_emb, vs*d); #undef SCALE_GRAD } } static void model_backward(Model *m, const int *tokens) { int S = m->seq_len, d = m->cfg.dim, hd = m->cfg.hidden_dim; int nh = m->cfg.n_heads, hdim = HEAD_DIM, vs = m->cfg.vocab_size; // Zero gradients for (int l = 0; l < N_LAYERS; l++) { memset(m->grad_wq[l], 0, d*d*sizeof(float)); memset(m->grad_wk[l], 0, d*d*sizeof(float)); memset(m->grad_wv[l], 0, d*d*sizeof(float)); memset(m->grad_wo[l], 0, d*d*sizeof(float)); memset(m->grad_w1[l], 0, hd*d*sizeof(float)); memset(m->grad_w2[l], 0, d*hd*sizeof(float)); memset(m->grad_w3[l], 0, hd*d*sizeof(float)); } memset(m->grad_wcls, 0, (size_t)vs*d*sizeof(float)); memset(m->grad_emb, 0, (size_t)vs*d*sizeof(float)); // dLogits from cross-entropy float *dlogits = (float*)calloc(S * vs, sizeof(float)); for (int t = 0; t < S - 1; t++) { float mx = -1e9f; for (int i = 0; i < vs; i++) if (m->logits[t*vs+i] > mx) mx = m->logits[t*vs+i]; float sm = 0; for (int i = 0; i < vs; i++) sm += expf(m->logits[t*vs+i] - mx); for (int i = 0; i < vs; i++) dlogits[t*vs+i] = expf(m->logits[t*vs+i] - mx) / sm; dlogits[t*vs + tokens[t+1]] -= 1.0f; for (int i = 0; i < vs; i++) dlogits[t*vs+i] /= (S - 1); } // Classifier backward cpu_accum_dW(m->grad_wcls, dlogits, m->act_final, S, vs, d); float *dx = (float*)calloc(S * d, sizeof(float)); cpu_matmul_backward_dx(m->wcls, dlogits, dx, S, vs, d); free(dlogits); // Final RMSNorm backward float *dx_norm = (float*)malloc(S * d * sizeof(float)); cpu_rmsnorm_backward(dx_norm, dx, m->act_pre_final, m->rms_final_w, S, d); memcpy(dx, dx_norm, S * d * sizeof(float)); free(dx_norm); // Layers in reverse for (int l = N_LAYERS - 1; l >= 0; l--) { // FFN down backward float *d_silu = (float*)calloc(S * hd, sizeof(float)); cpu_matmul_backward_dx(m->w2[l], dx, d_silu, S, d, hd); cpu_accum_dW(m->grad_w2[l], dx, m->act_silu[l], S, d, hd); // SiLU backward float *d_h1 = (float*)malloc(S * hd * sizeof(float)); float *d_h3 = (float*)malloc(S * hd * sizeof(float)); for (int t = 0; t < S; t++) for (int i = 0; i < hd; i++) { d_h1[t*hd+i] = d_silu[t*hd+i] * m->act_h3[l][t*hd+i] * silu_backward(m->act_h1[l][t*hd+i]); d_h3[t*hd+i] = d_silu[t*hd+i] * silu_f(m->act_h1[l][t*hd+i]); } free(d_silu); // FFN up backward cpu_accum_dW(m->grad_w1[l], d_h1, m->act_ffn_in[l], S, hd, d); cpu_accum_dW(m->grad_w3[l], d_h3, m->act_ffn_in[l], S, hd, d); float *dx_ffn_in = (float*)calloc(S * d, sizeof(float)); float *dx_w1 = (float*)malloc(S * d * sizeof(float)); float *dx_w3 = (float*)malloc(S * d * sizeof(float)); cpu_matmul_backward_dx(m->w1[l], d_h1, dx_w1, S, hd, d); cpu_matmul_backward_dx(m->w3[l], d_h3, dx_w3, S, hd, d); for (int i = 0; i < S * d; i++) dx_ffn_in[i] = dx_w1[i] + dx_w3[i]; free(d_h1); free(d_h3); free(dx_w1); free(dx_w3); // RMSNorm FFN backward float *dx_ffn_norm = (float*)malloc(S * d * sizeof(float)); // The input to FFN rmsnorm was the residual after attention = act_x[l] + attn_residual // We saved act_x[l] but the actual input to ffn_rmsnorm is x after attention residual // For a proper implementation we'd save this. Approximate with act_x[l]. cpu_rmsnorm_backward(dx_ffn_norm, dx_ffn_in, m->act_x[l], m->rms_ffn_w[l], S, d); for (int i = 0; i < S * d; i++) dx[i] += dx_ffn_norm[i]; free(dx_ffn_in); free(dx_ffn_norm); // O projection backward float *d_attn_out = (float*)calloc(S * d, sizeof(float)); cpu_matmul_backward_dx(m->wo[l], dx, d_attn_out, S, d, d); cpu_accum_dW(m->grad_wo[l], dx, m->act_attn_out[l], S, d, d); // Attention backward float *dq = (float*)calloc(S * d, sizeof(float)); float *dk = (float*)calloc(S * d, sizeof(float)); float *dv = (float*)calloc(S * d, sizeof(float)); cpu_attention_backward(dq, dk, dv, d_attn_out, m->act_q[l], m->act_k[l], m->act_v[l], S, nh, hdim); free(d_attn_out); cpu_rope_backward(dq, dk, S, nh, hdim); // QKV backward cpu_accum_dW(m->grad_wq[l], dq, m->act_xnorm[l], S, d, d); cpu_accum_dW(m->grad_wk[l], dk, m->act_xnorm[l], S, d, d); cpu_accum_dW(m->grad_wv[l], dv, m->act_xnorm[l], S, d, d); float *dx_qkv = (float*)calloc(S * d, sizeof(float)); float *tmp = (float*)malloc(S * d * sizeof(float)); cpu_matmul_backward_dx(m->wq[l], dq, tmp, S, d, d); for (int i = 0; i < S*d; i++) dx_qkv[i] += tmp[i]; cpu_matmul_backward_dx(m->wk[l], dk, tmp, S, d, d); for (int i = 0; i < S*d; i++) dx_qkv[i] += tmp[i]; cpu_matmul_backward_dx(m->wv[l], dv, tmp, S, d, d); for (int i = 0; i < S*d; i++) dx_qkv[i] += tmp[i]; free(tmp); free(dq); free(dk); free(dv); // RMSNorm attention backward float *dx_att_norm = (float*)malloc(S * d * sizeof(float)); cpu_rmsnorm_backward(dx_att_norm, dx_qkv, m->act_x[l], m->rms_att_w[l], S, d); for (int i = 0; i < S * d; i++) dx[i] += dx_att_norm[i]; free(dx_qkv); free(dx_att_norm); } // Embedding gradient for (int t = 0; t < S; t++) for (int i = 0; i < d; i++) m->grad_emb[tokens[t]*d + i] += dx[t*d + i]; free(dx); } static void model_adam_step(Model *m, float lr, float beta1, float beta2, float eps) { m->adam_step++; float bc1 = 1.0f - powf(beta1, m->adam_step); float bc2 = 1.0f - powf(beta2, m->adam_step); size_t idx = 0; #define ADAM_UPDATE(param, grad, size) do { \ for (size_t _i = 0; _i < (size_t)(size); _i++) { \ float g = (grad)[_i]; \ m->adam_m[idx] = beta1 * m->adam_m[idx] + (1-beta1) * g; \ m->adam_v[idx] = beta2 * m->adam_v[idx] + (1-beta2) * g * g; \ float m_hat = m->adam_m[idx] / bc1; \ float v_hat = m->adam_v[idx] / bc2; \ (param)[_i] -= lr * m_hat / (sqrtf(v_hat) + eps); \ idx++; \ } \ } while(0) int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size; for (int l = 0; l < N_LAYERS; l++) { ADAM_UPDATE(m->wq[l], m->grad_wq[l], d*d); ADAM_UPDATE(m->wk[l], m->grad_wk[l], d*d); ADAM_UPDATE(m->wv[l], m->grad_wv[l], d*d); ADAM_UPDATE(m->wo[l], m->grad_wo[l], d*d); ADAM_UPDATE(m->w1[l], m->grad_w1[l], hd*d); ADAM_UPDATE(m->w2[l], m->grad_w2[l], d*hd); ADAM_UPDATE(m->w3[l], m->grad_w3[l], hd*d); } ADAM_UPDATE(m->wcls, m->grad_wcls, vs*d); ADAM_UPDATE(m->token_embedding, m->grad_emb, vs*d); #undef ADAM_UPDATE }