Replace assert() with non-fatal bounds checks on token IDs

Follow-up to PR #31 — assert() aborts on bad tokens, which is too
harsh for training. Skip bad tokens with a warning instead.
This commit is contained in:
maderix 2026-03-04 04:41:38 -08:00
parent 05fc8f85e3
commit e986572e90
1 changed files with 4 additions and 4 deletions

View File

@ -1,7 +1,7 @@
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
#pragma once
#include "stories_config.h"
#include <assert.h>
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
float *rms_tmp = (float*)malloc(S * sizeof(float));
@ -95,7 +95,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
// loss
int tgt = targets[t];
assert(tgt >= 0 && tgt < V && "target token ID out of vocab range");
if (tgt < 0 || tgt >= V) { fprintf(stderr, "WARN: target token %d out of vocab range [0,%d), skipping\n", tgt, V); continue; }
total_loss -= logf(row[tgt] + 1e-10f);
// gradient: softmax - one_hot, then /S
row[tgt] -= 1.0f;
@ -112,7 +112,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
if (tok < 0 || tok >= VOCAB) { fprintf(stderr, "WARN: token %d out of range [0,%d)\n", tok, VOCAB); continue; }
for (int d = 0; d < dim; d++) {
x[d*seq + t] = embed[tok*dim + d];
}
@ -123,7 +123,7 @@ static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, i
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
if (tok < 0 || tok >= VOCAB) { continue; }
for (int d = 0; d < dim; d++) {
d_embed[tok*dim + d] += dx[d*seq + t];
}