ANE/training/forward.h

201 lines
8.0 KiB
C

// forward.h — Forward pass: ANE baked-weight conv for linears, CPU for element-wise
#pragma once
#include "model.h"
#include <math.h>
#include <string.h>
// ANE conv eval: input [S, in_dim] row-major → transpose to [in_dim, S] channels-first
// ANE computes conv(W, x) with baked W → output [out_dim, S]
// Transpose back to [S, out_dim] row-major
static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
int S, int in_dim, int out_dim) {
if (g_fp16_io) {
// fp16 I/O path: transpose + convert float→fp16, write, eval, read fp16→float + transpose
_Float16 *x_t = (_Float16*)malloc(S * in_dim * sizeof(_Float16));
for (int t = 0; t < S; t++)
for (int i = 0; i < in_dim; i++)
x_t[i*S + t] = (_Float16)x[t*in_dim + i];
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(_Float16));
ane_eval(kernel);
_Float16 *y_t = (_Float16*)malloc(S * out_dim * sizeof(_Float16));
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(_Float16));
for (int t = 0; t < S; t++)
for (int i = 0; i < out_dim; i++)
y[t*out_dim + i] = (float)y_t[i*S + t];
free(x_t); free(y_t);
} else {
// fp32 I/O path: transpose, write, eval, read, transpose back
float *x_t = (float*)malloc(S * in_dim * sizeof(float));
for (int t = 0; t < S; t++)
for (int i = 0; i < in_dim; i++)
x_t[i*S + t] = x[t*in_dim + i];
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float));
ane_eval(kernel);
float *y_t = (float*)malloc(S * out_dim * sizeof(float));
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float));
for (int t = 0; t < S; t++)
for (int i = 0; i < out_dim; i++)
y[t*out_dim + i] = y_t[i*S + t];
free(x_t); free(y_t);
}
}
// CPU matmul fallback: y = W @ x, W[out_dim, in_dim], x[S, in_dim] → y[S, out_dim]
static void cpu_matmul(const float *W, const float *x, float *y, int S, int in_dim, int out_dim) {
for (int t = 0; t < S; t++)
for (int i = 0; i < out_dim; i++) {
float sum = 0;
for (int j = 0; j < in_dim; j++)
sum += W[i*in_dim + j] * x[t*in_dim + j];
y[t*out_dim + i] = sum;
}
}
static void cpu_rmsnorm(float *out, const float *x, const float *w, int S, int D) {
for (int t = 0; t < S; t++) {
float ss = 0;
for (int i = 0; i < D; i++) ss += x[t*D+i] * x[t*D+i];
ss = 1.0f / sqrtf(ss / D + 1e-5f);
for (int i = 0; i < D; i++) out[t*D+i] = x[t*D+i] * ss * w[i];
}
}
static void cpu_rope(float *q, float *k, int S, int n_heads, int head_dim) {
for (int t = 0; t < S; t++)
for (int h = 0; h < n_heads; h++)
for (int i = 0; i < head_dim; i += 2) {
float freq = 1.0f / powf(10000.0f, (float)i / head_dim);
float val = t * freq;
float cos_v = cosf(val), sin_v = sinf(val);
int off = t * n_heads * head_dim + h * head_dim + i;
float q0 = q[off], q1 = q[off+1];
q[off] = q0 * cos_v - q1 * sin_v;
q[off+1] = q0 * sin_v + q1 * cos_v;
float k0 = k[off], k1 = k[off+1];
k[off] = k0 * cos_v - k1 * sin_v;
k[off+1] = k0 * sin_v + k1 * cos_v;
}
}
static void cpu_attention(float *out, const float *q, const float *k, const float *v,
int S, int n_heads, int head_dim) {
float scale = 1.0f / sqrtf((float)head_dim);
float *scores = (float*)malloc(S * S * sizeof(float));
for (int h = 0; h < n_heads; h++) {
int D = n_heads * head_dim;
for (int t = 0; t < S; t++) {
float mx = -1e9f;
for (int s = 0; s <= t; s++) {
float dot = 0;
for (int i = 0; i < head_dim; i++)
dot += q[t*D + h*head_dim + i] * k[s*D + h*head_dim + i];
scores[s] = dot * scale;
if (scores[s] > mx) mx = scores[s];
}
float sm = 0;
for (int s = 0; s <= t; s++) { scores[s] = expf(scores[s] - mx); sm += scores[s]; }
for (int s = 0; s <= t; s++) scores[s] /= sm;
for (int i = 0; i < head_dim; i++) {
float val = 0;
for (int s = 0; s <= t; s++)
val += scores[s] * v[s*D + h*head_dim + i];
out[t*D + h*head_dim + i] = val;
}
}
}
free(scores);
}
static inline float silu_f(float x) { return x / (1.0f + expf(-x)); }
// Forward pass — returns loss. Saves activations for backward.
static float model_forward(Model *m, const int *tokens, bool use_ane) {
int S = m->seq_len, d = m->cfg.dim, hd = m->cfg.hidden_dim;
int nh = m->cfg.n_heads, hdim = HEAD_DIM, vs = m->cfg.vocab_size;
float *x = (float*)malloc(S * d * sizeof(float));
for (int t = 0; t < S; t++)
memcpy(x + t*d, m->token_embedding + tokens[t]*d, d * sizeof(float));
for (int l = 0; l < N_LAYERS; l++) {
memcpy(m->act_x[l], x, S * d * sizeof(float));
cpu_rmsnorm(m->act_xnorm[l], x, m->rms_att_w[l], S, d);
if (use_ane) {
ane_conv_eval(m->kern_q[l], m->act_xnorm[l], m->act_q[l], S, d, d);
ane_conv_eval(m->kern_k[l], m->act_xnorm[l], m->act_k[l], S, d, d);
ane_conv_eval(m->kern_v[l], m->act_xnorm[l], m->act_v[l], S, d, d);
} else {
cpu_matmul(m->wq[l], m->act_xnorm[l], m->act_q[l], S, d, d);
cpu_matmul(m->wk[l], m->act_xnorm[l], m->act_k[l], S, d, d);
cpu_matmul(m->wv[l], m->act_xnorm[l], m->act_v[l], S, d, d);
}
cpu_rope(m->act_q[l], m->act_k[l], S, nh, hdim);
cpu_attention(m->act_attn_out[l], m->act_q[l], m->act_k[l], m->act_v[l], S, nh, hdim);
float *o_out = (float*)malloc(S * d * sizeof(float));
if (use_ane) {
ane_conv_eval(m->kern_o[l], m->act_attn_out[l], o_out, S, d, d);
} else {
cpu_matmul(m->wo[l], m->act_attn_out[l], o_out, S, d, d);
}
for (int i = 0; i < S * d; i++) x[i] += o_out[i];
free(o_out);
cpu_rmsnorm(m->act_ffn_in[l], x, m->rms_ffn_w[l], S, d);
if (use_ane) {
ane_conv_eval(m->kern_w1[l], m->act_ffn_in[l], m->act_h1[l], S, d, hd);
ane_conv_eval(m->kern_w3[l], m->act_ffn_in[l], m->act_h3[l], S, d, hd);
} else {
cpu_matmul(m->w1[l], m->act_ffn_in[l], m->act_h1[l], S, d, hd);
cpu_matmul(m->w3[l], m->act_ffn_in[l], m->act_h3[l], S, d, hd);
}
for (int t = 0; t < S; t++)
for (int i = 0; i < hd; i++)
m->act_silu[l][t*hd+i] = silu_f(m->act_h1[l][t*hd+i]) * m->act_h3[l][t*hd+i];
float *ffn_out = (float*)malloc(S * d * sizeof(float));
if (use_ane) {
ane_conv_eval(m->kern_w2[l], m->act_silu[l], ffn_out, S, hd, d);
} else {
cpu_matmul(m->w2[l], m->act_silu[l], ffn_out, S, hd, d);
}
for (int i = 0; i < S * d; i++) x[i] += ffn_out[i];
free(ffn_out);
}
memcpy(m->act_pre_final, x, S * d * sizeof(float));
cpu_rmsnorm(m->act_final, x, m->rms_final_w, S, d);
if (use_ane && m->kern_cls) {
ane_conv_eval(m->kern_cls, m->act_final, m->logits, S, d, vs);
} else {
cpu_matmul(m->wcls, m->act_final, m->logits, S, d, vs);
}
free(x);
float loss = 0;
for (int t = 0; t < S - 1; t++) {
float mx = -1e9f;
for (int i = 0; i < vs; i++) if (m->logits[t*vs+i] > mx) mx = m->logits[t*vs+i];
float sm = 0;
for (int i = 0; i < vs; i++) sm += expf(m->logits[t*vs+i] - mx);
float log_prob = m->logits[t*vs + tokens[t+1]] - mx - logf(sm);
loss -= log_prob;
}
return loss / (S - 1);
}