diff --git a/training/ane_runtime.h b/training/ane_runtime.h index 585d0f0..58bcb79 100644 --- a/training/ane_runtime.h +++ b/training/ane_runtime.h @@ -141,9 +141,14 @@ static void ane_read_output(ANEKernel *k, int idx, void *data, size_t bytes) { static bool ane_eval(ANEKernel *k) { NSError *e = nil; - return ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)( + BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)( k->model, @selector(evaluateWithQoS:options:request:error:), 21, @{}, k->request, &e); + if (!ok) { + fprintf(stderr, "ANE eval failed: %s\n", + e ? [[e description] UTF8String] : "unknown error"); + } + return ok; } static void ane_free(ANEKernel *k) { diff --git a/training/forward.h b/training/forward.h index adcf898..1a2a31f 100644 --- a/training/forward.h +++ b/training/forward.h @@ -7,7 +7,7 @@ // ANE conv eval: input [S, in_dim] row-major → transpose to [in_dim, S] channels-first // ANE computes conv(W, x) with baked W → output [out_dim, S] // Transpose back to [S, out_dim] row-major -static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y, +static bool ane_conv_eval(ANEKernel *kernel, const float *x, float *y, int S, int in_dim, int out_dim) { float *x_t = (float*)malloc(S * in_dim * sizeof(float)); for (int t = 0; t < S; t++) @@ -15,7 +15,11 @@ static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y, x_t[i*S + t] = x[t*in_dim + i]; ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float)); - ane_eval(kernel); + bool ok = ane_eval(kernel); + if (!ok) { + free(x_t); + return false; + } float *y_t = (float*)malloc(S * out_dim * sizeof(float)); ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float)); @@ -25,6 +29,7 @@ static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y, y[t*out_dim + i] = y_t[i*S + t]; free(x_t); free(y_t); + return true; } // CPU matmul fallback: y = W @ x, W[out_dim, in_dim], x[S, in_dim] → y[S, out_dim] diff --git a/training/model.h b/training/model.h index 6cee52f..4e68ebc 100644 --- a/training/model.h +++ b/training/model.h @@ -78,7 +78,10 @@ typedef struct { static int model_load_weights(Model *m, const char *path) { FILE *f = fopen(path, "rb"); if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; } - fread(&m->cfg, sizeof(Config), 1, f); + if (fread(&m->cfg, sizeof(Config), 1, f) != 1) { + fprintf(stderr, "ERROR: failed to read config from %s\n", path); + fclose(f); return -1; + } bool shared = m->cfg.vocab_size > 0; if (m->cfg.vocab_size < 0) m->cfg.vocab_size = -m->cfg.vocab_size; @@ -89,7 +92,10 @@ static int model_load_weights(Model *m, const char *path) { int d = m->cfg.dim, hd = m->cfg.hidden_dim, nl = m->cfg.n_layers, vs = m->cfg.vocab_size; m->token_embedding = (float*)malloc(vs * d * sizeof(float)); - fread(m->token_embedding, sizeof(float), vs * d, f); + if (fread(m->token_embedding, sizeof(float), vs * d, f) != (size_t)(vs * d)) { + fprintf(stderr, "ERROR: short read on token_embedding (file truncated?)\n"); + fclose(f); return -1; + } float *rms_att_all = (float*)malloc(nl * d * sizeof(float)); float *wq_all = (float*)malloc(nl * d * d * sizeof(float)); @@ -101,15 +107,24 @@ static int model_load_weights(Model *m, const char *path) { float *w2_all = (float*)malloc(nl * d * hd * sizeof(float)); float *w3_all = (float*)malloc(nl * hd * d * sizeof(float)); - fread(rms_att_all, sizeof(float), nl * d, f); - fread(wq_all, sizeof(float), nl * d * d, f); - fread(wk_all, sizeof(float), nl * d * d, f); - fread(wv_all, sizeof(float), nl * d * d, f); - fread(wo_all, sizeof(float), nl * d * d, f); - fread(rms_ffn_all, sizeof(float), nl * d, f); - fread(w1_all, sizeof(float), nl * hd * d, f); - fread(w2_all, sizeof(float), nl * d * hd, f); - fread(w3_all, sizeof(float), nl * hd * d, f); + #define FREAD_CHECK(buf, count, file, label) do { \ + size_t _n = fread(buf, sizeof(float), count, file); \ + if (_n != (size_t)(count)) { \ + fprintf(stderr, "ERROR: short read on %s: got %zu, expected %zu (file truncated?)\n", \ + label, _n, (size_t)(count)); \ + fclose(file); return -1; \ + } \ + } while(0) + + FREAD_CHECK(rms_att_all, nl * d, f, "rms_att"); + FREAD_CHECK(wq_all, nl * d * d, f, "wq"); + FREAD_CHECK(wk_all, nl * d * d, f, "wk"); + FREAD_CHECK(wv_all, nl * d * d, f, "wv"); + FREAD_CHECK(wo_all, nl * d * d, f, "wo"); + FREAD_CHECK(rms_ffn_all, nl * d, f, "rms_ffn"); + FREAD_CHECK(w1_all, nl * hd * d, f, "w1"); + FREAD_CHECK(w2_all, nl * d * hd, f, "w2"); + FREAD_CHECK(w3_all, nl * hd * d, f, "w3"); for (int l = 0; l < nl; l++) { m->rms_att_w[l] = (float*)malloc(d * sizeof(float)); @@ -135,14 +150,15 @@ static int model_load_weights(Model *m, const char *path) { free(rms_ffn_all); free(w1_all); free(w2_all); free(w3_all); m->rms_final_w = (float*)malloc(d * sizeof(float)); - fread(m->rms_final_w, sizeof(float), d, f); + FREAD_CHECK(m->rms_final_w, d, f, "rms_final"); if (shared) { m->wcls = m->token_embedding; } else { m->wcls = (float*)malloc(vs * d * sizeof(float)); - fread(m->wcls, sizeof(float), vs * d, f); + FREAD_CHECK(m->wcls, vs * d, f, "wcls"); } + #undef FREAD_CHECK fclose(f); return 0; } @@ -188,32 +204,45 @@ static int model_compile_kernels(Model *m, int seq_len) { return 0; } -// Recompile all kernels after weight update — unload all first to avoid ANE model limit +// Recompile all kernels after weight update — compile new first, then swap static int model_recompile_kernels(Model *m) { int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size; int S = m->seq_len; - // Phase 1: unload+free all + + // Phase 1: compile new kernels into temporaries + ANEKernel *new_q[N_LAYERS], *new_k[N_LAYERS], *new_v[N_LAYERS], *new_o[N_LAYERS]; + ANEKernel *new_w1[N_LAYERS], *new_w2[N_LAYERS], *new_w3[N_LAYERS]; + for (int l = 0; l < N_LAYERS; l++) { + new_q[l] = compile_conv_kernel(m->wq[l], d, d, S); + new_k[l] = compile_conv_kernel(m->wk[l], d, d, S); + new_v[l] = compile_conv_kernel(m->wv[l], d, d, S); + new_o[l] = compile_conv_kernel(m->wo[l], d, d, S); + new_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S); + new_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S); + new_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S); + if (!new_q[l] || !new_k[l] || !new_v[l] || !new_o[l] || + !new_w1[l] || !new_w2[l] || !new_w3[l]) { + // Cleanup partially compiled new kernels + for (int i = 0; i <= l; i++) { + ane_free(new_q[i]); ane_free(new_k[i]); ane_free(new_v[i]); ane_free(new_o[i]); + ane_free(new_w1[i]); ane_free(new_w2[i]); ane_free(new_w3[i]); + } + fprintf(stderr, "Recompile failed at layer %d, keeping old kernels\n", l); + return -1; + } + } + ANEKernel *new_cls = compile_conv_kernel(m->wcls, d, vs, S); + + // Phase 2: all compiles succeeded — swap and free old for (int l = 0; l < N_LAYERS; l++) { ane_free(m->kern_q[l]); ane_free(m->kern_k[l]); ane_free(m->kern_v[l]); ane_free(m->kern_o[l]); ane_free(m->kern_w1[l]); ane_free(m->kern_w2[l]); ane_free(m->kern_w3[l]); - m->kern_q[l]=m->kern_k[l]=m->kern_v[l]=m->kern_o[l]=NULL; - m->kern_w1[l]=m->kern_w2[l]=m->kern_w3[l]=NULL; + m->kern_q[l] = new_q[l]; m->kern_k[l] = new_k[l]; + m->kern_v[l] = new_v[l]; m->kern_o[l] = new_o[l]; + m->kern_w1[l] = new_w1[l]; m->kern_w2[l] = new_w2[l]; m->kern_w3[l] = new_w3[l]; } - if (m->kern_cls) { ane_free(m->kern_cls); m->kern_cls=NULL; } - // Phase 2: recompile all - for (int l = 0; l < N_LAYERS; l++) { - m->kern_q[l] = compile_conv_kernel(m->wq[l], d, d, S); - m->kern_k[l] = compile_conv_kernel(m->wk[l], d, d, S); - m->kern_v[l] = compile_conv_kernel(m->wv[l], d, d, S); - m->kern_o[l] = compile_conv_kernel(m->wo[l], d, d, S); - m->kern_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S); - m->kern_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S); - m->kern_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S); - if (!m->kern_q[l] || !m->kern_k[l] || !m->kern_v[l] || !m->kern_o[l] || - !m->kern_w1[l] || !m->kern_w2[l] || !m->kern_w3[l]) return -1; - } - m->kern_cls = compile_conv_kernel(m->wcls, d, vs, S); - // cls may fail for large vocab — that's OK, forward uses CPU fallback + if (m->kern_cls) ane_free(m->kern_cls); + m->kern_cls = new_cls; // may be NULL for large vocab — forward uses CPU fallback return 0; } diff --git a/training/stories_cpu_ops.h b/training/stories_cpu_ops.h index c9f2cfa..ae4dfdf 100644 --- a/training/stories_cpu_ops.h +++ b/training/stories_cpu_ops.h @@ -1,15 +1,14 @@ // stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax #pragma once #include "stories_config.h" - -static float *g_rms_tmp = NULL; +#include static void rmsnorm(float *out, const float *x, const float *w, int d, int S) { - if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4); + float *rms_tmp = (float*)malloc(S * sizeof(float)); float *ss = (float*)calloc(S, sizeof(float)); for (int i=0; i= 0 && tgt < V && "target token ID out of vocab range"); total_loss -= logf(row[tgt] + 1e-10f); // gradient: softmax - one_hot, then /S row[tgt] -= 1.0f; @@ -112,6 +112,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1 static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) { for (int t = 0; t < seq; t++) { int tok = tokens[t]; + assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range"); for (int d = 0; d < dim; d++) { x[d*seq + t] = embed[tok*dim + d]; } @@ -122,6 +123,7 @@ static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, i static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) { for (int t = 0; t < seq; t++) { int tok = tokens[t]; + assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range"); for (int d = 0; d < dim; d++) { d_embed[tok*dim + d] += dx[d*seq + t]; } diff --git a/training/tiny_train.m b/training/tiny_train.m index e1e9d7d..0449dba 100644 --- a/training/tiny_train.m +++ b/training/tiny_train.m @@ -139,7 +139,7 @@ static void free_kern(Kern *k) { free(k); } -static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ch, int sp) { +static bool ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ch, int sp) { float *tmp = (float*)malloc(in_ch * sp * sizeof(float)); for (int t = 0; t < sp; t++) for (int c = 0; c < in_ch; c++) @@ -151,8 +151,13 @@ static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ NSError *e = nil; id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; - ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)( + BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)( mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); + if (!ok) { + fprintf(stderr, "ANE eval failed: %s\n", + e ? [[e description] UTF8String] : "unknown error"); + return false; + } float *tmp2 = (float*)malloc(out_ch * sp * sizeof(float)); IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL); memcpy(tmp2, IOSurfaceGetBaseAddress(k->ioOut), out_ch * sp * sizeof(float)); @@ -161,6 +166,7 @@ static void ane_eval_k(Kern *k, const float *in, float *out, int in_ch, int out_ for (int c = 0; c < out_ch; c++) out[t*out_ch + c] = tmp2[c*sp + t]; free(tmp2); + return true; } // === Checkpoint: save/restore training state for exec() restart === @@ -179,21 +185,25 @@ static void save_checkpoint(const char *path, int step, float loss, int D, int H, int S, int total_steps, float lr, const float *W1, const float *W2, double cc, double ct, double cw, int cs, int cb) { - FILE *f = fopen(path, "wb"); + char tmp_path[512]; + snprintf(tmp_path, sizeof(tmp_path), "%s.tmp", path); + FILE *f = fopen(tmp_path, "wb"); + if (!f) { fprintf(stderr, "Failed to open %s for checkpoint\n", tmp_path); return; } CkptHeader hdr = {step, loss, D, H, S, total_steps, lr, cc, ct, cw, cs, cb}; fwrite(&hdr, sizeof(hdr), 1, f); fwrite(W1, sizeof(float), H * D, f); fwrite(W2, sizeof(float), D * H, f); fclose(f); + rename(tmp_path, path); // atomic on POSIX } static bool load_checkpoint(const char *path, CkptHeader *hdr, float *W1, float *W2, int H, int D) { FILE *f = fopen(path, "rb"); if (!f) return false; - fread(hdr, sizeof(CkptHeader), 1, f); - fread(W1, sizeof(float), H * D, f); - fread(W2, sizeof(float), D * H, f); + if (fread(hdr, sizeof(CkptHeader), 1, f) != 1) { fclose(f); return false; } + if (fread(W1, sizeof(float), H * D, f) != (size_t)(H * D)) { fclose(f); return false; } + if (fread(W2, sizeof(float), D * H, f) != (size_t)(D * H)) { fclose(f); return false; } fclose(f); return true; }