From 2b3b7ae5ccf072774b9b8f5a2036b89fed75aa39 Mon Sep 17 00:00:00 2001 From: tastyheadphones Date: Tue, 3 Mar 2026 11:42:42 +0900 Subject: [PATCH] Fix token sampling underflow on short datasets --- training/train_large.m | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/training/train_large.m b/training/train_large.m index e58ce08..e33f2eb 100644 --- a/training/train_large.m +++ b/training/train_large.m @@ -274,11 +274,17 @@ int main(int argc, char *argv[]) { int data_fd = open(DATA_PATH, O_RDONLY); if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; } struct stat st; fstat(data_fd, &st); - size_t data_len = st.st_size; - uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); - if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } - size_t n_tokens = data_len / 2; - printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); + size_t data_len = st.st_size; + uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); + if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } + size_t n_tokens = data_len / 2; + if (n_tokens <= (size_t)(SEQ + 1)) { + printf("Token data too short: need at least %d tokens, got %zu\n", SEQ + 2, n_tokens); + munmap(token_data, data_len); + close(data_fd); + return 1; + } + printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); // Gradient buffers shared across layers (reused each step) float *dy = (float*)malloc(SEQ*DIM*4); // gradient flowing backward