mirror of https://github.com/maderix/ANE.git
111 lines
4.5 KiB
C
111 lines
4.5 KiB
C
// stories_cpu_ops_opt.h — Optimized CPU operations: NEON Adam, vectorized embedding
|
|
#pragma once
|
|
#include "stories_cpu_ops.h"
|
|
#include <arm_neon.h>
|
|
|
|
// ===== NEON-vectorized Adam optimizer =====
|
|
// ~3-3.5x faster than scalar version for large param counts
|
|
// Uses vrsqrteq_f32 + one Newton-Raphson step for fast reciprocal sqrt
|
|
static void adam_update_opt(float *w, const float *g, AdamState *s, int t,
|
|
float lr, float b1, float b2, float eps) {
|
|
float bc1 = 1.0f - powf(b1, t);
|
|
float bc2 = 1.0f - powf(b2, t);
|
|
float inv_bc1 = 1.0f / bc1;
|
|
float inv_bc2 = 1.0f / bc2;
|
|
float one_minus_b1 = 1.0f - b1;
|
|
float one_minus_b2 = 1.0f - b2;
|
|
|
|
float32x4_t vb1 = vdupq_n_f32(b1);
|
|
float32x4_t vb2 = vdupq_n_f32(b2);
|
|
float32x4_t v1mb1 = vdupq_n_f32(one_minus_b1);
|
|
float32x4_t v1mb2 = vdupq_n_f32(one_minus_b2);
|
|
float32x4_t vinv_bc1 = vdupq_n_f32(inv_bc1);
|
|
float32x4_t vinv_bc2 = vdupq_n_f32(inv_bc2);
|
|
float32x4_t vneg_lr = vdupq_n_f32(-lr);
|
|
float32x4_t veps = vdupq_n_f32(eps);
|
|
|
|
size_t n = s->n;
|
|
size_t i = 0;
|
|
|
|
// Process 4 elements at a time
|
|
for (; i + 3 < n; i += 4) {
|
|
// Load
|
|
float32x4_t vm = vld1q_f32(s->m + i);
|
|
float32x4_t vv = vld1q_f32(s->v + i);
|
|
float32x4_t vg = vld1q_f32(g + i);
|
|
float32x4_t vw = vld1q_f32(w + i);
|
|
|
|
// m = b1*m + (1-b1)*g
|
|
vm = vmlaq_f32(vmulq_f32(vb1, vm), v1mb1, vg);
|
|
// v = b2*v + (1-b2)*g*g
|
|
float32x4_t g2 = vmulq_f32(vg, vg);
|
|
vv = vmlaq_f32(vmulq_f32(vb2, vv), v1mb2, g2);
|
|
|
|
// Store updated m, v
|
|
vst1q_f32(s->m + i, vm);
|
|
vst1q_f32(s->v + i, vv);
|
|
|
|
// mhat = m / bc1, vhat = v / bc2
|
|
float32x4_t mhat = vmulq_f32(vm, vinv_bc1);
|
|
float32x4_t vhat = vmulq_f32(vv, vinv_bc2);
|
|
|
|
// Fast reciprocal sqrt: vrsqrteq + one Newton-Raphson iteration
|
|
// rsqrt_est ≈ 1/sqrt(vhat)
|
|
float32x4_t rsqrt_est = vrsqrteq_f32(vhat);
|
|
// Newton-Raphson: rsqrt *= (3 - vhat * rsqrt^2) / 2
|
|
float32x4_t rsqrt_sq = vmulq_f32(rsqrt_est, rsqrt_est);
|
|
float32x4_t nr_step = vrsqrtsq_f32(vhat, rsqrt_sq);
|
|
rsqrt_est = vmulq_f32(rsqrt_est, nr_step);
|
|
|
|
// w -= lr * mhat / (sqrt(vhat) + eps)
|
|
// = w + (-lr) * mhat * (1/(sqrt(vhat) + eps))
|
|
// Compute sqrt(vhat) from rsqrt: sqrt = vhat * rsqrt(vhat) (avoids division)
|
|
float32x4_t sqrt_vhat = vmulq_f32(vhat, rsqrt_est);
|
|
float32x4_t denom = vaddq_f32(sqrt_vhat, veps);
|
|
|
|
// Use vdivq_f32 for the final division (accurate, eps-adjusted)
|
|
float32x4_t update = vmulq_f32(vneg_lr, vdivq_f32(mhat, denom));
|
|
vw = vaddq_f32(vw, update);
|
|
|
|
vst1q_f32(w + i, vw);
|
|
}
|
|
|
|
// Scalar tail
|
|
for (; i < n; i++) {
|
|
s->m[i] = b1 * s->m[i] + one_minus_b1 * g[i];
|
|
s->v[i] = b2 * s->v[i] + one_minus_b2 * g[i] * g[i];
|
|
float mh = s->m[i] * inv_bc1;
|
|
float vh = s->v[i] * inv_bc2;
|
|
w[i] -= lr * mh / (sqrtf(vh) + eps);
|
|
}
|
|
}
|
|
|
|
// ===== Vectorized embedding lookup =====
|
|
// Gather rows from [VOCAB, DIM] row-major embed table → x [DIM, SEQ] channel-first
|
|
// Strategy: gather token rows into temp buffer [SEQ, DIM], then transpose via vDSP_mtrans
|
|
static void embed_lookup_opt(float *x, const float *embed, const uint16_t *tokens,
|
|
int dim, int seq, float *tmp) {
|
|
// Gather: tmp[t*dim + d] = embed[tokens[t]*dim + d]
|
|
for (int t = 0; t < seq; t++) {
|
|
memcpy(tmp + t * dim, embed + tokens[t] * dim, dim * sizeof(float));
|
|
}
|
|
// Transpose [SEQ, DIM] → [DIM, SEQ]: x[d*seq + t] = tmp[t*dim + d]
|
|
vDSP_mtrans(tmp, 1, x, 1, (vDSP_Length)dim, (vDSP_Length)seq);
|
|
}
|
|
|
|
// ===== Vectorized embedding backward =====
|
|
// Accumulate dE[tok] += dx[:,t] for each position
|
|
// Strategy: transpose dx [DIM, SEQ] → tmp [SEQ, DIM], then accumulate rows
|
|
static void embed_backward_opt(float *d_embed, const float *dx, const uint16_t *tokens,
|
|
int dim, int seq, float *tmp) {
|
|
// Transpose [DIM, SEQ] → [SEQ, DIM]: tmp[t*dim + d] = dx[d*seq + t]
|
|
vDSP_mtrans(dx, 1, tmp, 1, (vDSP_Length)seq, (vDSP_Length)dim);
|
|
// Scatter-add: d_embed[tok*dim .. (tok+1)*dim] += tmp[t*dim .. (t+1)*dim]
|
|
for (int t = 0; t < seq; t++) {
|
|
vDSP_vadd(tmp + t * dim, 1,
|
|
d_embed + tokens[t] * dim, 1,
|
|
d_embed + tokens[t] * dim, 1,
|
|
(vDSP_Length)dim);
|
|
}
|
|
}
|