Merge pull request #29 from nabbilkhan/contrib/fix-training-data-paths

Fix hardcoded TinyStories data path in train_large/train_large_ane
This commit is contained in:
Manjeet Singh 2026-03-04 17:48:43 +05:30 committed by GitHub
commit 032f866f2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 45 deletions

View File

@ -83,15 +83,17 @@ Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from HuggingFace. Pr
### 2. Build & Train ### 2. Build & Train
```bash ```bash
# Static baseline (classifier + softmax on CPU) # Static baseline (classifier + softmax on CPU)
make train_large make train_large
./train_large stories110M.bin 256 100 1e-4 ./train_large stories110M.bin 256 100 1e-4
./train_large --model stories110M.bin --steps 100 --lr 1e-4 ./train_large --model stories110M.bin --steps 100 --lr 1e-4
./train_large --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
make train_large_ane # PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
./train_large_ane stories110M.bin 256 100 1e-4 make train_large_ane
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras ./train_large_ane stories110M.bin 256 100 1e-4
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
./train_large_ane --data ./tinystories_data00.bin --steps 100 --lr 1e-4
# Dynamic pipeline (no recompilation) # Dynamic pipeline (no recompilation)
cd training_dynamic && make train cd training_dynamic && make train
@ -100,13 +102,14 @@ cd training_dynamic && make train
./train --steps 200 --lr 1e-4 # custom steps/lr ./train --steps 200 --lr 1e-4 # custom steps/lr
``` ```
**CLI flags (all pipelines):** **CLI flags (`train_large` / `train_large_ane`):**
- `--steps N` (default 10000) - `--steps N` (default 10000)
- `--lr F` (default 3e-4) - `--lr F` (default 3e-4)
- `--model PATH` — pretrained weights file - `--model PATH` — pretrained weights file
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts) - `--data PATH` — tokenized TinyStories `.bin` file (default: `tinystories_data00.bin`)
- `--resume` — resume from checkpoint - `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd - `--resume` — resume from checkpoint
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
### 3. Monitor with Dashboard ### 3. Monitor with Dashboard

View File

@ -5,9 +5,9 @@
#include "stories_mil.h" #include "stories_mil.h"
#include "stories_cpu_ops.h" #include "stories_cpu_ops.h"
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin" #define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
#define MODEL_PATH_DEFAULT "stories110M.bin" #define MODEL_PATH_DEFAULT "stories110M.bin"
#define DATA_PATH "tinystories_data00.bin" #define DATA_PATH_DEFAULT "tinystories_data00.bin"
// ===== Weight loading from llama2.c format ===== // ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
@ -192,22 +192,24 @@ int main(int argc, char *argv[]) {
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f; float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0; int adam_t = 0, start_step = 0;
// Parse args // Parse args
const char *ckpt_path = CKPT_PATH_DEFAULT; const char *ckpt_path = CKPT_PATH_DEFAULT;
const char *model_path = MODEL_PATH_DEFAULT; const char *model_path = MODEL_PATH_DEFAULT;
bool do_resume = false; const char *data_path = DATA_PATH_DEFAULT;
int pos = 0; bool do_resume = false;
for (int i=1; i<argc; i++) { int pos = 0;
if (strcmp(argv[i], "--resume") == 0) do_resume = true; for (int i=1; i<argc; i++) {
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]); if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]); else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i]; else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i]; else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i];
else if (argv[i][0] != '-') { else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
if (pos == 0) model_path = argv[i]; else if (strcmp(argv[i], "--data") == 0 && i+1<argc) data_path = argv[++i];
else if (pos == 1) { /* seq - compile-time constant */ } else if (argv[i][0] != '-') {
else if (pos == 2) total_steps = atoi(argv[i]); if (pos == 0) model_path = argv[i];
else if (pos == 3) lr = atof(argv[i]); else if (pos == 1) { /* seq - compile-time constant */ }
else if (pos == 2) total_steps = atoi(argv[i]);
else if (pos == 3) lr = atof(argv[i]);
pos++; pos++;
} }
} }
@ -283,8 +285,12 @@ int main(int argc, char *argv[]) {
} }
// mmap token data // mmap token data
int data_fd = open(DATA_PATH, O_RDONLY); int data_fd = open(data_path, O_RDONLY);
if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; } if (data_fd < 0) {
printf("Cannot open token data: %s\n", data_path);
printf("Hint: run `bash download_data.sh` in training/ or pass --data /path/to/tinystories_data00.bin\n");
return 1;
}
struct stat st; fstat(data_fd, &st); struct stat st; fstat(data_fd, &st);
size_t data_len = st.st_size; size_t data_len = st.st_size;
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
@ -346,9 +352,9 @@ int main(int argc, char *argv[]) {
lw, la, rms_final, &arms_final, embed, &aembed); lw, la, rms_final, &arms_final, embed, &aembed);
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss); printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout); fflush(stdout);
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, NULL); execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, "--data", data_path, NULL);
perror("execl"); return 1; perror("execl"); return 1;
} }
// Compile all layers' weight-bearing kernels // Compile all layers' weight-bearing kernels
uint64_t tc = mach_absolute_time(); uint64_t tc = mach_absolute_time();

View File

@ -9,7 +9,7 @@
// NLL loss + gradient (needs target indexing) // NLL loss + gradient (needs target indexing)
// //
// Build: make train_large_ane // Build: make train_large_ane
// Run: ./train_large_ane [--resume] [--steps N] [--lr F] // Run: ./train_large_ane [--resume] [--steps N] [--lr F] [--data PATH]
#include "stories_io.h" #include "stories_io.h"
#include "stories_mil.h" #include "stories_mil.h"
#include "stories_cpu_ops.h" #include "stories_cpu_ops.h"
@ -18,7 +18,7 @@
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin" #define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
#define MODEL_PATH_DEFAULT "stories110M.bin" #define MODEL_PATH_DEFAULT "stories110M.bin"
#define DATA_PATH "tinystories_data00.bin" #define DATA_PATH_DEFAULT "tinystories_data00.bin"
// ===== Weight loading from llama2.c format ===== // ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) { static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
@ -204,6 +204,7 @@ int main(int argc, char *argv[]) {
int adam_t = 0, start_step = 0; int adam_t = 0, start_step = 0;
const char *ckpt_path = CKPT_PATH_DEFAULT; const char *ckpt_path = CKPT_PATH_DEFAULT;
const char *model_path = MODEL_PATH_DEFAULT; const char *model_path = MODEL_PATH_DEFAULT;
const char *data_path = DATA_PATH_DEFAULT;
bool do_resume = false; bool do_resume = false;
bool ane_extras = true; // classifier, softmax, rmsnorm_bwd on ANE bool ane_extras = true; // classifier, softmax, rmsnorm_bwd on ANE
int pos = 0; int pos = 0;
@ -214,6 +215,7 @@ int main(int argc, char *argv[]) {
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]); else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i]; else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i];
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i]; else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
else if (strcmp(argv[i], "--data") == 0 && i+1<argc) data_path = argv[++i];
else if (argv[i][0] != '-') { else if (argv[i][0] != '-') {
if (pos == 0) model_path = argv[i]; if (pos == 0) model_path = argv[i];
else if (pos == 1) { /* seq - compile-time constant */ } else if (pos == 1) { /* seq - compile-time constant */ }
@ -271,8 +273,12 @@ int main(int argc, char *argv[]) {
} }
// mmap token data // mmap token data
int data_fd = open(DATA_PATH, O_RDONLY); int data_fd = open(data_path, O_RDONLY);
if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; } if (data_fd < 0) {
printf("Cannot open token data: %s\n", data_path);
printf("Hint: run `bash download_data.sh` in training/ or pass --data /path/to/tinystories_data00.bin\n");
return 1;
}
struct stat st; fstat(data_fd, &st); struct stat st; fstat(data_fd, &st);
size_t data_len = st.st_size; size_t data_len = st.st_size;
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
@ -354,9 +360,9 @@ int main(int argc, char *argv[]) {
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss); printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout); fflush(stdout);
if (ane_extras) if (ane_extras)
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, NULL); execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, "--data", data_path, NULL);
else else
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, "--no-ane-extras", NULL); execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, "--data", data_path, "--no-ane-extras", NULL);
perror("execl"); return 1; perror("execl"); return 1;
} }