mirror of https://github.com/maderix/ANE.git
Replace assert() with non-fatal bounds checks on token IDs
Follow-up to PR #31 — assert() aborts on bad tokens, which is too harsh for training. Skip bad tokens with a warning instead.
This commit is contained in:
parent
05fc8f85e3
commit
e986572e90
|
|
@ -1,7 +1,7 @@
|
|||
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
|
||||
#pragma once
|
||||
#include "stories_config.h"
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
|
||||
float *rms_tmp = (float*)malloc(S * sizeof(float));
|
||||
|
|
@ -95,7 +95,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
|
|||
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
|
||||
// loss
|
||||
int tgt = targets[t];
|
||||
assert(tgt >= 0 && tgt < V && "target token ID out of vocab range");
|
||||
if (tgt < 0 || tgt >= V) { fprintf(stderr, "WARN: target token %d out of vocab range [0,%d), skipping\n", tgt, V); continue; }
|
||||
total_loss -= logf(row[tgt] + 1e-10f);
|
||||
// gradient: softmax - one_hot, then /S
|
||||
row[tgt] -= 1.0f;
|
||||
|
|
@ -112,7 +112,7 @@ static float cross_entropy_loss(float *dlogits, const float *logits, const uint1
|
|||
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
|
||||
for (int t = 0; t < seq; t++) {
|
||||
int tok = tokens[t];
|
||||
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
|
||||
if (tok < 0 || tok >= VOCAB) { fprintf(stderr, "WARN: token %d out of range [0,%d)\n", tok, VOCAB); continue; }
|
||||
for (int d = 0; d < dim; d++) {
|
||||
x[d*seq + t] = embed[tok*dim + d];
|
||||
}
|
||||
|
|
@ -123,7 +123,7 @@ static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, i
|
|||
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
|
||||
for (int t = 0; t < seq; t++) {
|
||||
int tok = tokens[t];
|
||||
assert(tok >= 0 && tok < VOCAB && "token ID out of embedding range");
|
||||
if (tok < 0 || tok >= VOCAB) { continue; }
|
||||
for (int d = 0; d < dim; d++) {
|
||||
d_embed[tok*dim + d] += dx[d*seq + t];
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue