mirror of https://github.com/maderix/ANE.git
Merge 7ea45c2fab into 20cd236f61
This commit is contained in:
commit
668c236a08
|
|
@ -8,7 +8,7 @@ HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h
|
|||
HEADERS_ANE = $(HEADERS_LARGE) ane_rmsnorm_bwd.h ane_classifier.h
|
||||
|
||||
train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h
|
||||
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS)
|
||||
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) -framework Accelerate
|
||||
|
||||
train_large: train_large.m $(HEADERS_LARGE)
|
||||
$(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate
|
||||
|
|
|
|||
|
|
@ -4,25 +4,21 @@
|
|||
#include "forward.h"
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
// dW += dy @ x^T — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim]
|
||||
// dW += dy^T @ x — dy: [S, out_dim], x: [S, in_dim], dW: [out_dim, in_dim]
|
||||
static void cpu_accum_dW(float *dW, const float *dy, const float *x, int S, int out_dim, int in_dim) {
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < out_dim; i++)
|
||||
for (int j = 0; j < in_dim; j++)
|
||||
dW[i*in_dim+j] += dy[t*out_dim+i] * x[t*in_dim+j];
|
||||
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
|
||||
out_dim, in_dim, S, 1.0f,
|
||||
dy, out_dim, x, in_dim, 1.0f, dW, in_dim);
|
||||
}
|
||||
|
||||
// dx = W^T @ dy — W: [out_dim, in_dim], dy: [S, out_dim] → dx: [S, in_dim]
|
||||
static void cpu_matmul_backward_dx(const float *W, const float *dy, float *dx,
|
||||
int S, int out_dim, int in_dim) {
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int j = 0; j < in_dim; j++) {
|
||||
float sum = 0;
|
||||
for (int i = 0; i < out_dim; i++)
|
||||
sum += W[i*in_dim+j] * dy[t*out_dim+i];
|
||||
dx[t*in_dim+j] = sum;
|
||||
}
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
S, in_dim, out_dim, 1.0f,
|
||||
dy, out_dim, W, in_dim, 0.0f, dx, in_dim);
|
||||
}
|
||||
|
||||
static void cpu_rmsnorm_backward(float *dx, const float *dy, const float *x, const float *w,
|
||||
|
|
@ -278,18 +274,30 @@ static void model_adam_step(Model *m, float lr, float beta1, float beta2, float
|
|||
m->adam_step++;
|
||||
float bc1 = 1.0f - powf(beta1, m->adam_step);
|
||||
float bc2 = 1.0f - powf(beta2, m->adam_step);
|
||||
float neg_lr_over_bc1 = -lr / bc1;
|
||||
float inv_bc2 = 1.0f / bc2;
|
||||
float one_minus_b1 = 1.0f - beta1;
|
||||
float one_minus_b2 = 1.0f - beta2;
|
||||
size_t idx = 0;
|
||||
|
||||
// Vectorized Adam update for a contiguous chunk
|
||||
#define ADAM_UPDATE(param, grad, size) do { \
|
||||
for (size_t _i = 0; _i < (size_t)(size); _i++) { \
|
||||
float g = (grad)[_i]; \
|
||||
m->adam_m[idx] = beta1 * m->adam_m[idx] + (1-beta1) * g; \
|
||||
m->adam_v[idx] = beta2 * m->adam_v[idx] + (1-beta2) * g * g; \
|
||||
float m_hat = m->adam_m[idx] / bc1; \
|
||||
float v_hat = m->adam_v[idx] / bc2; \
|
||||
(param)[_i] -= lr * m_hat / (sqrtf(v_hat) + eps); \
|
||||
idx++; \
|
||||
} \
|
||||
size_t _n = (size_t)(size); \
|
||||
float *_m = m->adam_m + idx; \
|
||||
float *_v = m->adam_v + idx; \
|
||||
float *_tmp = (float*)malloc(_n * sizeof(float)); \
|
||||
vDSP_vsmul(_m, 1, &beta1, _m, 1, _n); \
|
||||
vDSP_vsma((grad), 1, &one_minus_b1, _m, 1, _m, 1, _n); \
|
||||
vDSP_vsq((grad), 1, _tmp, 1, _n); \
|
||||
vDSP_vsmul(_v, 1, &beta2, _v, 1, _n); \
|
||||
vDSP_vsma(_tmp, 1, &one_minus_b2, _v, 1, _v, 1, _n); \
|
||||
vDSP_vsmul(_v, 1, &inv_bc2, _tmp, 1, _n); \
|
||||
int _nn = (int)_n; vvsqrtf(_tmp, _tmp, &_nn); \
|
||||
vDSP_vsadd(_tmp, 1, &eps, _tmp, 1, _n); \
|
||||
vDSP_vdiv(_tmp, 1, _m, 1, _tmp, 1, _n); \
|
||||
vDSP_vsma(_tmp, 1, &neg_lr_over_bc1, (param), 1, (param), 1, _n); \
|
||||
free(_tmp); \
|
||||
idx += _n; \
|
||||
} while(0)
|
||||
|
||||
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size;
|
||||
|
|
|
|||
|
|
@ -54,12 +54,31 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c
|
|||
|
||||
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
|
||||
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
|
||||
for (size_t i=0; i<s->n; i++) {
|
||||
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
|
||||
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
|
||||
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
|
||||
w[i] -= lr * mh / (sqrtf(vh) + eps);
|
||||
}
|
||||
size_t n = s->n;
|
||||
float one_minus_b1 = 1.0f - b1;
|
||||
float one_minus_b2 = 1.0f - b2;
|
||||
float neg_lr_over_bc1 = -lr / bc1;
|
||||
float inv_bc2 = 1.0f / bc2;
|
||||
|
||||
// m = b1*m + (1-b1)*g
|
||||
vDSP_vsmul(s->m, 1, &b1, s->m, 1, n);
|
||||
vDSP_vsma(g, 1, &one_minus_b1, s->m, 1, s->m, 1, n);
|
||||
|
||||
// v = b2*v + (1-b2)*g^2
|
||||
float *tmp = (float*)malloc(n * sizeof(float));
|
||||
vDSP_vsq(g, 1, tmp, 1, n);
|
||||
vDSP_vsmul(s->v, 1, &b2, s->v, 1, n);
|
||||
vDSP_vsma(tmp, 1, &one_minus_b2, s->v, 1, s->v, 1, n);
|
||||
|
||||
// update = m / (sqrt(v/bc2) + eps), then w -= (lr/bc1) * update
|
||||
vDSP_vsmul(s->v, 1, &inv_bc2, tmp, 1, n);
|
||||
int nn = (int)n;
|
||||
vvsqrtf(tmp, tmp, &nn);
|
||||
vDSP_vsadd(tmp, 1, &eps, tmp, 1, n);
|
||||
vDSP_vdiv(tmp, 1, s->m, 1, tmp, 1, n);
|
||||
vDSP_vsma(tmp, 1, &neg_lr_over_bc1, w, 1, w, 1, n);
|
||||
|
||||
free(tmp);
|
||||
}
|
||||
|
||||
// Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ])
|
||||
|
|
|
|||
Loading…
Reference in New Issue