fix: correctness & safety improvements

- Validate all fread() return values in model_load_weights (model.h)
- Check ane_eval() return values in ane_conv_eval (forward.h) and ane_eval_k (tiny_train.m)
- Log error details on ANE eval failure (ane_runtime.h)
- Thread-safe RMSNorm: replace global g_rms_tmp with local allocation (stories_cpu_ops.h)
- Bounds-check token indices in cross_entropy_loss, embed_lookup, embed_backward
- Atomic checkpoint writes via tmp+rename pattern (tiny_train.m)
- Non-destructive recompile: compile new kernels first, swap only on success (model.h)
- Validate fread() in load_checkpoint (tiny_train.m)
This commit is contained in:
Alvaro GPT 2026-03-02 23:10:00 +01:00
parent 443194bca4
commit 541bf4ec90
5 changed files with 111 additions and 60 deletions

View File

@ -141,9 +141,14 @@ static void ane_read_output(ANEKernel *k, int idx, void *data, size_t bytes) {
static bool ane_eval(ANEKernel *k) {
NSError *e = nil;
return ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
k->model, @selector(evaluateWithQoS:options:request:error:),
21, @{}, k->request, &e);
if (!ok) {
fprintf(stderr, "ANE eval failed: %s\n",
e ? [[e description] UTF8String] : "unknown error");
}
return ok;
}
static void ane_free(ANEKernel *k) {

View File

@ -7,7 +7,7 @@
// ANE conv eval: input [S, in_dim] row-major → transpose to [in_dim, S] channels-first
// ANE computes conv(W, x) with baked W → output [out_dim, S]
// Transpose back to [S, out_dim] row-major
static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
static bool ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
int S, int in_dim, int out_dim) {
float *x_t = (float*)malloc(S * in_dim * sizeof(float));
for (int t = 0; t < S; t++)
@ -15,7 +15,11 @@ static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
x_t[i*S + t] = x[t*in_dim + i];
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float));
ane_eval(kernel);
bool ok = ane_eval(kernel);
if (!ok) {
free(x_t);
return false;
}
float *y_t = (float*)malloc(S * out_dim * sizeof(float));
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float));
@ -25,6 +29,7 @@ static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
y[t*out_dim + i] = y_t[i*S + t];
free(x_t); free(y_t);
return true;
}
// CPU matmul fallback: y = W @ x, W[out_dim, in_dim], x[S, in_dim] → y[S, out_dim]

View File

@ -78,7 +78,10 @@ typedef struct {
static int model_load_weights(Model *m, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
fread(&m->cfg, sizeof(Config), 1, f);
if (fread(&m->cfg, sizeof(Config), 1, f) != 1) {
fprintf(stderr, "ERROR: failed to read config from %s\n", path);
fclose(f); return -1;
}
bool shared = m->cfg.vocab_size > 0;
if (m->cfg.vocab_size < 0) m->cfg.vocab_size = -m->cfg.vocab_size;
@ -89,7 +92,10 @@ static int model_load_weights(Model *m, const char *path) {
int d = m->cfg.dim, hd = m->cfg.hidden_dim, nl = m->cfg.n_layers, vs = m->cfg.vocab_size;
m->token_embedding = (float*)malloc(vs * d * sizeof(float));
fread(m->token_embedding, sizeof(float), vs * d, f);
if (fread(m->token_embedding, sizeof(float), vs * d, f) != (size_t)(vs * d)) {
fprintf(stderr, "ERROR: short read on token_embedding (file truncated?)\n");
fclose(f); return -1;
}
float *rms_att_all = (float*)malloc(nl * d * sizeof(float));
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
@ -101,15 +107,24 @@ static int model_load_weights(Model *m, const char *path) {
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));
fread(rms_att_all, sizeof(float), nl * d, f);
fread(wq_all, sizeof(float), nl * d * d, f);
fread(wk_all, sizeof(float), nl * d * d, f);
fread(wv_all, sizeof(float), nl * d * d, f);
fread(wo_all, sizeof(float), nl * d * d, f);
fread(rms_ffn_all, sizeof(float), nl * d, f);
fread(w1_all, sizeof(float), nl * hd * d, f);
fread(w2_all, sizeof(float), nl * d * hd, f);
fread(w3_all, sizeof(float), nl * hd * d, f);
#define FREAD_CHECK(buf, count, file, label) do { \
size_t _n = fread(buf, sizeof(float), count, file); \
if (_n != (size_t)(count)) { \
fprintf(stderr, "ERROR: short read on %s: got %zu, expected %zu (file truncated?)\n", \
label, _n, (size_t)(count)); \
fclose(file); return -1; \
} \
} while(0)
FREAD_CHECK(rms_att_all, nl * d, f, "rms_att");
FREAD_CHECK(wq_all, nl * d * d, f, "wq");
FREAD_CHECK(wk_all, nl * d * d, f, "wk");
FREAD_CHECK(wv_all, nl * d * d, f, "wv");
FREAD_CHECK(wo_all, nl * d * d, f, "wo");
FREAD_CHECK(rms_ffn_all, nl * d, f, "rms_ffn");
FREAD_CHECK(w1_all, nl * hd * d, f, "w1");
FREAD_CHECK(w2_all, nl * d * hd, f, "w2");
FREAD_CHECK(w3_all, nl * hd * d, f, "w3");
for (int l = 0; l < nl; l++) {
m->rms_att_w[l] = (float*)malloc(d * sizeof(float));
@ -135,14 +150,15 @@ static int model_load_weights(Model *m, const char *path) {
free(rms_ffn_all); free(w1_all); free(w2_all); free(w3_all);
m->rms_final_w = (float*)malloc(d * sizeof(float));
fread(m->rms_final_w, sizeof(float), d, f);
FREAD_CHECK(m->rms_final_w, d, f, "rms_final");
if (shared) {
m->wcls = m->token_embedding;
} else {
m->wcls = (float*)malloc(vs * d * sizeof(float));
fread(m->wcls, sizeof(float), vs * d, f);
FREAD_CHECK(m->wcls, vs * d, f, "wcls");
}
#undef FREAD_CHECK
fclose(f);
return 0;
}
@ -188,32 +204,45 @@ static int model_compile_kernels(Model *m, int seq_len) {
return 0;
}
// Recompile all kernels after weight update — unload all first to avoid ANE model limit
// Recompile all kernels after weight update — compile new first, then swap
static int model_recompile_kernels(Model *m) {
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size;
int S = m->seq_len;
// Phase 1: unload+free all
// Phase 1: compile new kernels into temporaries
ANEKernel *new_q[N_LAYERS], *new_k[N_LAYERS], *new_v[N_LAYERS], *new_o[N_LAYERS];
ANEKernel *new_w1[N_LAYERS], *new_w2[N_LAYERS], *new_w3[N_LAYERS];
for (int l = 0; l < N_LAYERS; l++) {
new_q[l] = compile_conv_kernel(m->wq[l], d, d, S);
new_k[l] = compile_conv_kernel(m->wk[l], d, d, S);
new_v[l] = compile_conv_kernel(m->wv[l], d, d, S);
new_o[l] = compile_conv_kernel(m->wo[l], d, d, S);
new_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S);
new_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S);
new_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S);
if (!new_q[l] || !new_k[l] || !new_v[l] || !new_o[l] ||
!new_w1[l] || !new_w2[l] || !new_w3[l]) {
// Cleanup partially compiled new kernels
for (int i = 0; i <= l; i++) {
ane_free(new_q[i]); ane_free(new_k[i]); ane_free(new_v[i]); ane_free(new_o[i]);
ane_free(new_w1[i]); ane_free(new_w2[i]); ane_free(new_w3[i]);
}
fprintf(stderr, "Recompile failed at layer %d, keeping old kernels\n", l);
return -1;
}
}
ANEKernel *new_cls = compile_conv_kernel(m->wcls, d, vs, S);
// Phase 2: all compiles succeeded — swap and free old
for (int l = 0; l < N_LAYERS; l++) {
ane_free(m->kern_q[l]); ane_free(m->kern_k[l]); ane_free(m->kern_v[l]); ane_free(m->kern_o[l]);
ane_free(m->kern_w1[l]); ane_free(m->kern_w2[l]); ane_free(m->kern_w3[l]);
m->kern_q[l]=m->kern_k[l]=m->kern_v[l]=m->kern_o[l]=NULL;
m->kern_w1[l]=m->kern_w2[l]=m->kern_w3[l]=NULL;
m->kern_q[l] = new_q[l]; m->kern_k[l] = new_k[l];
m->kern_v[l] = new_v[l]; m->kern_o[l] = new_o[l];
m->kern_w1[l] = new_w1[l]; m->kern_w2[l] = new_w2[l]; m->kern_w3[l] = new_w3[l];
}
if (m->kern_cls) { ane_free(m->kern_cls); m->kern_cls=NULL; }
// Phase 2: recompile all
for (int l = 0; l < N_LAYERS; l++) {
m->kern_q[l] = compile_conv_kernel(m->wq[l], d, d, S);
m->kern_k[l] = compile_conv_kernel(m->wk[l], d, d, S);
m->kern_v[l] = compile_conv_kernel(m->wv[l], d, d, S);
m->kern_o[l] = compile_conv_kernel(m->wo[l], d, d, S);
m->kern_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S);
m->kern_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S);
m->kern_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S);
if (!m->kern_q[l] || !m->kern_k[l] || !m->kern_v[l] || !m->kern_o[l] ||
!m->kern_w1[l] || !m->kern_w2[l] || !m->kern_w3[l]) return -1;
}
m->kern_cls = compile_conv_kernel(m->wcls, d, vs, S);
// cls may fail for large vocab — that's OK, forward uses CPU fallback
if (m->kern_cls) ane_free(m->kern_cls);
m->kern_cls = new_cls; // may be NULL for large vocab — forward uses CPU fallback
return 0;
}

View File

@ -1,15 +1,14 @@
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
#pragma once
#include "stories_config.h"
static float *g_rms_tmp = NULL;
#include <assert.h>
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *rms_tmp = (float*)malloc(S * sizeof(float));
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
vDSP_vmul(x+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
@ -18,15 +17,15 @@ static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
}
free(ss);
free(ss); free(rms_tmp);
}
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *rms_tmp = (float*)malloc(S * sizeof(float));
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
vDSP_vmul(x+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
@ -34,23 +33,23 @@ static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, c
int n = S; vvrsqrtf(rrms, ss, &n);
float *dot = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
vDSP_vmul(dy+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vsma(rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
}
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
vDSP_vmul(x+i*S, 1, dot, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vsub(rms_tmp, 1, dy+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(rms_tmp, 1, rrms, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vsmul(rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
vDSP_vmul(dy+i*S, 1, x+i*S, 1, rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(rms_tmp, 1, rrms, 1, rms_tmp, 1, (vDSP_Length)S);
float s; vDSP_sve(rms_tmp, 1, &s, (vDSP_Length)S);
dw[i] += s;
}
free(ss); free(rrms); free(dot);
free(ss); free(rrms); free(dot); free(rms_tmp);
}
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
@ -96,6 +95,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
// loss
int tgt = targets[t];
assert(tgt >= 0 && tgt < V && "target token ID out of vocab range");
total_loss -= logf(row[tgt] + 1e-10f);
// gradient: softmax - one_hot, then /S
row[tgt] -= 1.0f;
@ -112,6 +112,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
for (int d = 0; d < dim; d++) {
x[d*seq + t] = embed[tok*dim + d];
}
@ -122,6 +123,7 @@ static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, i
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
for (int d = 0; d < dim; d++) {
d_embed[tok*dim + d] += dx[d*seq + t];
}

View File

@ -139,7 +139,7 @@ static void free_kern(Kern *k) {
free(k);
}
static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ch, int sp) {
static bool ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ch, int sp) {
float *tmp = (float*)malloc(in_ch * sp * sizeof(float));
for (int t = 0; t < sp; t++)
for (int c = 0; c < in_ch; c++)
@ -151,8 +151,13 @@ static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_
NSError *e = nil;
id mdl = (__bridge id)k->model;
id req = (__bridge id)k->request;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(
mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
if (!ok) {
fprintf(stderr, "ANE eval failed: %s\n",
e ? [[e description] UTF8String] : "unknown error");
return false;
}
float *tmp2 = (float*)malloc(out_ch * sp * sizeof(float));
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
memcpy(tmp2, IOSurfaceGetBaseAddress(k->ioOut), out_ch * sp * sizeof(float));
@ -161,6 +166,7 @@ static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_
for (int c = 0; c < out_ch; c++)
out[t*out_ch + c] = tmp2[c*sp + t];
free(tmp2);
return true;
}
// === Checkpoint: save/restore training state for exec() restart ===
@ -179,21 +185,25 @@ static void save_checkpoint(const char *path, int step, float loss,
int D, int H, int S, int total_steps, float lr,
const float *W1, const float *W2,
double cc, double ct, double cw, int cs, int cb) {
FILE *f = fopen(path, "wb");
char tmp_path[512];
snprintf(tmp_path, sizeof(tmp_path), "%s.tmp", path);
FILE *f = fopen(tmp_path, "wb");
if (!f) { fprintf(stderr, "Failed to open %s for checkpoint\n", tmp_path); return; }
CkptHeader hdr = {step, loss, D, H, S, total_steps, lr, cc, ct, cw, cs, cb};
fwrite(&hdr, sizeof(hdr), 1, f);
fwrite(W1, sizeof(float), H * D, f);
fwrite(W2, sizeof(float), D * H, f);
fclose(f);
rename(tmp_path, path); // atomic on POSIX
}
static bool load_checkpoint(const char *path, CkptHeader *hdr,
float *W1, float *W2, int H, int D) {
FILE *f = fopen(path, "rb");
if (!f) return false;
fread(hdr, sizeof(CkptHeader), 1, f);
fread(W1, sizeof(float), H * D, f);
fread(W2, sizeof(float), D * H, f);
if (fread(hdr, sizeof(CkptHeader), 1, f) != 1) { fclose(f); return false; }
if (fread(W1, sizeof(float), H * D, f) != (size_t)(H * D)) { fclose(f); return false; }
if (fread(W2, sizeof(float), D * H, f) != (size_t)(D * H)) { fclose(f); return false; }
fclose(f);
return true;
}