From 380237af1f01fea7a9db8820accd25e2835ef9e9 Mon Sep 17 00:00:00 2001 From: Erik Bray Date: Tue, 3 Mar 2026 14:21:53 +0100 Subject: [PATCH] [fix] Token sampling underflow fix (upstream PR #17): prevent size_t wraparound on short datasets in both train_large variants --- training/train_large.m | 36 +++++++++++++++++++++++++++++++----- training/train_large_ane.m | 23 ++++++++++++++++++++--- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/training/train_large.m b/training/train_large.m index e58ce08..f71bf52 100644 --- a/training/train_large.m +++ b/training/train_large.m @@ -5,16 +5,29 @@ #include "stories_mil.h" #include "stories_cpu_ops.h" -#define CKPT_PATH "ane_stories110M_ckpt.bin" -#define MODEL_PATH "../../assets/models/stories110M.bin" -#define DATA_PATH "tinystories_data00.bin" +#define DEFAULT_CKPT_PATH "ane_stories110M_ckpt.bin" +#define DEFAULT_MODEL_PATH "../../assets/models/stories110M.bin" +#define DEFAULT_DATA_PATH "tinystories_data00.bin" + +static const char *get_path(const char *env_var, const char *default_val) { + const char *v = getenv(env_var); + return (v && v[0]) ? v : default_val; +} + +#define CKPT_PATH get_path("ANE_CKPT_PATH", DEFAULT_CKPT_PATH) +#define MODEL_PATH get_path("ANE_MODEL_PATH", DEFAULT_MODEL_PATH) +#define DATA_PATH get_path("ANE_DATA_PATH", DEFAULT_DATA_PATH) // ===== Weight loading from llama2.c format ===== static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { FILE *f = fopen(path, "rb"); if (!f) { printf("Cannot open %s\n", path); return false; } Llama2Config cfg; - fread(&cfg, sizeof(cfg), 1, f); + // Validate config read — gatekeeper before any dimension-based logic (CRIT-03) + if (fread(&cfg, sizeof(cfg), 1, f) != 1) { + printf(" ERROR: Config read failed (truncated file?)\n"); + fclose(f); return false; + } printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n", cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len); if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) { @@ -112,6 +125,7 @@ static void save_checkpoint(const char *path, int step, int total_steps, float l LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final, float *embed, AdamState *aembed) { FILE *f = fopen(path, "wb"); + if (!f) { fprintf(stderr, "save_checkpoint: cannot open %s\n", path); return; } // CRIT-03 CkptHdr h = {0}; h.magic = 0x424C5A54; h.version = 2; h.step = step; h.total_steps = total_steps; @@ -152,7 +166,11 @@ static bool load_checkpoint(const char *path, int *step, int *total_steps, float FILE *f = fopen(path, "rb"); if (!f) return false; CkptHdr h; - fread(&h, sizeof(h), 1, f); + // Validate header read before magic-byte check (CRIT-03) + if (fread(&h, sizeof(h), 1, f) != 1) { + fprintf(stderr, "load_checkpoint: header read failed\n"); + fclose(f); return false; + } if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; } *step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss; *cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall; @@ -185,6 +203,7 @@ int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); + init_accum_steps(); mach_timebase_info(&g_tb); int total_steps = 10000; @@ -236,6 +255,7 @@ int main(int argc, char *argv[]) { if (!resuming) { printf("=== ANE Training: Stories110M (12 layers) ===\n"); printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS); + printf("model=%s data=%s ckpt=%s\n", MODEL_PATH, DATA_PATH, CKPT_PATH); if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) { printf("Pretrained load failed, using random init\n"); srand48(42); @@ -278,6 +298,12 @@ int main(int argc, char *argv[]) { uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } size_t n_tokens = data_len / 2; + if (n_tokens <= (size_t)(SEQ + 1)) { + printf("Token data too short: need at least %d tokens, got %zu\n", SEQ + 2, n_tokens); + munmap(token_data, data_len); + close(data_fd); + return 1; + } printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); // Gradient buffers shared across layers (reused each step) diff --git a/training/train_large_ane.m b/training/train_large_ane.m index d7a99ef..52b7dd8 100644 --- a/training/train_large_ane.m +++ b/training/train_large_ane.m @@ -16,9 +16,18 @@ #include "ane_rmsnorm_bwd.h" #include "ane_classifier.h" -#define CKPT_PATH "ane_stories110M_ckpt.bin" -#define MODEL_PATH "../../assets/models/stories110M.bin" -#define DATA_PATH "tinystories_data00.bin" +#define DEFAULT_CKPT_PATH "ane_stories110M_ckpt.bin" +#define DEFAULT_MODEL_PATH "../../assets/models/stories110M.bin" +#define DEFAULT_DATA_PATH "tinystories_data00.bin" + +static const char *get_path(const char *env_var, const char *default_val) { + const char *v = getenv(env_var); + return (v && v[0]) ? v : default_val; +} + +#define CKPT_PATH get_path("ANE_CKPT_PATH", DEFAULT_CKPT_PATH) +#define MODEL_PATH get_path("ANE_MODEL_PATH", DEFAULT_MODEL_PATH) +#define DATA_PATH get_path("ANE_DATA_PATH", DEFAULT_DATA_PATH) // ===== Weight loading from llama2.c format ===== static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { @@ -196,6 +205,7 @@ int main(int argc, char *argv[]) { @autoreleasepool { setbuf(stdout, NULL); ane_init(); + init_accum_steps(); mach_timebase_info(&g_tb); int total_steps = 10000; @@ -236,6 +246,7 @@ int main(int argc, char *argv[]) { if (!resuming) { printf("=== ANE Training: Stories110M (ANE-offloaded) ===\n"); printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS); + printf("model=%s data=%s ckpt=%s\n", MODEL_PATH, DATA_PATH, CKPT_PATH); printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n"); if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) { printf("Pretrained load failed, using random init\n"); @@ -263,6 +274,12 @@ int main(int argc, char *argv[]) { uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } size_t n_tokens = data_len / 2; + if (n_tokens <= (size_t)(SEQ + 1)) { + printf("Token data too short: need at least %d tokens, got %zu\n", SEQ + 2, n_tokens); + munmap(token_data, data_len); + close(data_fd); + return 1; + } printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); // Gradient buffers