From e986572e9095c0f646ddee4ac03c487e586de0aa Mon Sep 17 00:00:00 2001 From: maderix Date: Wed, 4 Mar 2026 04:41:38 -0800 Subject: [PATCH] Replace assert() with non-fatal bounds checks on token IDs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to PR #31 — assert() aborts on bad tokens, which is too harsh for training. Skip bad tokens with a warning instead. --- training/stories_cpu_ops.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/training/stories_cpu_ops.h b/training/stories_cpu_ops.h index ae4dfdf..cd103c5 100644 --- a/training/stories_cpu_ops.h +++ b/training/stories_cpu_ops.h @@ -1,7 +1,7 @@ // stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax #pragma once #include "stories_config.h" -#include + static void rmsnorm(float *out, const float *x, const float *w, int d, int S) { float *rms_tmp = (float*)malloc(S * sizeof(float)); @@ -95,7 +95,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1 vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V); // loss int tgt = targets[t]; - assert(tgt >= 0 && tgt < V && "target token ID out of vocab range"); + if (tgt < 0 || tgt >= V) { fprintf(stderr, "WARN: target token %d out of vocab range [0,%d), skipping\n", tgt, V); continue; } total_loss -= logf(row[tgt] + 1e-10f); // gradient: softmax - one_hot, then /S row[tgt] -= 1.0f; @@ -112,7 +112,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1 static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) { for (int t = 0; t < seq; t++) { int tok = tokens[t]; - assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range"); + if (tok < 0 || tok >= VOCAB) { fprintf(stderr, "WARN: token %d out of range [0,%d)\n", tok, VOCAB); continue; } for (int d = 0; d < dim; d++) { x[d*seq + t] = embed[tok*dim + d]; } @@ -123,7 +123,7 @@ static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, i static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) { for (int t = 0; t < seq; t++) { int tok = tokens[t]; - assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range"); + if (tok < 0 || tok >= VOCAB) { continue; } for (int d = 0; d < dim; d++) { d_embed[tok*dim + d] += dx[d*seq + t]; }