mirror of https://github.com/maderix/ANE.git
149 lines
5.9 KiB
C
149 lines
5.9 KiB
C
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
|
|
#pragma once
|
|
#include "stories_config.h"
|
|
|
|
static float *g_rms_tmp = NULL;
|
|
|
|
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
|
|
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
|
|
float *ss = (float*)calloc(S, sizeof(float));
|
|
for (int i=0; i<d; i++) {
|
|
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
|
}
|
|
float invd = 1.0f/d, eps=1e-5f;
|
|
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
|
int n = S; vvrsqrtf(ss, ss, &n);
|
|
for (int i=0; i<d; i++) {
|
|
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
|
|
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
|
|
}
|
|
free(ss);
|
|
}
|
|
|
|
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
|
|
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
|
|
float *ss = (float*)calloc(S, sizeof(float));
|
|
for (int i=0; i<d; i++) {
|
|
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
|
}
|
|
float invd = 1.0f/d, eps=1e-5f;
|
|
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
|
float *rrms = (float*)malloc(S*4);
|
|
int n = S; vvrsqrtf(rrms, ss, &n);
|
|
float *dot = (float*)calloc(S, sizeof(float));
|
|
for (int i=0; i<d; i++) {
|
|
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
|
|
}
|
|
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
|
|
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
|
|
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
|
|
for (int i=0; i<d; i++) {
|
|
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
|
|
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
|
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
|
|
dw[i] += s;
|
|
}
|
|
free(ss); free(rrms); free(dot);
|
|
}
|
|
|
|
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
|
|
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
|
|
size_t n = s->n;
|
|
float one_minus_b1 = 1.0f - b1;
|
|
float one_minus_b2 = 1.0f - b2;
|
|
float neg_lr_over_bc1 = -lr / bc1;
|
|
float inv_bc2 = 1.0f / bc2;
|
|
|
|
// m = b1*m + (1-b1)*g
|
|
vDSP_vsmul(s->m, 1, &b1, s->m, 1, n);
|
|
vDSP_vsma(g, 1, &one_minus_b1, s->m, 1, s->m, 1, n);
|
|
|
|
// v = b2*v + (1-b2)*g^2
|
|
float *tmp = (float*)malloc(n * sizeof(float));
|
|
vDSP_vsq(g, 1, tmp, 1, n);
|
|
vDSP_vsmul(s->v, 1, &b2, s->v, 1, n);
|
|
vDSP_vsma(tmp, 1, &one_minus_b2, s->v, 1, s->v, 1, n);
|
|
|
|
// update = m / (sqrt(v/bc2) + eps), then w -= (lr/bc1) * update
|
|
vDSP_vsmul(s->v, 1, &inv_bc2, tmp, 1, n);
|
|
int nn = (int)n;
|
|
vvsqrtf(tmp, tmp, &nn);
|
|
vDSP_vsadd(tmp, 1, &eps, tmp, 1, n);
|
|
vDSP_vdiv(tmp, 1, s->m, 1, tmp, 1, n);
|
|
vDSP_vsma(tmp, 1, &neg_lr_over_bc1, w, 1, w, 1, n);
|
|
|
|
free(tmp);
|
|
}
|
|
|
|
// Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ])
|
|
// logits[v*SEQ+t] = logit for vocab v, position t
|
|
// targets[t] = target token id for position t
|
|
// Returns mean CE loss, writes dlogits = softmax(logits) - one_hot(targets)
|
|
// Data is column-major [V, S], but we process per-column (stride=1 within col is v*S+t, stride between v's is S)
|
|
// For vDSP: transpose to row-major scratch [S, V] to vectorize softmax per position
|
|
static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) {
|
|
// Work in transposed layout [S, V] where each row is one position's logits (contiguous)
|
|
float *buf = (float*)malloc(S * V * 4);
|
|
// Transpose [V,S] → [S,V]: buf[t*V+v] = logits[v*S+t]
|
|
vDSP_mtrans(logits, 1, buf, 1, (vDSP_Length)S, (vDSP_Length)V);
|
|
|
|
float total_loss = 0;
|
|
float invS = 1.0f / S;
|
|
for (int t = 0; t < S; t++) {
|
|
float *row = buf + t * V;
|
|
// max
|
|
float maxv;
|
|
vDSP_maxv(row, 1, &maxv, (vDSP_Length)V);
|
|
// row -= maxv
|
|
float neg_max = -maxv;
|
|
vDSP_vsadd(row, 1, &neg_max, row, 1, (vDSP_Length)V);
|
|
// exp in-place
|
|
int n = V;
|
|
vvexpf(row, row, &n);
|
|
// sum
|
|
float sum;
|
|
vDSP_sve(row, 1, &sum, (vDSP_Length)V);
|
|
// normalize
|
|
float inv_sum = 1.0f / sum;
|
|
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
|
|
// loss
|
|
int tgt = targets[t];
|
|
total_loss -= logf(row[tgt] + 1e-10f);
|
|
// gradient: softmax - one_hot, then /S
|
|
row[tgt] -= 1.0f;
|
|
vDSP_vsmul(row, 1, &invS, row, 1, (vDSP_Length)V);
|
|
}
|
|
// Transpose back [S,V] → [V,S]
|
|
vDSP_mtrans(buf, 1, dlogits, 1, (vDSP_Length)V, (vDSP_Length)S);
|
|
free(buf);
|
|
return total_loss / S;
|
|
}
|
|
|
|
// Embedding lookup: token_ids → x [DIM, SEQ] (channel-first)
|
|
// embed is [VOCAB, DIM] row-major (vocab_size rows, dim cols)
|
|
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
|
|
for (int t = 0; t < seq; t++) {
|
|
int tok = tokens[t];
|
|
for (int d = 0; d < dim; d++) {
|
|
x[d*seq + t] = embed[tok*dim + d];
|
|
}
|
|
}
|
|
}
|
|
|
|
// Embedding backward: accumulate dE[tok] += dx[:,t] for each position
|
|
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
|
|
for (int t = 0; t < seq; t++) {
|
|
int tok = tokens[t];
|
|
for (int d = 0; d < dim; d++) {
|
|
d_embed[tok*dim + d] += dx[d*seq + t];
|
|
}
|
|
}
|
|
}
|