mirror of https://github.com/maderix/ANE.git
CLI fixes + --no-ane-extras flag + README benchmark table
- Fix positional arg parsing (model_path, steps, lr were silently ignored) - Add --model, --ckpt flags; forward ckpt_path across exec() restarts - Add --no-ane-extras to disable ANE classifier/softmax/rmsnorm_bwd - CPU fallback for softmax/classifier/rmsnorm_bwd when extras disabled - Update README with 4-way benchmark comparison table (20 steps)
This commit is contained in:
parent
cb474e1537
commit
4c14ed0e25
|
|
@ -8,43 +8,68 @@ Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly
|
||||||
|
|
||||||
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
|
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
|
||||||
- **109.53M params** (84.95M transformer + 24.58M embedding)
|
- **109.53M params** (84.95M transformer + 24.58M embedding)
|
||||||
- **72 ANE kernels** per compile (60 weight-bearing, 12 weight-free sdpaBwd2)
|
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask — decompose into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
|
||||||
- **6 kernel types per layer**: fwdAttn, fwdFFN, ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd
|
|
||||||
|
|
||||||
## Performance
|
## Three Training Pipelines
|
||||||
|
|
||||||
| Component | Time (ms/step) |
|
### 1. Static Baseline (`train_large`)
|
||||||
|-----------|---------------|
|
Original pipeline. Weights baked as constants in MIL kernels — recompile every 10 steps via `exec()` restart.
|
||||||
| ANE eval | 9.6 |
|
|
||||||
| IO (fp16 conversion) | 4.1 |
|
- 60 weight-bearing + 12 weight-free kernels = 72 per compile batch
|
||||||
| Classifier (cblas) | 9.1 |
|
- Classifier + softmax + RMSNorm backward on CPU
|
||||||
| Cross-entropy + residuals | 14.4 |
|
- **106.7 ms/step**, 7.6s compile per restart
|
||||||
| RMSNorm | 0.1 |
|
|
||||||
| **Total** | **107 ms/step** |
|
### 2. Static + ANE Extras (`train_large_ane`) — PR#19
|
||||||
|
Offloads classifier forward (32K conv), softmax, final RMSNorm, and RMSNorm backward to ANE. Bridge API for C-callable ANE access.
|
||||||
|
|
||||||
|
- 86 kernels per compile batch (+24 rmsnorm_bwd, +1 classifier, +1 finalRms)
|
||||||
|
- **91.8 ms/step** (14% faster), 9.6s compile per restart
|
||||||
|
- Use `--no-ane-extras` to disable and fall back to CPU (for debugging)
|
||||||
|
|
||||||
|
### 3. Dynamic Weight Pipeline (`training_dynamic/`)
|
||||||
|
Weights passed via IOSurface spatial dimension — compile 9 kernels once at startup, no recompilation needed.
|
||||||
|
|
||||||
|
- 9 shared kernels across all 12 layers
|
||||||
|
- **111 ms/step**, 0.4s one-time compile
|
||||||
|
- No exec() restart, no compile limit issues
|
||||||
|
|
||||||
|
## Performance Comparison (20 Steps)
|
||||||
|
|
||||||
|
| | Static Baseline | PR#19 + ANE extras | PR#19 no extras | Dynamic |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| **Wall time** | **10.1s** | **11.7s** | **10.7s** | **~2.6s** |
|
||||||
|
| Compile | 7.6s (75.7%) | 9.6s (81.6%) | 7.5s (69.7%) | 0.4s (15%) |
|
||||||
|
| Train | 2.1s (21.2%) | 1.8s (15.6%) | 2.9s (27.4%) | 2.2s (85%) |
|
||||||
|
| **ms/step** | **106.7** | **91.8** | **147.0** | **111** |
|
||||||
|
| Kernels/restart | 72 | 86 | 60 | 9 (once) |
|
||||||
|
| ANE TFLOPS | 0.87 | 1.15 | 0.72 | — |
|
||||||
|
| Total TFLOPS | 1.63 | 1.90 | 1.19 | — |
|
||||||
|
|
||||||
|
**Key insights:**
|
||||||
|
- Dynamic wins on wall time for any practical run length (3.9x faster at 20 steps)
|
||||||
|
- PR#19 has the best per-step throughput (92ms) but compile overhead dominates short runs
|
||||||
|
- Static restarts every 10 steps, so dynamic's zero-recompile advantage compounds
|
||||||
|
|
||||||
## Files
|
## Files
|
||||||
|
|
||||||
| File | Description |
|
| File | Description |
|
||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `train_large.m` | Main training loop — 12-layer forward/backward, checkpoint, exec() restart |
|
| `train_large.m` | Static baseline — 72 kernels, classifier/softmax on CPU |
|
||||||
| `stories_config.h` | Model config, structs, alloc helpers |
|
| `train_large_ane.m` | PR#19 — 86 kernels, classifier/softmax/rmsnorm_bwd on ANE |
|
||||||
|
| `training_dynamic/train.m` | Dynamic pipeline — 9 kernels, weights via IOSurface |
|
||||||
|
| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels |
|
||||||
|
| `training_dynamic/config.h` | Model config (DIM=768, HIDDEN=2048, etc.) |
|
||||||
|
| `training_dynamic/io.h` | IOSurface I/O + MIL compilation helpers |
|
||||||
|
| `training_dynamic/cpu_ops.h` | CPU ops (SiLU backward, cross-entropy, Adam) |
|
||||||
|
| `stories_config.h` | Static pipeline config, structs, alloc helpers |
|
||||||
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
|
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
|
||||||
| `stories_mil.h` | MIL program generators for all 6 ANE kernel types |
|
| `stories_mil.h` | MIL generators for static pipeline (6 kernel types) |
|
||||||
| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam, embedding ops |
|
| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam |
|
||||||
| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs, text generation |
|
| `ane_classifier.h` | ANE classifier fwd (32K conv), softmax kernels |
|
||||||
| `tokenize.py` | Extract pretokenized TinyStories data |
|
| `ane_rmsnorm_bwd.h` | ANE rmsnorm backward kernel |
|
||||||
|
| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs |
|
||||||
| `Makefile` | Build targets |
|
| `Makefile` | Build targets |
|
||||||
|
|
||||||
## How it works
|
|
||||||
|
|
||||||
1. **Forward pass**: Each layer runs fwdAttn (QKV + SDPA + Wo) and fwdFFN (W1 + SiLU(W3) + W2) on ANE via MIL-compiled kernels. Final RMSNorm + classifier matmul on CPU (cblas).
|
|
||||||
|
|
||||||
2. **Backward pass**: Reverse layer order. ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd on ANE. Weight gradients (dW) via async cblas_sgemm on CPU. RMSNorm backward via vDSP.
|
|
||||||
|
|
||||||
3. **Compile budget**: ANE has a ~119 compile limit per process. With 72 kernels per batch, we run 10 accumulation steps then `exec()` restart with checkpoint resume.
|
|
||||||
|
|
||||||
4. **Data**: Real TinyStories text (20M tokens), mmap'd uint16 token IDs, random position sampling per step.
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### 1. Download Training Data
|
### 1. Download Training Data
|
||||||
|
|
@ -53,69 +78,63 @@ Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly
|
||||||
bash download_data.sh
|
bash download_data.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from [enio/TinyStories](https://huggingface.co/datasets/enio/TinyStories) on HuggingFace. Produces `tinystories_data00.bin` (~41 MB, ~20M tokens).
|
Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from HuggingFace. Produces `tinystories_data00.bin` (~41 MB, ~20M tokens).
|
||||||
|
|
||||||
### 2. Build & Train
|
### 2. Build & Train
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Baseline: classifier + softmax on CPU
|
# Static baseline (classifier + softmax on CPU)
|
||||||
make train_large
|
make train_large
|
||||||
./train_large --steps 100 # quick test
|
./train_large stories110M.bin 256 100 1e-4
|
||||||
./train_large # full 10k steps
|
./train_large --model stories110M.bin --steps 100 --lr 1e-4
|
||||||
./train_large --resume # resume from checkpoint
|
|
||||||
|
|
||||||
# ANE-offloaded: classifier + softmax on ANE (faster)
|
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
|
||||||
make train_large_ane
|
make train_large_ane
|
||||||
./train_large_ane --steps 100
|
./train_large_ane stories110M.bin 256 100 1e-4
|
||||||
|
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
|
||||||
|
|
||||||
|
# Dynamic pipeline (no recompilation)
|
||||||
|
cd training_dynamic && make train
|
||||||
|
./train --scratch # train from random init
|
||||||
|
./train # resume from checkpoint
|
||||||
|
./train --steps 200 --lr 1e-4 # custom steps/lr
|
||||||
```
|
```
|
||||||
|
|
||||||
**CLI flags:** `--steps N` (default 10000), `--lr F` (default 3e-4), `--resume`.
|
**CLI flags (all pipelines):**
|
||||||
|
- `--steps N` (default 10000)
|
||||||
|
- `--lr F` (default 3e-4)
|
||||||
|
- `--model PATH` — pretrained weights file
|
||||||
|
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
|
||||||
|
- `--resume` — resume from checkpoint
|
||||||
|
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
|
||||||
|
|
||||||
### 3. Monitor with Dashboard
|
### 3. Monitor with Dashboard
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install blessed psutil numpy
|
pip install blessed psutil numpy
|
||||||
sudo python3 dashboard.py # live mode (needs powermetrics)
|
sudo python3 dashboard.py # static pipeline
|
||||||
sudo python3 dashboard.py --resume # attach to resumed training
|
sudo python3 dashboard.py --dynamic # dynamic pipeline
|
||||||
```
|
```
|
||||||
|
|
||||||
### 4. Benchmarking
|
### 4. Benchmarking
|
||||||
|
|
||||||
Both programs print an **Efficiency Report** at completion:
|
All programs print an **Efficiency Report** at completion:
|
||||||
|
|
||||||
```
|
```
|
||||||
=== Efficiency Report ===
|
=== Efficiency Report ===
|
||||||
Total steps: 100
|
Total steps: 20
|
||||||
Avg train: 107.0 ms/step
|
Wall time: 11738 ms (11.7 s)
|
||||||
ANE TFLOPS: 2.45 sustained
|
Compile time: 9583 ms (81.6%)
|
||||||
ANE utilization: 15.5% of 15.8 TFLOPS
|
Train time: 1835 ms (15.6%)
|
||||||
|
Avg train: 91.8 ms/step
|
||||||
|
ANE TFLOPS: 1.15 sustained
|
||||||
```
|
```
|
||||||
|
|
||||||
Per-batch timing breakdown during training:
|
## Key Techniques
|
||||||
|
|
||||||
```
|
- **NEON vectorized fp16↔fp32**: ARM NEON intrinsics for fast IOSurface data transfer
|
||||||
ane=9.6 io=4.1 cls=9.1 elem=14.4 rms=0.1 cblas_wait=2.3 ms/step
|
|
||||||
```
|
|
||||||
|
|
||||||
| Metric | What it measures |
|
|
||||||
|--------|-----------------|
|
|
||||||
| `ane` | ANE kernel evaluation |
|
|
||||||
| `io` | fp16↔fp32 IOSurface transfer |
|
|
||||||
| `cls` | Classifier matmul (CPU cblas) |
|
|
||||||
| `elem` | Embedding, residual adds, cross-entropy |
|
|
||||||
| `rms` | RMSNorm forward/backward |
|
|
||||||
| `cblas_wait` | Waiting for async dW gradient sgemms |
|
|
||||||
|
|
||||||
Compare baseline vs ANE-offloaded:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make train_large && ./train_large --steps 100
|
|
||||||
make train_large_ane && ./train_large_ane --steps 100
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key techniques
|
|
||||||
|
|
||||||
- **NEON vectorized fp16<->fp32**: ARM NEON intrinsics for fast IOSurface data transfer
|
|
||||||
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
|
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
|
||||||
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
|
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
|
||||||
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask, so we decompose attention into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
|
- **Vocab compaction** (dynamic): 32K → 9.2K active tokens, 3.5x reduction in classifier work
|
||||||
|
- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all 12 layers
|
||||||
|
- **exec() restart**: Workaround for ANE ~119 compile limit per process
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@
|
||||||
#include "stories_mil.h"
|
#include "stories_mil.h"
|
||||||
#include "stories_cpu_ops.h"
|
#include "stories_cpu_ops.h"
|
||||||
|
|
||||||
#define CKPT_PATH "ane_stories110M_ckpt.bin"
|
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
|
||||||
#define MODEL_PATH "../../assets/models/stories110M.bin"
|
#define MODEL_PATH_DEFAULT "stories110M.bin"
|
||||||
#define DATA_PATH "tinystories_data00.bin"
|
#define DATA_PATH "tinystories_data00.bin"
|
||||||
|
|
||||||
// ===== Weight loading from llama2.c format =====
|
// ===== Weight loading from llama2.c format =====
|
||||||
|
|
@ -193,11 +193,23 @@ int main(int argc, char *argv[]) {
|
||||||
int adam_t = 0, start_step = 0;
|
int adam_t = 0, start_step = 0;
|
||||||
|
|
||||||
// Parse args
|
// Parse args
|
||||||
|
const char *ckpt_path = CKPT_PATH_DEFAULT;
|
||||||
|
const char *model_path = MODEL_PATH_DEFAULT;
|
||||||
bool do_resume = false;
|
bool do_resume = false;
|
||||||
|
int pos = 0;
|
||||||
for (int i=1; i<argc; i++) {
|
for (int i=1; i<argc; i++) {
|
||||||
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
|
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
|
||||||
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
|
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
|
||||||
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
|
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
|
||||||
|
else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i];
|
||||||
|
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
|
||||||
|
else if (argv[i][0] != '-') {
|
||||||
|
if (pos == 0) model_path = argv[i];
|
||||||
|
else if (pos == 1) { /* seq - compile-time constant */ }
|
||||||
|
else if (pos == 2) total_steps = atoi(argv[i]);
|
||||||
|
else if (pos == 3) lr = atof(argv[i]);
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate per-layer state
|
// Allocate per-layer state
|
||||||
|
|
@ -228,7 +240,7 @@ int main(int argc, char *argv[]) {
|
||||||
float resume_loss = 0;
|
float resume_loss = 0;
|
||||||
bool resuming = false;
|
bool resuming = false;
|
||||||
if (do_resume) {
|
if (do_resume) {
|
||||||
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
|
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
|
||||||
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
|
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
|
||||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||||
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
|
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
|
||||||
|
|
@ -236,7 +248,7 @@ int main(int argc, char *argv[]) {
|
||||||
if (!resuming) {
|
if (!resuming) {
|
||||||
printf("=== ANE Training: Stories110M (12 layers) ===\n");
|
printf("=== ANE Training: Stories110M (12 layers) ===\n");
|
||||||
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
||||||
if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
|
if (!load_pretrained(lw, rms_final, embed, model_path)) {
|
||||||
printf("Pretrained load failed, using random init\n");
|
printf("Pretrained load failed, using random init\n");
|
||||||
srand48(42);
|
srand48(42);
|
||||||
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
||||||
|
|
@ -322,13 +334,13 @@ int main(int argc, char *argv[]) {
|
||||||
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
|
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
|
||||||
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
|
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
|
||||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||||
save_checkpoint(CKPT_PATH, step, total_steps, lr, last_loss,
|
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
|
||||||
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
|
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
|
||||||
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
|
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
|
||||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||||
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
|
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
execl(argv[0], argv[0], "--resume", NULL);
|
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, NULL);
|
||||||
perror("execl"); return 1;
|
perror("execl"); return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@
|
||||||
#include "ane_rmsnorm_bwd.h"
|
#include "ane_rmsnorm_bwd.h"
|
||||||
#include "ane_classifier.h"
|
#include "ane_classifier.h"
|
||||||
|
|
||||||
#define CKPT_PATH "ane_stories110M_ckpt.bin"
|
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
|
||||||
#define MODEL_PATH "../../assets/models/stories110M.bin"
|
#define MODEL_PATH_DEFAULT "stories110M.bin"
|
||||||
#define DATA_PATH "tinystories_data00.bin"
|
#define DATA_PATH "tinystories_data00.bin"
|
||||||
|
|
||||||
// ===== Weight loading from llama2.c format =====
|
// ===== Weight loading from llama2.c format =====
|
||||||
|
|
@ -202,11 +202,25 @@ int main(int argc, char *argv[]) {
|
||||||
float lr = 3e-4f;
|
float lr = 3e-4f;
|
||||||
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
|
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
|
||||||
int adam_t = 0, start_step = 0;
|
int adam_t = 0, start_step = 0;
|
||||||
|
const char *ckpt_path = CKPT_PATH_DEFAULT;
|
||||||
|
const char *model_path = MODEL_PATH_DEFAULT;
|
||||||
bool do_resume = false;
|
bool do_resume = false;
|
||||||
|
bool ane_extras = true; // classifier, softmax, rmsnorm_bwd on ANE
|
||||||
|
int pos = 0;
|
||||||
for (int i=1; i<argc; i++) {
|
for (int i=1; i<argc; i++) {
|
||||||
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
|
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
|
||||||
|
else if (strcmp(argv[i], "--no-ane-extras") == 0) ane_extras = false;
|
||||||
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
|
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
|
||||||
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
|
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
|
||||||
|
else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i];
|
||||||
|
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
|
||||||
|
else if (argv[i][0] != '-') {
|
||||||
|
if (pos == 0) model_path = argv[i];
|
||||||
|
else if (pos == 1) { /* seq - compile-time constant */ }
|
||||||
|
else if (pos == 2) total_steps = atoi(argv[i]);
|
||||||
|
else if (pos == 3) lr = atof(argv[i]);
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS];
|
LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS];
|
||||||
|
|
@ -228,7 +242,7 @@ int main(int argc, char *argv[]) {
|
||||||
float resume_loss = 0;
|
float resume_loss = 0;
|
||||||
bool resuming = false;
|
bool resuming = false;
|
||||||
if (do_resume) {
|
if (do_resume) {
|
||||||
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
|
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
|
||||||
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
|
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
|
||||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||||
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
|
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
|
||||||
|
|
@ -236,8 +250,9 @@ int main(int argc, char *argv[]) {
|
||||||
if (!resuming) {
|
if (!resuming) {
|
||||||
printf("=== ANE Training: Stories110M (ANE-offloaded) ===\n");
|
printf("=== ANE Training: Stories110M (ANE-offloaded) ===\n");
|
||||||
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
||||||
printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n");
|
if (ane_extras) printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n");
|
||||||
if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
|
else printf("ANE extras DISABLED (classifier/softmax/rmsnorm_bwd on CPU)\n");
|
||||||
|
if (!load_pretrained(lw, rms_final, embed, model_path)) {
|
||||||
printf("Pretrained load failed, using random init\n");
|
printf("Pretrained load failed, using random init\n");
|
||||||
srand48(42);
|
srand48(42);
|
||||||
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
||||||
|
|
@ -301,9 +316,12 @@ int main(int argc, char *argv[]) {
|
||||||
memset(rmsFFNBwd, 0, sizeof(rmsFFNBwd));
|
memset(rmsFFNBwd, 0, sizeof(rmsFFNBwd));
|
||||||
|
|
||||||
// Softmax kernel (no weights — compile once)
|
// Softmax kernel (no weights — compile once)
|
||||||
Kern *softmaxKern = compile_softmax_kern();
|
Kern *softmaxKern = NULL;
|
||||||
if (!softmaxKern) { printf("softmax compile failed\n"); return 1; }
|
if (ane_extras) {
|
||||||
printf("Softmax kernel compiled (no weights)\n");
|
softmaxKern = compile_softmax_kern();
|
||||||
|
if (!softmaxKern) { printf("softmax compile failed\n"); return 1; }
|
||||||
|
printf("Softmax kernel compiled (no weights)\n");
|
||||||
|
}
|
||||||
|
|
||||||
// Final RMSNorm and classifier are recompiled per batch since they have baked weights
|
// Final RMSNorm and classifier are recompiled per batch since they have baked weights
|
||||||
Kern *finalRmsKern = NULL, *classifierKern = NULL;
|
Kern *finalRmsKern = NULL, *classifierKern = NULL;
|
||||||
|
|
@ -320,8 +338,8 @@ int main(int argc, char *argv[]) {
|
||||||
int step = start_step;
|
int step = start_step;
|
||||||
while (step < total_steps) {
|
while (step < total_steps) {
|
||||||
// Check compile budget — account for new kernels
|
// Check compile budget — account for new kernels
|
||||||
// Per batch: 60 layer kernels + 24 rmsnorm_bwd + 1 classifier + 1 final_rms = 86
|
// Per batch: 60 layer kernels [+ 24 rmsnorm_bwd + 1 classifier + 1 final_rms = 86 with extras]
|
||||||
int kernels_needed = TOTAL_WEIGHT_KERNELS + 2*NLAYERS + 2;
|
int kernels_needed = TOTAL_WEIGHT_KERNELS + (ane_extras ? 2*NLAYERS + 2 : 0);
|
||||||
if (g_compile_count + kernels_needed > MAX_COMPILES) {
|
if (g_compile_count + kernels_needed > MAX_COMPILES) {
|
||||||
for (int L=0; L<NLAYERS; L++) {
|
for (int L=0; L<NLAYERS; L++) {
|
||||||
free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]);
|
free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]);
|
||||||
|
|
@ -329,13 +347,16 @@ int main(int argc, char *argv[]) {
|
||||||
}
|
}
|
||||||
free_kern(softmaxKern); free_kern(finalRmsKern); free_kern(classifierKern);
|
free_kern(softmaxKern); free_kern(finalRmsKern); free_kern(classifierKern);
|
||||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||||
save_checkpoint(CKPT_PATH, step, total_steps, lr, last_loss,
|
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
|
||||||
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
|
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
|
||||||
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
|
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
|
||||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||||
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
|
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
execl(argv[0], argv[0], "--resume", NULL);
|
if (ane_extras)
|
||||||
|
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, NULL);
|
||||||
|
else
|
||||||
|
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, "--no-ane-extras", NULL);
|
||||||
perror("execl"); return 1;
|
perror("execl"); return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -350,13 +371,15 @@ int main(int argc, char *argv[]) {
|
||||||
printf("\nCompile failed at layer %d\n", L);
|
printf("\nCompile failed at layer %d\n", L);
|
||||||
compile_ok = false; break;
|
compile_ok = false; break;
|
||||||
}
|
}
|
||||||
// NEW: Compile RMSNorm backward kernels for this layer
|
// Compile RMSNorm backward kernels for this layer (if ane_extras)
|
||||||
free_kern(rmsAttBwd[L]); free_kern(rmsFFNBwd[L]);
|
if (ane_extras) {
|
||||||
rmsAttBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_att);
|
free_kern(rmsAttBwd[L]); free_kern(rmsFFNBwd[L]);
|
||||||
rmsFFNBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_ffn);
|
rmsAttBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_att);
|
||||||
if (!rmsAttBwd[L] || !rmsFFNBwd[L]) {
|
rmsFFNBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_ffn);
|
||||||
printf("\nrmsnorm_bwd compile failed at layer %d\n", L);
|
if (!rmsAttBwd[L] || !rmsFFNBwd[L]) {
|
||||||
compile_ok = false; break;
|
printf("\nrmsnorm_bwd compile failed at layer %d\n", L);
|
||||||
|
compile_ok = false; break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
|
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
|
||||||
|
|
@ -369,18 +392,19 @@ int main(int argc, char *argv[]) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NEW: Compile final RMSNorm and classifier with current weights
|
// Compile final RMSNorm and classifier with current weights (if ane_extras)
|
||||||
free_kern(finalRmsKern); free_kern(classifierKern);
|
if (ane_extras) {
|
||||||
finalRmsKern = compile_final_rmsnorm_kern(rms_final);
|
free_kern(finalRmsKern); free_kern(classifierKern);
|
||||||
classifierKern = compile_classifier_fwd(embed);
|
finalRmsKern = compile_final_rmsnorm_kern(rms_final);
|
||||||
if (!finalRmsKern || !classifierKern) {
|
classifierKern = compile_classifier_fwd(embed);
|
||||||
printf("finalRms or classifier compile failed\n");
|
if (!finalRmsKern || !classifierKern) {
|
||||||
g_compile_count = MAX_COMPILES; continue;
|
printf("finalRms or classifier compile failed\n");
|
||||||
}
|
g_compile_count = MAX_COMPILES; continue;
|
||||||
// Re-compile softmax if needed
|
}
|
||||||
if (!softmaxKern) {
|
if (!softmaxKern) {
|
||||||
softmaxKern = compile_softmax_kern();
|
softmaxKern = compile_softmax_kern();
|
||||||
if (!softmaxKern) { printf("softmax recompile failed\n"); return 1; }
|
if (!softmaxKern) { printf("softmax recompile failed\n"); return 1; }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double cms = tb_ms(mach_absolute_time() - tc);
|
double cms = tb_ms(mach_absolute_time() - tc);
|
||||||
|
|
@ -444,26 +468,46 @@ int main(int argc, char *argv[]) {
|
||||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
|
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHANGED: Final RMSNorm on ANE (was CPU)
|
|
||||||
t0=mach_absolute_time();
|
t0=mach_absolute_time();
|
||||||
io_write_fp16(finalRmsKern->ioIn, x_cur, DIM, SEQ);
|
if (ane_extras) {
|
||||||
ane_eval(finalRmsKern);
|
// Final RMSNorm on ANE
|
||||||
io_read_fp16(finalRmsKern->ioOut, x_final, 0, DIM, SEQ);
|
io_write_fp16(finalRmsKern->ioIn, x_cur, DIM, SEQ);
|
||||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
ane_eval(finalRmsKern);
|
||||||
|
io_read_fp16(finalRmsKern->ioOut, x_final, 0, DIM, SEQ);
|
||||||
|
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||||
|
|
||||||
// CHANGED: Classifier on ANE (was CPU cblas)
|
// Classifier on ANE
|
||||||
io_write_fp16(classifierKern->ioIn, x_final, DIM, SEQ);
|
io_write_fp16(classifierKern->ioIn, x_final, DIM, SEQ);
|
||||||
ane_eval(classifierKern);
|
ane_eval(classifierKern);
|
||||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||||
|
|
||||||
// CHANGED: Softmax on ANE, then read probs back for NLL on CPU
|
// Softmax on ANE
|
||||||
io_copy(softmaxKern->ioIn, 0, classifierKern->ioOut, 0, VOCAB, SEQ);
|
io_copy(softmaxKern->ioIn, 0, classifierKern->ioOut, 0, VOCAB, SEQ);
|
||||||
ane_eval(softmaxKern);
|
ane_eval(softmaxKern);
|
||||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||||
|
|
||||||
// Read probs back for NLL loss + gradient (needs target indexing — CPU)
|
io_read_fp16(softmaxKern->ioOut, probs, 0, VOCAB, SEQ);
|
||||||
io_read_fp16(softmaxKern->ioOut, probs, 0, VOCAB, SEQ);
|
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
} else {
|
||||||
|
// CPU fallback: rmsnorm + classifier + softmax
|
||||||
|
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
|
||||||
|
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
|
||||||
|
|
||||||
|
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||||
|
VOCAB, SEQ, DIM, 1.0f,
|
||||||
|
embed, DIM, x_final, SEQ, 0.0f, probs, SEQ);
|
||||||
|
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
|
||||||
|
|
||||||
|
// CPU softmax
|
||||||
|
for (int t = 0; t < SEQ; t++) {
|
||||||
|
float maxv = -1e30f;
|
||||||
|
for (int v = 0; v < VOCAB; v++) { float val = probs[v*SEQ+t]; if (val > maxv) maxv = val; }
|
||||||
|
float sum = 0;
|
||||||
|
for (int v = 0; v < VOCAB; v++) { probs[v*SEQ+t] = expf(probs[v*SEQ+t] - maxv); sum += probs[v*SEQ+t]; }
|
||||||
|
for (int v = 0; v < VOCAB; v++) probs[v*SEQ+t] /= sum;
|
||||||
|
}
|
||||||
|
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
|
||||||
|
}
|
||||||
|
|
||||||
// NLL loss + gradient on CPU: dlogits = probs - one_hot(targets)
|
// NLL loss + gradient on CPU: dlogits = probs - one_hot(targets)
|
||||||
float total_loss = 0;
|
float total_loss = 0;
|
||||||
|
|
@ -531,17 +575,19 @@ int main(int argc, char *argv[]) {
|
||||||
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
|
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
|
||||||
});
|
});
|
||||||
|
|
||||||
// CHANGED: RMSNorm2 backward on ANE
|
// RMSNorm2 backward
|
||||||
// Write concat(dx_ffn, x2) into rmsnorm_bwd kernel
|
if (ane_extras) {
|
||||||
io_write_fp16_at(rmsFFNBwd[L]->ioIn, 0, dx_ffn, DIM, SEQ);
|
io_write_fp16_at(rmsFFNBwd[L]->ioIn, 0, dx_ffn, DIM, SEQ);
|
||||||
io_write_fp16_at(rmsFFNBwd[L]->ioIn, DIM, ac->x2, DIM, SEQ);
|
io_write_fp16_at(rmsFFNBwd[L]->ioIn, DIM, ac->x2, DIM, SEQ);
|
||||||
ane_eval(rmsFFNBwd[L]);
|
ane_eval(rmsFFNBwd[L]);
|
||||||
io_read_fp16(rmsFFNBwd[L]->ioOut, dx2, 0, DIM, SEQ);
|
io_read_fp16(rmsFFNBwd[L]->ioOut, dx2, 0, DIM, SEQ);
|
||||||
// dw for rmsnorm_ffn still on CPU (accumulate per step)
|
}
|
||||||
|
// dw for rmsnorm_ffn on CPU (accumulate per step)
|
||||||
{
|
{
|
||||||
float *dw_tmp = (float*)calloc(DIM, 4);
|
float *dw_tmp = (float*)calloc(DIM, 4);
|
||||||
float *dx_scratch = (float*)malloc(SEQ*DIM*4);
|
float *dx_scratch = (float*)malloc(SEQ*DIM*4);
|
||||||
rmsnorm_bwd(dx_scratch, dw_tmp, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
|
rmsnorm_bwd(dx_scratch, dw_tmp, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
|
||||||
|
if (!ane_extras) memcpy(dx2, dx_scratch, SEQ*DIM*4);
|
||||||
for(int i=0;i<DIM;i++) gr->rms_ffn[i] += dw_tmp[i];
|
for(int i=0;i<DIM;i++) gr->rms_ffn[i] += dw_tmp[i];
|
||||||
free(dx_scratch); free(dw_tmp);
|
free(dx_scratch); free(dw_tmp);
|
||||||
}
|
}
|
||||||
|
|
@ -591,17 +637,20 @@ int main(int argc, char *argv[]) {
|
||||||
ane_eval(kern[L].qkvBwd);
|
ane_eval(kern[L].qkvBwd);
|
||||||
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
|
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
|
||||||
|
|
||||||
// CHANGED: RMSNorm1 backward on ANE
|
// RMSNorm1 backward
|
||||||
io_write_fp16_at(rmsAttBwd[L]->ioIn, 0, dx_attn, DIM, SEQ);
|
|
||||||
io_write_fp16_at(rmsAttBwd[L]->ioIn, DIM, ac->layer_in, DIM, SEQ);
|
|
||||||
ane_eval(rmsAttBwd[L]);
|
|
||||||
float *dx_rms1 = (float*)malloc(SEQ*DIM*4);
|
float *dx_rms1 = (float*)malloc(SEQ*DIM*4);
|
||||||
io_read_fp16(rmsAttBwd[L]->ioOut, dx_rms1, 0, DIM, SEQ);
|
if (ane_extras) {
|
||||||
// dw for rmsnorm_att still on CPU
|
io_write_fp16_at(rmsAttBwd[L]->ioIn, 0, dx_attn, DIM, SEQ);
|
||||||
|
io_write_fp16_at(rmsAttBwd[L]->ioIn, DIM, ac->layer_in, DIM, SEQ);
|
||||||
|
ane_eval(rmsAttBwd[L]);
|
||||||
|
io_read_fp16(rmsAttBwd[L]->ioOut, dx_rms1, 0, DIM, SEQ);
|
||||||
|
}
|
||||||
|
// dw for rmsnorm_att on CPU
|
||||||
{
|
{
|
||||||
float *dw_tmp = (float*)calloc(DIM, 4);
|
float *dw_tmp = (float*)calloc(DIM, 4);
|
||||||
float *dx_scratch = (float*)malloc(SEQ*DIM*4);
|
float *dx_scratch = (float*)malloc(SEQ*DIM*4);
|
||||||
rmsnorm_bwd(dx_scratch, dw_tmp, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
|
rmsnorm_bwd(dx_scratch, dw_tmp, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
|
||||||
|
if (!ane_extras) memcpy(dx_rms1, dx_scratch, SEQ*DIM*4);
|
||||||
for(int i=0;i<DIM;i++) gr->rms_att[i] += dw_tmp[i];
|
for(int i=0;i<DIM;i++) gr->rms_att[i] += dw_tmp[i];
|
||||||
free(dx_scratch); free(dw_tmp);
|
free(dx_scratch); free(dw_tmp);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue