mirror of https://github.com/maderix/ANE.git
fix: correctness & safety improvements
- Validate all fread() return values in model_load_weights (model.h) - Check ane_eval() return values in ane_conv_eval (forward.h) and ane_eval_k (tiny_train.m) - Log error details on ANE eval failure (ane_runtime.h) - Thread-safe RMSNorm: replace global g_rms_tmp with local allocation (stories_cpu_ops.h) - Bounds-check token indices in cross_entropy_loss, embed_lookup, embed_backward - Atomic checkpoint writes via tmp+rename pattern (tiny_train.m) - Non-destructive recompile: compile new kernels first, swap only on success (model.h) - Validate fread() in load_checkpoint (tiny_train.m)
This commit is contained in:
parent
443194bca4
commit
541bf4ec90
|
|
@ -141,9 +141,14 @@ static void ane_read_output(ANEKernel *k, int idx, void *data, size_t bytes) {
|
||||||
|
|
||||||
static bool ane_eval(ANEKernel *k) {
|
static bool ane_eval(ANEKernel *k) {
|
||||||
NSError *e = nil;
|
NSError *e = nil;
|
||||||
return ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
|
BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
|
||||||
k->model, @selector(evaluateWithQoS:options:request:error:),
|
k->model, @selector(evaluateWithQoS:options:request:error:),
|
||||||
21, @{}, k->request, &e);
|
21, @{}, k->request, &e);
|
||||||
|
if (!ok) {
|
||||||
|
fprintf(stderr, "ANE eval failed: %s\n",
|
||||||
|
e ? [[e description] UTF8String] : "unknown error");
|
||||||
|
}
|
||||||
|
return ok;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ane_free(ANEKernel *k) {
|
static void ane_free(ANEKernel *k) {
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
// ANE conv eval: input [S, in_dim] row-major → transpose to [in_dim, S] channels-first
|
// ANE conv eval: input [S, in_dim] row-major → transpose to [in_dim, S] channels-first
|
||||||
// ANE computes conv(W, x) with baked W → output [out_dim, S]
|
// ANE computes conv(W, x) with baked W → output [out_dim, S]
|
||||||
// Transpose back to [S, out_dim] row-major
|
// Transpose back to [S, out_dim] row-major
|
||||||
static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
|
static bool ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
|
||||||
int S, int in_dim, int out_dim) {
|
int S, int in_dim, int out_dim) {
|
||||||
float *x_t = (float*)malloc(S * in_dim * sizeof(float));
|
float *x_t = (float*)malloc(S * in_dim * sizeof(float));
|
||||||
for (int t = 0; t < S; t++)
|
for (int t = 0; t < S; t++)
|
||||||
|
|
@ -15,7 +15,11 @@ static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
|
||||||
x_t[i*S + t] = x[t*in_dim + i];
|
x_t[i*S + t] = x[t*in_dim + i];
|
||||||
|
|
||||||
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float));
|
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float));
|
||||||
ane_eval(kernel);
|
bool ok = ane_eval(kernel);
|
||||||
|
if (!ok) {
|
||||||
|
free(x_t);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
float *y_t = (float*)malloc(S * out_dim * sizeof(float));
|
float *y_t = (float*)malloc(S * out_dim * sizeof(float));
|
||||||
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float));
|
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float));
|
||||||
|
|
@ -25,6 +29,7 @@ static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
|
||||||
y[t*out_dim + i] = y_t[i*S + t];
|
y[t*out_dim + i] = y_t[i*S + t];
|
||||||
|
|
||||||
free(x_t); free(y_t);
|
free(x_t); free(y_t);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CPU matmul fallback: y = W @ x, W[out_dim, in_dim], x[S, in_dim] → y[S, out_dim]
|
// CPU matmul fallback: y = W @ x, W[out_dim, in_dim], x[S, in_dim] → y[S, out_dim]
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,10 @@ typedef struct {
|
||||||
static int model_load_weights(Model *m, const char *path) {
|
static int model_load_weights(Model *m, const char *path) {
|
||||||
FILE *f = fopen(path, "rb");
|
FILE *f = fopen(path, "rb");
|
||||||
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
|
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
|
||||||
fread(&m->cfg, sizeof(Config), 1, f);
|
if (fread(&m->cfg, sizeof(Config), 1, f) != 1) {
|
||||||
|
fprintf(stderr, "ERROR: failed to read config from %s\n", path);
|
||||||
|
fclose(f); return -1;
|
||||||
|
}
|
||||||
bool shared = m->cfg.vocab_size > 0;
|
bool shared = m->cfg.vocab_size > 0;
|
||||||
if (m->cfg.vocab_size < 0) m->cfg.vocab_size = -m->cfg.vocab_size;
|
if (m->cfg.vocab_size < 0) m->cfg.vocab_size = -m->cfg.vocab_size;
|
||||||
|
|
||||||
|
|
@ -89,7 +92,10 @@ static int model_load_weights(Model *m, const char *path) {
|
||||||
int d = m->cfg.dim, hd = m->cfg.hidden_dim, nl = m->cfg.n_layers, vs = m->cfg.vocab_size;
|
int d = m->cfg.dim, hd = m->cfg.hidden_dim, nl = m->cfg.n_layers, vs = m->cfg.vocab_size;
|
||||||
|
|
||||||
m->token_embedding = (float*)malloc(vs * d * sizeof(float));
|
m->token_embedding = (float*)malloc(vs * d * sizeof(float));
|
||||||
fread(m->token_embedding, sizeof(float), vs * d, f);
|
if (fread(m->token_embedding, sizeof(float), vs * d, f) != (size_t)(vs * d)) {
|
||||||
|
fprintf(stderr, "ERROR: short read on token_embedding (file truncated?)\n");
|
||||||
|
fclose(f); return -1;
|
||||||
|
}
|
||||||
|
|
||||||
float *rms_att_all = (float*)malloc(nl * d * sizeof(float));
|
float *rms_att_all = (float*)malloc(nl * d * sizeof(float));
|
||||||
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
|
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||||
|
|
@ -101,15 +107,24 @@ static int model_load_weights(Model *m, const char *path) {
|
||||||
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
|
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
|
||||||
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));
|
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));
|
||||||
|
|
||||||
fread(rms_att_all, sizeof(float), nl * d, f);
|
#define FREAD_CHECK(buf, count, file, label) do { \
|
||||||
fread(wq_all, sizeof(float), nl * d * d, f);
|
size_t _n = fread(buf, sizeof(float), count, file); \
|
||||||
fread(wk_all, sizeof(float), nl * d * d, f);
|
if (_n != (size_t)(count)) { \
|
||||||
fread(wv_all, sizeof(float), nl * d * d, f);
|
fprintf(stderr, "ERROR: short read on %s: got %zu, expected %zu (file truncated?)\n", \
|
||||||
fread(wo_all, sizeof(float), nl * d * d, f);
|
label, _n, (size_t)(count)); \
|
||||||
fread(rms_ffn_all, sizeof(float), nl * d, f);
|
fclose(file); return -1; \
|
||||||
fread(w1_all, sizeof(float), nl * hd * d, f);
|
} \
|
||||||
fread(w2_all, sizeof(float), nl * d * hd, f);
|
} while(0)
|
||||||
fread(w3_all, sizeof(float), nl * hd * d, f);
|
|
||||||
|
FREAD_CHECK(rms_att_all, nl * d, f, "rms_att");
|
||||||
|
FREAD_CHECK(wq_all, nl * d * d, f, "wq");
|
||||||
|
FREAD_CHECK(wk_all, nl * d * d, f, "wk");
|
||||||
|
FREAD_CHECK(wv_all, nl * d * d, f, "wv");
|
||||||
|
FREAD_CHECK(wo_all, nl * d * d, f, "wo");
|
||||||
|
FREAD_CHECK(rms_ffn_all, nl * d, f, "rms_ffn");
|
||||||
|
FREAD_CHECK(w1_all, nl * hd * d, f, "w1");
|
||||||
|
FREAD_CHECK(w2_all, nl * d * hd, f, "w2");
|
||||||
|
FREAD_CHECK(w3_all, nl * hd * d, f, "w3");
|
||||||
|
|
||||||
for (int l = 0; l < nl; l++) {
|
for (int l = 0; l < nl; l++) {
|
||||||
m->rms_att_w[l] = (float*)malloc(d * sizeof(float));
|
m->rms_att_w[l] = (float*)malloc(d * sizeof(float));
|
||||||
|
|
@ -135,14 +150,15 @@ static int model_load_weights(Model *m, const char *path) {
|
||||||
free(rms_ffn_all); free(w1_all); free(w2_all); free(w3_all);
|
free(rms_ffn_all); free(w1_all); free(w2_all); free(w3_all);
|
||||||
|
|
||||||
m->rms_final_w = (float*)malloc(d * sizeof(float));
|
m->rms_final_w = (float*)malloc(d * sizeof(float));
|
||||||
fread(m->rms_final_w, sizeof(float), d, f);
|
FREAD_CHECK(m->rms_final_w, d, f, "rms_final");
|
||||||
|
|
||||||
if (shared) {
|
if (shared) {
|
||||||
m->wcls = m->token_embedding;
|
m->wcls = m->token_embedding;
|
||||||
} else {
|
} else {
|
||||||
m->wcls = (float*)malloc(vs * d * sizeof(float));
|
m->wcls = (float*)malloc(vs * d * sizeof(float));
|
||||||
fread(m->wcls, sizeof(float), vs * d, f);
|
FREAD_CHECK(m->wcls, vs * d, f, "wcls");
|
||||||
}
|
}
|
||||||
|
#undef FREAD_CHECK
|
||||||
fclose(f);
|
fclose(f);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
@ -188,32 +204,45 @@ static int model_compile_kernels(Model *m, int seq_len) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recompile all kernels after weight update — unload all first to avoid ANE model limit
|
// Recompile all kernels after weight update — compile new first, then swap
|
||||||
static int model_recompile_kernels(Model *m) {
|
static int model_recompile_kernels(Model *m) {
|
||||||
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size;
|
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size;
|
||||||
int S = m->seq_len;
|
int S = m->seq_len;
|
||||||
// Phase 1: unload+free all
|
|
||||||
|
// Phase 1: compile new kernels into temporaries
|
||||||
|
ANEKernel *new_q[N_LAYERS], *new_k[N_LAYERS], *new_v[N_LAYERS], *new_o[N_LAYERS];
|
||||||
|
ANEKernel *new_w1[N_LAYERS], *new_w2[N_LAYERS], *new_w3[N_LAYERS];
|
||||||
|
for (int l = 0; l < N_LAYERS; l++) {
|
||||||
|
new_q[l] = compile_conv_kernel(m->wq[l], d, d, S);
|
||||||
|
new_k[l] = compile_conv_kernel(m->wk[l], d, d, S);
|
||||||
|
new_v[l] = compile_conv_kernel(m->wv[l], d, d, S);
|
||||||
|
new_o[l] = compile_conv_kernel(m->wo[l], d, d, S);
|
||||||
|
new_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S);
|
||||||
|
new_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S);
|
||||||
|
new_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S);
|
||||||
|
if (!new_q[l] || !new_k[l] || !new_v[l] || !new_o[l] ||
|
||||||
|
!new_w1[l] || !new_w2[l] || !new_w3[l]) {
|
||||||
|
// Cleanup partially compiled new kernels
|
||||||
|
for (int i = 0; i <= l; i++) {
|
||||||
|
ane_free(new_q[i]); ane_free(new_k[i]); ane_free(new_v[i]); ane_free(new_o[i]);
|
||||||
|
ane_free(new_w1[i]); ane_free(new_w2[i]); ane_free(new_w3[i]);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "Recompile failed at layer %d, keeping old kernels\n", l);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ANEKernel *new_cls = compile_conv_kernel(m->wcls, d, vs, S);
|
||||||
|
|
||||||
|
// Phase 2: all compiles succeeded — swap and free old
|
||||||
for (int l = 0; l < N_LAYERS; l++) {
|
for (int l = 0; l < N_LAYERS; l++) {
|
||||||
ane_free(m->kern_q[l]); ane_free(m->kern_k[l]); ane_free(m->kern_v[l]); ane_free(m->kern_o[l]);
|
ane_free(m->kern_q[l]); ane_free(m->kern_k[l]); ane_free(m->kern_v[l]); ane_free(m->kern_o[l]);
|
||||||
ane_free(m->kern_w1[l]); ane_free(m->kern_w2[l]); ane_free(m->kern_w3[l]);
|
ane_free(m->kern_w1[l]); ane_free(m->kern_w2[l]); ane_free(m->kern_w3[l]);
|
||||||
m->kern_q[l]=m->kern_k[l]=m->kern_v[l]=m->kern_o[l]=NULL;
|
m->kern_q[l] = new_q[l]; m->kern_k[l] = new_k[l];
|
||||||
m->kern_w1[l]=m->kern_w2[l]=m->kern_w3[l]=NULL;
|
m->kern_v[l] = new_v[l]; m->kern_o[l] = new_o[l];
|
||||||
|
m->kern_w1[l] = new_w1[l]; m->kern_w2[l] = new_w2[l]; m->kern_w3[l] = new_w3[l];
|
||||||
}
|
}
|
||||||
if (m->kern_cls) { ane_free(m->kern_cls); m->kern_cls=NULL; }
|
if (m->kern_cls) ane_free(m->kern_cls);
|
||||||
// Phase 2: recompile all
|
m->kern_cls = new_cls; // may be NULL for large vocab — forward uses CPU fallback
|
||||||
for (int l = 0; l < N_LAYERS; l++) {
|
|
||||||
m->kern_q[l] = compile_conv_kernel(m->wq[l], d, d, S);
|
|
||||||
m->kern_k[l] = compile_conv_kernel(m->wk[l], d, d, S);
|
|
||||||
m->kern_v[l] = compile_conv_kernel(m->wv[l], d, d, S);
|
|
||||||
m->kern_o[l] = compile_conv_kernel(m->wo[l], d, d, S);
|
|
||||||
m->kern_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S);
|
|
||||||
m->kern_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S);
|
|
||||||
m->kern_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S);
|
|
||||||
if (!m->kern_q[l] || !m->kern_k[l] || !m->kern_v[l] || !m->kern_o[l] ||
|
|
||||||
!m->kern_w1[l] || !m->kern_w2[l] || !m->kern_w3[l]) return -1;
|
|
||||||
}
|
|
||||||
m->kern_cls = compile_conv_kernel(m->wcls, d, vs, S);
|
|
||||||
// cls may fail for large vocab — that's OK, forward uses CPU fallback
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
|
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "stories_config.h"
|
#include "stories_config.h"
|
||||||
|
#include <assert.h>
|
||||||
static float *g_rms_tmp = NULL;
|
|
||||||
|
|
||||||
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
|
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
|
||||||
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
|
float *rms_tmp = (float*)malloc(S * sizeof(float));
|
||||||
float *ss = (float*)calloc(S, sizeof(float));
|
float *ss = (float*)calloc(S, sizeof(float));
|
||||||
for (int i=0; i<d; i++) {
|
for (int i=0; i<d; i++) {
|
||||||
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(x+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
vDSP_vadd(rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
||||||
}
|
}
|
||||||
float invd = 1.0f/d, eps=1e-5f;
|
float invd = 1.0f/d, eps=1e-5f;
|
||||||
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
||||||
|
|
@ -18,15 +17,15 @@ static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
|
||||||
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
|
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
|
||||||
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
|
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
|
||||||
}
|
}
|
||||||
free(ss);
|
free(ss); free(rms_tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
|
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
|
||||||
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
|
float *rms_tmp = (float*)malloc(S * sizeof(float));
|
||||||
float *ss = (float*)calloc(S, sizeof(float));
|
float *ss = (float*)calloc(S, sizeof(float));
|
||||||
for (int i=0; i<d; i++) {
|
for (int i=0; i<d; i++) {
|
||||||
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(x+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
vDSP_vadd(rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
||||||
}
|
}
|
||||||
float invd = 1.0f/d, eps=1e-5f;
|
float invd = 1.0f/d, eps=1e-5f;
|
||||||
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
||||||
|
|
@ -34,23 +33,23 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c
|
||||||
int n = S; vvrsqrtf(rrms, ss, &n);
|
int n = S; vvrsqrtf(rrms, ss, &n);
|
||||||
float *dot = (float*)calloc(S, sizeof(float));
|
float *dot = (float*)calloc(S, sizeof(float));
|
||||||
for (int i=0; i<d; i++) {
|
for (int i=0; i<d; i++) {
|
||||||
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(dy+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
|
vDSP_vsma(rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
|
||||||
}
|
}
|
||||||
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
|
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
|
||||||
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
|
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
|
||||||
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
|
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
|
||||||
for (int i=0; i<d; i++) {
|
for (int i=0; i<d; i++) {
|
||||||
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(x+i*S, 1, dot, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vsub(rms_tmp, 1, dy+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(rms_tmp, 1, rrms, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
|
vDSP_vsmul(rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
|
||||||
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(dy+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
vDSP_vmul(rms_tmp, 1, rrms, 1, rms_tmp, 1, (vDSP_Length)S);
|
||||||
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
|
float s; vDSP_sve(rms_tmp, 1, &s, (vDSP_Length)S);
|
||||||
dw[i] += s;
|
dw[i] += s;
|
||||||
}
|
}
|
||||||
free(ss); free(rrms); free(dot);
|
free(ss); free(rrms); free(dot); free(rms_tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
|
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
|
||||||
|
|
@ -96,6 +95,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
|
||||||
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
|
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
|
||||||
// loss
|
// loss
|
||||||
int tgt = targets[t];
|
int tgt = targets[t];
|
||||||
|
assert(tgt >= 0 && tgt < V && "target token ID out of vocab range");
|
||||||
total_loss -= logf(row[tgt] + 1e-10f);
|
total_loss -= logf(row[tgt] + 1e-10f);
|
||||||
// gradient: softmax - one_hot, then /S
|
// gradient: softmax - one_hot, then /S
|
||||||
row[tgt] -= 1.0f;
|
row[tgt] -= 1.0f;
|
||||||
|
|
@ -112,6 +112,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
|
||||||
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
|
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
|
||||||
for (int t = 0; t < seq; t++) {
|
for (int t = 0; t < seq; t++) {
|
||||||
int tok = tokens[t];
|
int tok = tokens[t];
|
||||||
|
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
|
||||||
for (int d = 0; d < dim; d++) {
|
for (int d = 0; d < dim; d++) {
|
||||||
x[d*seq + t] = embed[tok*dim + d];
|
x[d*seq + t] = embed[tok*dim + d];
|
||||||
}
|
}
|
||||||
|
|
@ -122,6 +123,7 @@ static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, i
|
||||||
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
|
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
|
||||||
for (int t = 0; t < seq; t++) {
|
for (int t = 0; t < seq; t++) {
|
||||||
int tok = tokens[t];
|
int tok = tokens[t];
|
||||||
|
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
|
||||||
for (int d = 0; d < dim; d++) {
|
for (int d = 0; d < dim; d++) {
|
||||||
d_embed[tok*dim + d] += dx[d*seq + t];
|
d_embed[tok*dim + d] += dx[d*seq + t];
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,7 @@ static void free_kern(Kern *k) {
|
||||||
free(k);
|
free(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ch, int sp) {
|
static bool ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ch, int sp) {
|
||||||
float *tmp = (float*)malloc(in_ch * sp * sizeof(float));
|
float *tmp = (float*)malloc(in_ch * sp * sizeof(float));
|
||||||
for (int t = 0; t < sp; t++)
|
for (int t = 0; t < sp; t++)
|
||||||
for (int c = 0; c < in_ch; c++)
|
for (int c = 0; c < in_ch; c++)
|
||||||
|
|
@ -151,8 +151,13 @@ static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_
|
||||||
NSError *e = nil;
|
NSError *e = nil;
|
||||||
id mdl = (__bridge id)k->model;
|
id mdl = (__bridge id)k->model;
|
||||||
id req = (__bridge id)k->request;
|
id req = (__bridge id)k->request;
|
||||||
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
|
BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
|
||||||
mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
|
mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
|
||||||
|
if (!ok) {
|
||||||
|
fprintf(stderr, "ANE eval failed: %s\n",
|
||||||
|
e ? [[e description] UTF8String] : "unknown error");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
float *tmp2 = (float*)malloc(out_ch * sp * sizeof(float));
|
float *tmp2 = (float*)malloc(out_ch * sp * sizeof(float));
|
||||||
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
|
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
|
||||||
memcpy(tmp2, IOSurfaceGetBaseAddress(k->ioOut), out_ch * sp * sizeof(float));
|
memcpy(tmp2, IOSurfaceGetBaseAddress(k->ioOut), out_ch * sp * sizeof(float));
|
||||||
|
|
@ -161,6 +166,7 @@ static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_
|
||||||
for (int c = 0; c < out_ch; c++)
|
for (int c = 0; c < out_ch; c++)
|
||||||
out[t*out_ch + c] = tmp2[c*sp + t];
|
out[t*out_ch + c] = tmp2[c*sp + t];
|
||||||
free(tmp2);
|
free(tmp2);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Checkpoint: save/restore training state for exec() restart ===
|
// === Checkpoint: save/restore training state for exec() restart ===
|
||||||
|
|
@ -179,21 +185,25 @@ static void save_checkpoint(const char *path, int step, float loss,
|
||||||
int D, int H, int S, int total_steps, float lr,
|
int D, int H, int S, int total_steps, float lr,
|
||||||
const float *W1, const float *W2,
|
const float *W1, const float *W2,
|
||||||
double cc, double ct, double cw, int cs, int cb) {
|
double cc, double ct, double cw, int cs, int cb) {
|
||||||
FILE *f = fopen(path, "wb");
|
char tmp_path[512];
|
||||||
|
snprintf(tmp_path, sizeof(tmp_path), "%s.tmp", path);
|
||||||
|
FILE *f = fopen(tmp_path, "wb");
|
||||||
|
if (!f) { fprintf(stderr, "Failed to open %s for checkpoint\n", tmp_path); return; }
|
||||||
CkptHeader hdr = {step, loss, D, H, S, total_steps, lr, cc, ct, cw, cs, cb};
|
CkptHeader hdr = {step, loss, D, H, S, total_steps, lr, cc, ct, cw, cs, cb};
|
||||||
fwrite(&hdr, sizeof(hdr), 1, f);
|
fwrite(&hdr, sizeof(hdr), 1, f);
|
||||||
fwrite(W1, sizeof(float), H * D, f);
|
fwrite(W1, sizeof(float), H * D, f);
|
||||||
fwrite(W2, sizeof(float), D * H, f);
|
fwrite(W2, sizeof(float), D * H, f);
|
||||||
fclose(f);
|
fclose(f);
|
||||||
|
rename(tmp_path, path); // atomic on POSIX
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool load_checkpoint(const char *path, CkptHeader *hdr,
|
static bool load_checkpoint(const char *path, CkptHeader *hdr,
|
||||||
float *W1, float *W2, int H, int D) {
|
float *W1, float *W2, int H, int D) {
|
||||||
FILE *f = fopen(path, "rb");
|
FILE *f = fopen(path, "rb");
|
||||||
if (!f) return false;
|
if (!f) return false;
|
||||||
fread(hdr, sizeof(CkptHeader), 1, f);
|
if (fread(hdr, sizeof(CkptHeader), 1, f) != 1) { fclose(f); return false; }
|
||||||
fread(W1, sizeof(float), H * D, f);
|
if (fread(W1, sizeof(float), H * D, f) != (size_t)(H * D)) { fclose(f); return false; }
|
||||||
fread(W2, sizeof(float), D * H, f);
|
if (fread(W2, sizeof(float), D * H, f) != (size_t)(D * H)) { fclose(f); return false; }
|
||||||
fclose(f);
|
fclose(f);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue