[feat] Merge upstream PRs #21, #23, #26: NEON-optimized training (train_opt), double-buffered async ANE training (train_double_buffer), Qwen2.5-0.5B LLM inference (inference/). Added get_path() env var support and SEC_FLAGS to all new targets. Skipped PR #22 (binary blob risk).

This commit is contained in:
Erik Bray 2026-03-03 17:18:02 +01:00
parent 99b06838bc
commit b4d81b71d4
12 changed files with 2881 additions and 1 deletions

8
.gitignore vendored
View File

@ -17,9 +17,17 @@ tiny_train_m1
train_large
training/train_large
training/train_large_ane
training/train_opt
training/train_double_buffer
training/test_*
!training/test_*.m
# Inference binaries
inference/qwen_ane
# Dynamic training binaries
training/training_dynamic/train
# Test/research binaries
test_chaining

88
PROBE_RESULTS.md Normal file
View File

@ -0,0 +1,88 @@
# ANE Probe Results: M4 (macOS 26.3)
**Machine:** Apple M4 (10 cores), 32GB RAM, macOS 26.3
**Date:** 2026-03-03
**ANE Family:** H16 (same as M5 results in `training/m5result.md`)
## Key Discovery: Compile and Eval Run in Parallel
**This was not known before.** The M5 probes tested compile and eval sequentially.
We tested with GCD `dispatch_async` and found they fully overlap.
### probe_v2.m Results
#### TEST 1: Pure Eval Throughput
```
Conv 128x128, spatial=64
1000 evals: 189.1ms total, 0.189ms/eval
11.09 GFLOPS sustained
```
#### TEST 2: Ping-pong (Two Pre-compiled Models)
```
500 ping-pong pairs: 207.4ms (0.415ms/pair, 0.207ms/eval)
```
Near-zero overhead switching between two loaded models.
#### TEST 3: Sequential Compile (20 Models)
```
All 20 models compiled and verified ✓
Compile time: ~23-29ms each (consistent, no degradation)
All 20 models correct with different scale factors
```
#### TEST 4: Background Compile Overlap ⭐
```
Background compile: 26.8ms
Foreground evals during compile: 119 (26.8ms total)
Overlap: YES — compile and eval CAN run in parallel!
Background model verified correct ✓
```
### Summary
| Metric | Value |
|--------|-------|
| Compile time | ~25ms per kernel set |
| Eval time | 0.189ms per eval |
| Compile:eval ratio | ~130:1 |
| Parallel compile+eval | **YES** |
| Max simultaneous models | 20+ |
| Ping-pong overhead | +10% vs single model |
## Peak ANE Throughput (inmem_peak)
```
Config W(MB) GFLOP ms/eval TFLOPS
96x conv 512ch sp64 48.0 3.22 0.429 ms 7.50
128x conv 512ch sp64 64.0 4.29 0.589 ms 7.30
256x conv 256ch sp64 32.0 2.15 0.380 ms 5.65
64x conv 512ch sp64 32.0 2.15 0.395 ms 5.43
```
Peak: **7.50 TFLOPS** (47% of 15.8 TFLOPS theoretical).
## Implications for Training
### Before (train_large.m)
- Synchronous compile: **88.6% of wall time is compilation**
- 55ms compile per batch, 0.54ms actual training
- Training throughput limited by compiler, not by ANE
### After (train_double_buffer.m)
- Async double-buffered compile: **0% compile stall**
- Background compile happens during forward/backward passes
- ~130 eval steps fit in one compile window
- Weight updates are "delayed" by one batch (standard technique in distributed training)
- Training throughput limited only by ANE eval speed
### Architecture
```
Time →
Active kernels: [=== eval batch N ===][=== eval batch N+1 ===][=== eval batch N+2 ===]
Background: [compile N+1 weights ][compile N+2 weights ][compile N+3 weights ]
↑ ↑ ↑
swap ready swap ready swap ready
```
Two kernel sets (A and B) alternate between active evaluation and background compilation.
When the background compile finishes, pointers swap atomically at the batch boundary.

119
inference/README.md Normal file
View File

@ -0,0 +1,119 @@
# ANE Inference — Full LLM on Apple Neural Engine
First complete LLM inference running directly on Apple's Neural Engine via reverse-engineered `_ANEClient` APIs. No CoreML. No Xcode compiler dependency at runtime. Token-for-token match with PyTorch.
Built on top of the [maderix/ANE](https://github.com/maderix/ANE) training runtime.
## What This Does
Runs **Qwen2.5-0.5B-Instruct** (24 transformer layers, 494M parameters) entirely on the ANE:
- **169 ANE kernels** compiled at startup via `_ANEInMemoryModel`
- **82 tokens/sec** decode on M4 Pro
- **Zero GPU usage** — runs on 16 dedicated neural cores
- **Correct output** — matches PyTorch reference token-for-token
All linear projections (Q, K, V, O, gate, up, down × 24 layers + chunked LM head) compile as baked-weight 1×1 convolution kernels on ANE. Element-wise ops (RMSNorm, RoPE, softmax, SiLU, attention scores) run on CPU via Accelerate BLAS.
## Architecture
```
Token → Embedding (CPU) → 24× Transformer Layer → LM Head (CPU) → Next Token
├── RMSNorm (CPU)
├── Q/K/V Projection (ANE conv kernel)
├── RoPE (CPU, rotate_half)
├── GQA Attention (CPU, 14 heads / 2 KV heads)
├── O Projection (ANE conv kernel)
├── Residual (CPU)
├── RMSNorm (CPU)
├── Gate/Up Projection (ANE conv kernel)
├── SiLU + elementwise mul (CPU)
├── Down Projection (ANE conv kernel)
└── Residual (CPU)
```
## Quick Start
```bash
# 1. Convert weights from HuggingFace safetensors to flat binary
pip install safetensors torch transformers
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct qwen05b.bin
# 2. Build
xcrun clang -O2 -framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -ldl -lobjc \
-o qwen_ane main.m
# 3. Run (pass space-separated token IDs)
./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20
# 4. With tokenizer (requires transformers)
python3 run.py "Say hello in one word."
```
## Output
```
=== Qwen2.5-0.5B ANE Inference ===
Loading weights...
Config: dim=896 hidden=4864 layers=24 heads=14 kv_heads=2 vocab=151936
Compiling ANE kernels (169 total)...
Compile time: 5.1s
Prompt: 28 tokens, generating up to 10
Prefill: 64.2 t/s (28 tokens)
OUT: 9707 13 151645
Decode: 82.4 t/s (2 tokens)
→ "Hello." (matches PyTorch exactly)
```
## Files
| File | What |
|------|------|
| `qwen_ane_infer.h` | Full 24-layer transformer forward pass, ANE kernel compilation, KV cache |
| `main.m` | Weight loader, token I/O, main generation loop |
| `convert_weights.py` | HuggingFace safetensors → flat f32 binary (includes Q/K/V biases) |
| `run.py` | Python wrapper with HuggingFace tokenizer |
## Model Support
Currently implements **Qwen2.5** architecture:
- GQA attention (grouped-query, `n_heads``n_kv_heads`)
- `rotate_half` RoPE (not interleaved pairs)
- SwiGLU FFN (gate + up + silu + down)
- Q/K/V bias (Qwen-specific)
- Tied word embeddings (lm_head = embed)
- Chunked LM head (vocab > 65536 exceeds ANE max dim)
Adapting to other architectures (LLaMA, Gemma, Mistral) requires:
1. Adjusting the config constants in `qwen_ane_infer.h`
2. Updating `convert_weights.py` for the weight naming scheme
3. Removing Q/K/V bias handling if the model doesn't have them
4. Switching RoPE to interleaved pairs if needed
## Requirements
- macOS 15+ on Apple Silicon (M1/M2/M3/M4)
- Xcode Command Line Tools (for `xcrun clang`)
- Python 3.9+ with `safetensors`, `torch`, `transformers` (for weight conversion)
## Known Limitations
- **CPU projections only** — ANE baked-weight conv kernels compile successfully but produce incorrect output (FP16 weight blob format mismatch). The `USE_ANE_PROJECTIONS` toggle exists but defaults to 0 (CPU via Accelerate BLAS). Fixing this would push decode speed from 82 t/s to 120+ t/s.
- **No persistent server** — each invocation recompiles 169 kernels (~5s). A server mode that compiles once and serves via HTTP would eliminate this overhead.
- **Single model** — hardcoded for Qwen2.5-0.5B. Needs parameterization for other sizes.
- **f32 weights** — 1.9GB on disk. FP16 or quantized weight support would halve this.
## How It Works
The key insight from maderix's reverse engineering: the ANE executes compiled MIL (Machine Learning Intermediate Language) programs as atomic graph operations. Each linear projection becomes a MIL program with baked FP16 weights, compiled in-memory via `_ANEInMemoryModel`, and executed through IOSurface-based zero-copy I/O.
We chain 169 of these atomic operations (7 per transformer layer + 16 LM head chunks) with CPU-side element-wise ops in between. The ANE handles the compute-heavy matmuls; the CPU handles the memory-bound operations (attention scores, softmax, RoPE).
## License
Same as maderix/ANE — research and educational use.

View File

@ -0,0 +1,107 @@
#!/usr/bin/env python3
"""Convert Qwen2.5-0.5B-Instruct safetensors → flat binary for ANE inference.
Output format: config header (7 ints) + all weights in f32, layer by layer.
Matches the layout expected by qwen_ane_infer.h.
Usage:
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct /path/to/output.bin
"""
import struct
import sys
import numpy as np
from pathlib import Path
from safetensors import safe_open
def convert(model_dir: str, output_path: str):
model_dir = Path(model_dir)
# Load safetensors
st_files = list(model_dir.glob("*.safetensors"))
if not st_files:
print(f"No safetensors files in {model_dir}")
sys.exit(1)
tensors = {}
for f in st_files:
with safe_open(str(f), framework="pt") as sf:
for key in sf.keys():
tensors[key] = sf.get_tensor(key).float().numpy()
print(f"Loaded {len(tensors)} tensors from {len(st_files)} files")
# Qwen2.5-0.5B config
dim = 896
hidden = 4864
n_layers = 24
n_heads = 14
n_kv_heads = 2
vocab_size = 151936
max_seq = 512
with open(output_path, "wb") as f:
# Config header: 7 x int32
f.write(struct.pack("iiiiiii",
dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq))
# Embedding [vocab, dim]
emb = tensors["model.embed_tokens.weight"].astype(np.float32)
print(f"embed: {emb.shape}")
f.write(emb.tobytes())
# Per-layer weights
for l in range(n_layers):
prefix = f"model.layers.{l}"
# Attention norm
rms_att = tensors[f"{prefix}.input_layernorm.weight"].astype(np.float32)
f.write(rms_att.tobytes())
# Q, K, V projections
wq = tensors[f"{prefix}.self_attn.q_proj.weight"].astype(np.float32)
wk = tensors[f"{prefix}.self_attn.k_proj.weight"].astype(np.float32)
wv = tensors[f"{prefix}.self_attn.v_proj.weight"].astype(np.float32)
wo = tensors[f"{prefix}.self_attn.o_proj.weight"].astype(np.float32)
f.write(wq.tobytes())
f.write(wk.tobytes())
f.write(wv.tobytes())
f.write(wo.tobytes())
# Q/K biases (Qwen has them)
# Q/K/V biases
qb = tensors.get(f"{prefix}.self_attn.q_proj.bias")
kb = tensors.get(f"{prefix}.self_attn.k_proj.bias")
vb = tensors.get(f"{prefix}.self_attn.v_proj.bias")
f.write((qb if qb is not None else np.zeros(wq.shape[0])).astype(np.float32).tobytes())
f.write((kb if kb is not None else np.zeros(wk.shape[0])).astype(np.float32).tobytes())
f.write((vb if vb is not None else np.zeros(wv.shape[0])).astype(np.float32).tobytes())
# FFN norm
rms_ffn = tensors[f"{prefix}.post_attention_layernorm.weight"].astype(np.float32)
f.write(rms_ffn.tobytes())
# FFN: gate, up, down
w_gate = tensors[f"{prefix}.mlp.gate_proj.weight"].astype(np.float32)
w_up = tensors[f"{prefix}.mlp.up_proj.weight"].astype(np.float32)
w_down = tensors[f"{prefix}.mlp.down_proj.weight"].astype(np.float32)
f.write(w_gate.tobytes())
f.write(w_up.tobytes())
f.write(w_down.tobytes())
print(f" Layer {l}: Q{wq.shape} K{wk.shape} V{wv.shape} O{wo.shape} "
f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape}")
# Final norm
rms_final = tensors["model.norm.weight"].astype(np.float32)
f.write(rms_final.tobytes())
size_mb = Path(output_path).stat().st_size / 1024 / 1024
print(f"\nWritten: {output_path} ({size_mb:.0f} MB)")
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python3 convert_weights.py <model_dir> <output.bin>")
sys.exit(1)
convert(sys.argv[1], sys.argv[2])

163
inference/main.m Normal file
View File

@ -0,0 +1,163 @@
// main.m Qwen2.5-0.5B inference on Apple Neural Engine
// Compiles ANE kernels for all linear projections, runs autoregressive decode.
//
// Build:
// xcrun clang -O2 -framework Foundation -framework IOSurface \
// -framework CoreML -framework Accelerate -ldl -lobjc \
// -o qwen_ane main.m
//
// Run:
// ./qwen_ane qwen05b.bin "Hello world"
//
#import <Foundation/Foundation.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "qwen_ane_infer.h"
int g_fp16_io = 0;
static QwenModel g_model;
static int load_weights(const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
// Read config header
int config[7];
fread(config, sizeof(int), 7, f);
int dim = config[0], hidden = config[1], n_layers = config[2];
int n_heads = config[3], n_kv_heads = config[4], vocab = config[5];
printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d\n",
dim, hidden, n_layers, n_heads, n_kv_heads, vocab);
int q_dim = n_heads * QWEN_HEAD_DIM;
int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
// Embedding
g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float));
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
// Per-layer
for (int l = 0; l < n_layers; l++) {
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_att[l], sizeof(float), dim, f);
g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f);
g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); // o_proj is [dim, q_dim]
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
// Q/K/V biases
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
fread(g_model.q_bias[l], sizeof(float), q_dim, f);
fread(g_model.k_bias[l], sizeof(float), kv_dim, f);
fread(g_model.v_bias[l], sizeof(float), kv_dim, f);
g_model.rms_ffn[l] = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_ffn[l], sizeof(float), dim, f);
g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f);
g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f);
g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float));
fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f);
}
g_model.rms_final = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_final, sizeof(float), dim, f);
fclose(f);
printf("Weights loaded (%.0f MB)\n",
(float)ftell(f) / 1024 / 1024);
return 0;
}
int main(int argc, char **argv) {
@autoreleasepool {
if (argc < 3) {
fprintf(stderr, "Usage: %s <weights.bin> <prompt>\n", argv[0]);
return 1;
}
printf("=== Qwen2.5-0.5B ANE Inference ===\n\n");
// Load weights
printf("Loading weights...\n");
if (load_weights(argv[1]) != 0) return 1;
// Allocate buffers
qwen_alloc(&g_model);
// Compile ANE kernels
printf("Compiling ANE kernels (169 total)...\n");
struct timespec t0, t1;
clock_gettime(CLOCK_MONOTONIC, &t0);
qwen_compile_kernels(&g_model);
clock_gettime(CLOCK_MONOTONIC, &t1);
double compile_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9;
printf("Compile time: %.1fs\n\n", compile_sec);
// Parse token IDs from argv[2] (space-separated)
// argv[3] = max generation tokens
int max_gen = 50;
if (argc >= 4) max_gen = atoi(argv[3]);
// Parse input token IDs
int prompt_ids[2048];
int n_prompt = 0;
char *tok_str = strdup(argv[2]);
char *saveptr;
char *p = strtok_r(tok_str, " ", &saveptr);
while (p && n_prompt < 2048) {
prompt_ids[n_prompt++] = atoi(p);
p = strtok_r(NULL, " ", &saveptr);
}
free(tok_str);
printf("Prompt: %d tokens, generating up to %d\n", n_prompt, max_gen);
clock_gettime(CLOCK_MONOTONIC, &t0);
// Prefill: feed all prompt tokens
int next = 0;
for (int i = 0; i < n_prompt; i++) {
next = qwen_forward(&g_model, prompt_ids[i]);
}
struct timespec t_prefill;
clock_gettime(CLOCK_MONOTONIC, &t_prefill);
double prefill_sec = (t_prefill.tv_sec - t0.tv_sec) + (t_prefill.tv_nsec - t0.tv_nsec) / 1e9;
printf("Prefill: %d tokens in %.2fs (%.1f t/s)\n", n_prompt, prefill_sec, n_prompt / prefill_sec);
// Generate
int eos = 151645; // <|im_end|>
int eos2 = 151643; // <|endoftext|>
printf("OUT:");
for (int i = 0; i < max_gen; i++) {
printf(" %d", next);
fflush(stdout);
if (next == eos || next == eos2) break;
next = qwen_forward(&g_model, next);
}
printf("\n");
clock_gettime(CLOCK_MONOTONIC, &t1);
double gen_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9;
int total_tokens = g_model.pos;
int gen_tokens = total_tokens - n_prompt;
double decode_sec = gen_sec - prefill_sec;
printf("\nTotal: %d tokens in %.2fs\n", total_tokens, gen_sec);
printf("Prefill: %.1f t/s (%d tokens)\n", n_prompt / prefill_sec, n_prompt);
printf("Decode: %.1f t/s (%d tokens)\n",
decode_sec > 0 ? gen_tokens / decode_sec : 0, gen_tokens);
return 0;
}
}

435
inference/qwen_ane_infer.h Normal file
View File

@ -0,0 +1,435 @@
// qwen_ane_infer.h — Qwen2.5-0.5B inference on Apple Neural Engine
// Linear projections on ANE (baked-weight conv kernels), CPU for element-wise ops.
// Based on maderix/ANE runtime + MIL generation.
#pragma once
#include "../training/ane_runtime.h"
#include "../training/ane_mil_gen.h"
// Compile a matmul kernel: W[out_ch, in_ch] @ x[in_ch] → y[out_ch]
// Uses the two-input matmul MIL variant (weights passed as input, not baked)
static ANEKernel *compile_matmul_kernel(int in_ch, int out_ch) {
NSString *mil = mil_gen_matmul(in_ch, out_ch, 1);
size_t inputSizes[2] = {(size_t)in_ch * 1 * 4, (size_t)out_ch * in_ch * 4};
size_t outBytes = (size_t)out_ch * 1 * 4;
return ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], nil, 2, inputSizes, 1, &outBytes);
}
// Compile a baked-weight conv kernel (from model.h)
static ANEKernel *compile_conv_kernel(const float *weights, int in_ch, int out_ch, int spatial) {
NSData *wb = mil_build_weight_blob(weights, out_ch, in_ch);
NSString *mil = mil_gen_conv(in_ch, out_ch, spatial);
size_t inBytes = (size_t)in_ch * spatial * 4;
size_t outBytes = (size_t)out_ch * spatial * 4;
return ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, 1, &inBytes, 1, &outBytes);
}
#include <math.h>
#include <string.h>
#include <time.h>
// Qwen2.5-0.5B-Instruct architecture
#define QWEN_DIM 896
#define QWEN_HIDDEN 4864
#define QWEN_LAYERS 24
#define QWEN_HEADS 14
#define QWEN_KV_HEADS 2
#define QWEN_HEAD_DIM 64
#define QWEN_VOCAB 151936
#define QWEN_RMS_EPS 1e-6f
#define QWEN_ROPE_THETA 1000000.0f
#define QWEN_MAX_SEQ 512
// GQA: each KV head serves (HEADS / KV_HEADS) query heads
#define QWEN_GQA_FACTOR (QWEN_HEADS / QWEN_KV_HEADS)
// Sizes for GQA projections
#define QWEN_Q_DIM (QWEN_HEADS * QWEN_HEAD_DIM) // 896
#define QWEN_KV_DIM (QWEN_KV_HEADS * QWEN_HEAD_DIM) // 128
typedef struct {
// Weights (f32)
float *embed; // [vocab, dim]
float *rms_att[QWEN_LAYERS]; // [dim]
float *wq[QWEN_LAYERS]; // [q_dim, dim]
float *wk[QWEN_LAYERS]; // [kv_dim, dim]
float *wv[QWEN_LAYERS]; // [kv_dim, dim]
float *wo[QWEN_LAYERS]; // [dim, q_dim]
float *rms_ffn[QWEN_LAYERS]; // [dim]
float *w_gate[QWEN_LAYERS]; // [hidden, dim]
float *w_up[QWEN_LAYERS]; // [hidden, dim]
float *w_down[QWEN_LAYERS]; // [dim, hidden]
float *rms_final; // [dim]
// wcls = embed (tied)
// ANE kernels (one per linear projection per layer)
ANEKernel *k_q[QWEN_LAYERS];
ANEKernel *k_k[QWEN_LAYERS];
ANEKernel *k_v[QWEN_LAYERS];
ANEKernel *k_o[QWEN_LAYERS];
ANEKernel *k_gate[QWEN_LAYERS];
ANEKernel *k_up[QWEN_LAYERS];
ANEKernel *k_down[QWEN_LAYERS];
// LM head chunked: vocab too large for single ANE kernel (max 65536)
#define QWEN_LM_CHUNKS 16
#define QWEN_LM_CHUNK_SIZE 9496 // 151936 / 16
ANEKernel *k_lmhead[QWEN_LM_CHUNKS];
// Q/K/V biases per layer
float *q_bias[QWEN_LAYERS]; // [q_dim]
float *k_bias[QWEN_LAYERS]; // [kv_dim]
float *v_bias[QWEN_LAYERS]; // [kv_dim]
// KV cache [layer][kv_heads * head_dim * max_seq]
float *kv_cache_k[QWEN_LAYERS];
float *kv_cache_v[QWEN_LAYERS];
int pos; // current position in sequence
// Scratch buffers
float *x; // [dim]
float *xb; // [dim]
float *q; // [q_dim]
float *k; // [kv_dim]
float *v; // [kv_dim]
float *att; // [heads * max_seq]
float *hb; // [hidden]
float *hb2; // [hidden]
float *logits; // [vocab]
} QwenModel;
// ── CPU ops ──────────────────────────────────────────────────────────
static void qwen_rmsnorm(float *out, const float *x, const float *w, int D) {
float ss = 0;
for (int i = 0; i < D; i++) ss += x[i] * x[i];
ss = 1.0f / sqrtf(ss / D + QWEN_RMS_EPS);
for (int i = 0; i < D; i++) out[i] = x[i] * ss * w[i];
}
static void qwen_rope(float *q, float *k, int pos, int n_q_heads, int n_kv_heads, int head_dim) {
// Qwen uses rotate_half RoPE (NOT interleaved pairs):
// rotate_half(x) = [-x[dim/2:], x[:dim/2]]
// q_embed = q * cos + rotate_half(q) * sin
// cos/sin have shape [head_dim/2] and are applied to both halves
int half = head_dim / 2;
// Precompute cos/sin for this position (head_dim/2 frequencies)
float cos_v[half], sin_v[half];
for (int i = 0; i < half; i++) {
float freq = 1.0f / powf(QWEN_ROPE_THETA, (float)(2 * i) / head_dim);
float angle = pos * freq;
cos_v[i] = cosf(angle);
sin_v[i] = sinf(angle);
}
// Apply to Q heads
for (int h = 0; h < n_q_heads; h++) {
float *qh = q + h * head_dim;
for (int i = 0; i < half; i++) {
float q_first = qh[i];
float q_second = qh[i + half];
// rotate_half: [-q_second, q_first]
qh[i] = q_first * cos_v[i] + (-q_second) * sin_v[i];
qh[i + half] = q_second * cos_v[i] + q_first * sin_v[i];
}
}
// Apply to K heads
for (int h = 0; h < n_kv_heads; h++) {
float *kh = k + h * head_dim;
for (int i = 0; i < half; i++) {
float k_first = kh[i];
float k_second = kh[i + half];
kh[i] = k_first * cos_v[i] + (-k_second) * sin_v[i];
kh[i + half] = k_second * cos_v[i] + k_first * sin_v[i];
}
}
}
static void qwen_silu(float *x, int n) {
for (int i = 0; i < n; i++)
x[i] = x[i] / (1.0f + expf(-x[i]));
}
// ── ANE projection helper (single token: spatial=1) ─────────────────
static void ane_project(ANEKernel *kernel, const float *in, float *out,
int in_dim, int out_dim) {
// For single-token inference: spatial=1
ane_write_input(kernel, 0, in, in_dim * sizeof(float));
ane_eval(kernel);
ane_read_output(kernel, 0, out, out_dim * sizeof(float));
}
// CPU matmul via Accelerate BLAS: y = W @ x, W[out_dim, in_dim]
#include <Accelerate/Accelerate.h>
static void cpu_project(const float *W, const float *x, float *y, int in_dim, int out_dim) {
// y = W @ x where W is [out_dim, in_dim] row-major
// cblas_sgemv: y = alpha * A * x + beta * y
cblas_sgemv(CblasRowMajor, CblasNoTrans,
out_dim, in_dim,
1.0f, W, in_dim,
x, 1,
0.0f, y, 1);
}
// Toggle: 1 = use ANE for projections, 0 = CPU fallback
#define USE_ANE_PROJECTIONS 0
// ── Forward one token ────────────────────────────────────────────────
static int qwen_forward(QwenModel *m, int token) {
int D = QWEN_DIM, HD = QWEN_HIDDEN;
int pos = m->pos;
// Token embedding
memcpy(m->x, m->embed + token * D, D * sizeof(float));
for (int l = 0; l < QWEN_LAYERS; l++) {
// Attention RMSNorm
qwen_rmsnorm(m->xb, m->x, m->rms_att[l], D);
// Debug: print first layer input/output norms
if (l == 0 && pos == 0) {
float xnorm = 0, qnorm = 0;
for (int i = 0; i < D; i++) xnorm += m->xb[i] * m->xb[i];
printf(" L0 RMSNorm out norm=%.4f (first 4: %.4f %.4f %.4f %.4f)\n",
sqrtf(xnorm), m->xb[0], m->xb[1], m->xb[2], m->xb[3]);
}
// QKV projections (ANE) + bias
#if USE_ANE_PROJECTIONS
ane_project(m->k_q[l], m->xb, m->q, D, QWEN_Q_DIM);
ane_project(m->k_k[l], m->xb, m->k, D, QWEN_KV_DIM);
ane_project(m->k_v[l], m->xb, m->v, D, QWEN_KV_DIM);
#else
cpu_project(m->wq[l], m->xb, m->q, D, QWEN_Q_DIM);
cpu_project(m->wk[l], m->xb, m->k, D, QWEN_KV_DIM);
cpu_project(m->wv[l], m->xb, m->v, D, QWEN_KV_DIM);
#endif
// Apply Q/K biases
if (m->q_bias[l]) {
for (int i = 0; i < QWEN_Q_DIM; i++) m->q[i] += m->q_bias[l][i];
}
if (m->k_bias[l]) {
for (int i = 0; i < QWEN_KV_DIM; i++) m->k[i] += m->k_bias[l][i];
}
if (m->v_bias[l]) {
for (int i = 0; i < QWEN_KV_DIM; i++) m->v[i] += m->v_bias[l][i];
}
if (l == 0 && pos == 0) {
float qn = 0;
for (int i = 0; i < QWEN_Q_DIM; i++) qn += m->q[i] * m->q[i];
printf(" L0 ANE Q norm=%.4f (first 4: %.4f %.4f %.4f %.4f)\n",
sqrtf(qn), m->q[0], m->q[1], m->q[2], m->q[3]);
// CPU reference
float cpu_q[4] = {0};
for (int i = 0; i < 4; i++) {
for (int j = 0; j < D; j++)
cpu_q[i] += m->wq[0][i * D + j] * m->xb[j];
cpu_q[i] += m->q_bias[0][i];
}
printf(" L0 CPU Q first 4: %.4f %.4f %.4f %.4f\n",
cpu_q[0], cpu_q[1], cpu_q[2], cpu_q[3]);
}
// RoPE
qwen_rope(m->q, m->k, pos, QWEN_HEADS, QWEN_KV_HEADS, QWEN_HEAD_DIM);
// Store K, V in cache
memcpy(m->kv_cache_k[l] + pos * QWEN_KV_DIM,
m->k, QWEN_KV_DIM * sizeof(float));
memcpy(m->kv_cache_v[l] + pos * QWEN_KV_DIM,
m->v, QWEN_KV_DIM * sizeof(float));
// GQA attention (CPU — element-wise ops)
float scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM);
float *attn_out = m->xb; // reuse buffer
memset(attn_out, 0, QWEN_Q_DIM * sizeof(float));
for (int h = 0; h < QWEN_HEADS; h++) {
int kv_h = h / QWEN_GQA_FACTOR;
float *qh = m->q + h * QWEN_HEAD_DIM;
// Attention scores: Q @ K^T for all positions up to pos
float max_score = -1e9f;
for (int t = 0; t <= pos; t++) {
float *kt = m->kv_cache_k[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM;
// Use BLAS dot product for precision
float score = cblas_sdot(QWEN_HEAD_DIM, qh, 1, kt, 1);
m->att[h * QWEN_MAX_SEQ + t] = score * scale;
if (score * scale > max_score) max_score = score * scale;
}
// Softmax (double accumulation for precision)
double sum = 0;
for (int t = 0; t <= pos; t++) {
m->att[h * QWEN_MAX_SEQ + t] = expf(m->att[h * QWEN_MAX_SEQ + t] - max_score);
sum += (double)m->att[h * QWEN_MAX_SEQ + t];
}
float inv_sum = (float)(1.0 / sum);
for (int t = 0; t <= pos; t++)
m->att[h * QWEN_MAX_SEQ + t] *= inv_sum;
// Weighted sum of V: attn_out[h] += att[t] * V[t] for each t
for (int t = 0; t <= pos; t++) {
float a = m->att[h * QWEN_MAX_SEQ + t];
float *vt = m->kv_cache_v[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM;
cblas_saxpy(QWEN_HEAD_DIM, a, vt, 1,
attn_out + h * QWEN_HEAD_DIM, 1);
}
}
float o_out[QWEN_DIM];
#if USE_ANE_PROJECTIONS
ane_project(m->k_o[l], attn_out, o_out, QWEN_Q_DIM, D);
#else
cpu_project(m->wo[l], attn_out, o_out, QWEN_Q_DIM, D);
#endif
// Residual
for (int i = 0; i < D; i++) m->x[i] += o_out[i];
if (l == 0 && pos == 0) {
float pan = 0;
for (int i = 0; i < D; i++) pan += m->x[i] * m->x[i];
printf(" L0 post-attn norm=%.4f first4=[%.6f, %.6f, %.6f, %.6f]\n",
sqrtf(pan), m->x[0], m->x[1], m->x[2], m->x[3]);
float on = 0;
for (int i = 0; i < D; i++) on += o_out[i] * o_out[i];
printf(" L0 o_proj out norm=%.4f first4=[%.6f, %.6f, %.6f, %.6f]\n",
sqrtf(on), o_out[0], o_out[1], o_out[2], o_out[3]);
}
// FFN RMSNorm
qwen_rmsnorm(m->xb, m->x, m->rms_ffn[l], D);
// SwiGLU FFN
#if USE_ANE_PROJECTIONS
ane_project(m->k_gate[l], m->xb, m->hb, D, HD);
ane_project(m->k_up[l], m->xb, m->hb2, D, HD);
#else
cpu_project(m->w_gate[l], m->xb, m->hb, D, HD);
cpu_project(m->w_up[l], m->xb, m->hb2, D, HD);
#endif
if (l == 0 && pos == 0) {
float gn = 0, un = 0;
for (int i = 0; i < HD; i++) { gn += m->hb[i]*m->hb[i]; un += m->hb2[i]*m->hb2[i]; }
printf(" L0 gate norm=%.4f up norm=%.4f\n", sqrtf(gn), sqrtf(un));
printf(" L0 gate first4=[%.6f, %.6f, %.6f, %.6f]\n",
m->hb[0], m->hb[1], m->hb[2], m->hb[3]);
}
qwen_silu(m->hb, HD);
for (int i = 0; i < HD; i++) m->hb[i] *= m->hb2[i];
float ffn_out[QWEN_DIM];
#if USE_ANE_PROJECTIONS
ane_project(m->k_down[l], m->hb, ffn_out, HD, D);
#else
cpu_project(m->w_down[l], m->hb, ffn_out, HD, D);
#endif
// Residual
for (int i = 0; i < D; i++) m->x[i] += ffn_out[i];
// Debug: hidden state after each layer (first 3 layers, first token only)
if (l < 3 && pos == 0) {
float hn = 0;
for (int i = 0; i < D; i++) hn += m->x[i] * m->x[i];
printf(" C hidden[%d] norm=%.4f first4=[%.4f, %.4f, %.4f, %.4f]\n",
l+1, sqrtf(hn), m->x[0], m->x[1], m->x[2], m->x[3]);
}
}
// Final RMSNorm
qwen_rmsnorm(m->xb, m->x, m->rms_final, D);
// Debug: check final hidden state before LM head
if (m->pos < 2) {
float fn = 0;
for (int i = 0; i < D; i++) fn += m->xb[i] * m->xb[i];
printf(" Final hidden norm=%.4f (first 4: %.6f %.6f %.6f %.6f)\n",
sqrtf(fn), m->xb[0], m->xb[1], m->xb[2], m->xb[3]);
}
// LM head via Accelerate BLAS: logits = embed @ xb
// embed is [vocab, dim] row-major
cblas_sgemv(CblasRowMajor, CblasNoTrans,
QWEN_VOCAB, D,
1.0f, m->embed, D,
m->xb, 1,
0.0f, m->logits, 1);
// Debug: check logits
if (m->pos < 2) {
float lmax = m->logits[0], lmin = m->logits[0];
int nonzero = 0;
for (int i = 0; i < QWEN_VOCAB; i++) {
if (m->logits[i] > lmax) lmax = m->logits[i];
if (m->logits[i] < lmin) lmin = m->logits[i];
if (m->logits[i] != 0.0f) nonzero++;
}
printf(" Logits: min=%.4f max=%.4f nonzero=%d/%d\n", lmin, lmax, nonzero, QWEN_VOCAB);
}
m->pos++;
// Argmax
int max_idx = 0;
float max_val = m->logits[0];
for (int i = 1; i < QWEN_VOCAB; i++) {
if (m->logits[i] > max_val) {
max_val = m->logits[i];
max_idx = i;
}
}
return max_idx;
}
// ── Compile all ANE kernels ──────────────────────────────────────────
static void qwen_compile_kernels(QwenModel *m) {
int D = QWEN_DIM, HD = QWEN_HIDDEN;
printf("Compiling %d ANE kernels...\n", QWEN_LAYERS * 7 + 1);
for (int l = 0; l < QWEN_LAYERS; l++) {
m->k_q[l] = compile_conv_kernel(m->wq[l], D, QWEN_Q_DIM, 1);
m->k_k[l] = compile_conv_kernel(m->wk[l], D, QWEN_KV_DIM, 1);
m->k_v[l] = compile_conv_kernel(m->wv[l], D, QWEN_KV_DIM, 1);
m->k_o[l] = compile_conv_kernel(m->wo[l], QWEN_Q_DIM, D, 1);
m->k_gate[l] = compile_conv_kernel(m->w_gate[l], D, HD, 1);
m->k_up[l] = compile_conv_kernel(m->w_up[l], D, HD, 1);
m->k_down[l] = compile_conv_kernel(m->w_down[l], HD, D, 1);
printf(" Layer %d/%d compiled\r", l+1, QWEN_LAYERS);
fflush(stdout);
}
// LM head (tied = embedding, chunked into 16 pieces)
for (int c = 0; c < QWEN_LM_CHUNKS; c++) {
float *chunk_weights = m->embed + c * QWEN_LM_CHUNK_SIZE * D;
m->k_lmhead[c] = compile_conv_kernel(chunk_weights, D, QWEN_LM_CHUNK_SIZE, 1);
if (!m->k_lmhead[c]) {
printf(" LM head chunk %d FAILED to compile\n", c);
}
}
printf("\nAll kernels compiled.\n");
}
// ── Allocate buffers ─────────────────────────────────────────────────
static void qwen_alloc(QwenModel *m) {
m->x = (float*)calloc(QWEN_DIM, sizeof(float));
m->xb = (float*)calloc(QWEN_DIM, sizeof(float));
m->q = (float*)calloc(QWEN_Q_DIM, sizeof(float));
m->k = (float*)calloc(QWEN_KV_DIM, sizeof(float));
m->v = (float*)calloc(QWEN_KV_DIM, sizeof(float));
m->att = (float*)calloc(QWEN_HEADS * QWEN_MAX_SEQ, sizeof(float));
m->hb = (float*)calloc(QWEN_HIDDEN, sizeof(float));
m->hb2 = (float*)calloc(QWEN_HIDDEN, sizeof(float));
m->logits = (float*)calloc(QWEN_VOCAB, sizeof(float));
for (int l = 0; l < QWEN_LAYERS; l++) {
m->kv_cache_k[l] = (float*)calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float));
m->kv_cache_v[l] = (float*)calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float));
}
m->pos = 0;
}

74
inference/run.py Normal file
View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
"""Run Qwen2.5-0.5B on ANE with proper tokenization.
Usage:
python3 run.py "Your prompt here" [--max-tokens 50]
"""
import argparse
import ctypes
import struct
import sys
import time
from pathlib import Path
INFERENCE_DIR = Path(__file__).parent
WEIGHTS_PATH = INFERENCE_DIR / "qwen05b.bin"
MODEL_DIR = Path.home() / "models" / "Qwen2.5-0.5B-Instruct"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("prompt", type=str)
parser.add_argument("--max-tokens", type=int, default=50)
args = parser.parse_args()
from transformers import AutoTokenizer
print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(str(MODEL_DIR), trust_remote_code=True)
# Build chat template
messages = [
{"role": "system", "content": "You are a helpful assistant. Be concise."},
{"role": "user", "content": args.prompt},
]
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tok.encode(text)
print(f"Prompt tokens: {len(input_ids)}")
# Run the C binary — pass token IDs as arguments
import subprocess
binary = str(INFERENCE_DIR / "qwen_ane")
# We need to modify the binary to accept token IDs as input
# For now, print the token IDs so we can verify tokenization
print(f"First 10 tokens: {input_ids[:10]}")
print(f"Token text: {[tok.decode([t]) for t in input_ids[:10]]}")
print(f"\nRunning ANE inference with {len(input_ids)} prompt tokens + {args.max_tokens} generation...")
# Call binary with token IDs piped via stdin
result = subprocess.run(
[binary, str(WEIGHTS_PATH), " ".join(str(t) for t in input_ids),
str(args.max_tokens)],
capture_output=True, text=True, timeout=120,
)
print(result.stdout)
if result.stderr:
print(result.stderr[:500], file=sys.stderr)
# Parse output token IDs from binary stdout
output_ids = []
for line in result.stdout.split("\n"):
if line.startswith("OUT:"):
ids = [int(x) for x in line[4:].split() if x.isdigit()]
output_ids.extend(ids)
if output_ids:
decoded = tok.decode(output_ids, skip_special_tokens=True)
print(f"\n=== Response ===\n{decoded}")
else:
print("\n(No output tokens parsed — binary may need token ID input mode)")
if __name__ == "__main__":
main()

View File

@ -21,6 +21,14 @@ train_large: train_large.m $(HEADERS_LARGE)
train_large_ane: train_large_ane.m $(HEADERS_ANE)
$(CC) $(CFLAGS) -o $@ train_large_ane.m $(LDFLAGS) -framework Accelerate
HEADERS_OPT = $(HEADERS_LARGE) stories_cpu_ops_opt.h
train_opt: train_opt.m $(HEADERS_OPT)
$(CC) $(CFLAGS) -o $@ train_opt.m $(LDFLAGS) -framework Accelerate -framework Metal -framework MetalPerformanceShaders
train_double_buffer: train_double_buffer.m $(HEADERS_LARGE)
$(CC) $(CFLAGS) -o $@ train_double_buffer.m $(LDFLAGS) -framework Accelerate
PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced
test_rmsnorm_bwd: test_rmsnorm_bwd.m $(HEADERS_ANE)
@ -65,7 +73,7 @@ verify-flags:
@xcrun clang --version
clean:
rm -f train train_large train_large_ane $(PROBES) test_rmsnorm_bwd test_classifier
rm -f train train_large train_large_ane train_opt train_double_buffer $(PROBES) test_rmsnorm_bwd test_classifier
.PHONY: clean tokenize probes verify-flags data setup

View File

@ -0,0 +1,110 @@
// stories_cpu_ops_opt.h — Optimized CPU operations: NEON Adam, vectorized embedding
#pragma once
#include "stories_cpu_ops.h"
#include <arm_neon.h>
// ===== NEON-vectorized Adam optimizer =====
// ~3-3.5x faster than scalar version for large param counts
// Uses vrsqrteq_f32 + one Newton-Raphson step for fast reciprocal sqrt
static void adam_update_opt(float *w, const float *g, AdamState *s, int t,
float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t);
float bc2 = 1.0f - powf(b2, t);
float inv_bc1 = 1.0f / bc1;
float inv_bc2 = 1.0f / bc2;
float one_minus_b1 = 1.0f - b1;
float one_minus_b2 = 1.0f - b2;
float32x4_t vb1 = vdupq_n_f32(b1);
float32x4_t vb2 = vdupq_n_f32(b2);
float32x4_t v1mb1 = vdupq_n_f32(one_minus_b1);
float32x4_t v1mb2 = vdupq_n_f32(one_minus_b2);
float32x4_t vinv_bc1 = vdupq_n_f32(inv_bc1);
float32x4_t vinv_bc2 = vdupq_n_f32(inv_bc2);
float32x4_t vneg_lr = vdupq_n_f32(-lr);
float32x4_t veps = vdupq_n_f32(eps);
size_t n = s->n;
size_t i = 0;
// Process 4 elements at a time
for (; i + 3 < n; i += 4) {
// Load
float32x4_t vm = vld1q_f32(s->m + i);
float32x4_t vv = vld1q_f32(s->v + i);
float32x4_t vg = vld1q_f32(g + i);
float32x4_t vw = vld1q_f32(w + i);
// m = b1*m + (1-b1)*g
vm = vmlaq_f32(vmulq_f32(vb1, vm), v1mb1, vg);
// v = b2*v + (1-b2)*g*g
float32x4_t g2 = vmulq_f32(vg, vg);
vv = vmlaq_f32(vmulq_f32(vb2, vv), v1mb2, g2);
// Store updated m, v
vst1q_f32(s->m + i, vm);
vst1q_f32(s->v + i, vv);
// mhat = m / bc1, vhat = v / bc2
float32x4_t mhat = vmulq_f32(vm, vinv_bc1);
float32x4_t vhat = vmulq_f32(vv, vinv_bc2);
// Fast reciprocal sqrt: vrsqrteq + one Newton-Raphson iteration
// rsqrt_est ≈ 1/sqrt(vhat)
float32x4_t rsqrt_est = vrsqrteq_f32(vhat);
// Newton-Raphson: rsqrt *= (3 - vhat * rsqrt^2) / 2
float32x4_t rsqrt_sq = vmulq_f32(rsqrt_est, rsqrt_est);
float32x4_t nr_step = vrsqrtsq_f32(vhat, rsqrt_sq);
rsqrt_est = vmulq_f32(rsqrt_est, nr_step);
// w -= lr * mhat / (sqrt(vhat) + eps)
// = w + (-lr) * mhat * (1/(sqrt(vhat) + eps))
// Compute sqrt(vhat) from rsqrt: sqrt = vhat * rsqrt(vhat) (avoids division)
float32x4_t sqrt_vhat = vmulq_f32(vhat, rsqrt_est);
float32x4_t denom = vaddq_f32(sqrt_vhat, veps);
// Use vdivq_f32 for the final division (accurate, eps-adjusted)
float32x4_t update = vmulq_f32(vneg_lr, vdivq_f32(mhat, denom));
vw = vaddq_f32(vw, update);
vst1q_f32(w + i, vw);
}
// Scalar tail
for (; i < n; i++) {
s->m[i] = b1 * s->m[i] + one_minus_b1 * g[i];
s->v[i] = b2 * s->v[i] + one_minus_b2 * g[i] * g[i];
float mh = s->m[i] * inv_bc1;
float vh = s->v[i] * inv_bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
}
// ===== Vectorized embedding lookup =====
// Gather rows from [VOCAB, DIM] row-major embed table → x [DIM, SEQ] channel-first
// Strategy: gather token rows into temp buffer [SEQ, DIM], then transpose via vDSP_mtrans
static void embed_lookup_opt(float *x, const float *embed, const uint16_t *tokens,
int dim, int seq, float *tmp) {
// Gather: tmp[t*dim + d] = embed[tokens[t]*dim + d]
for (int t = 0; t < seq; t++) {
memcpy(tmp + t * dim, embed + tokens[t] * dim, dim * sizeof(float));
}
// Transpose [SEQ, DIM] → [DIM, SEQ]: x[d*seq + t] = tmp[t*dim + d]
vDSP_mtrans(tmp, 1, x, 1, (vDSP_Length)dim, (vDSP_Length)seq);
}
// ===== Vectorized embedding backward =====
// Accumulate dE[tok] += dx[:,t] for each position
// Strategy: transpose dx [DIM, SEQ] → tmp [SEQ, DIM], then accumulate rows
static void embed_backward_opt(float *d_embed, const float *dx, const uint16_t *tokens,
int dim, int seq, float *tmp) {
// Transpose [DIM, SEQ] → [SEQ, DIM]: tmp[t*dim + d] = dx[d*seq + t]
vDSP_mtrans(dx, 1, tmp, 1, (vDSP_Length)seq, (vDSP_Length)dim);
// Scatter-add: d_embed[tok*dim .. (tok+1)*dim] += tmp[t*dim .. (t+1)*dim]
for (int t = 0; t < seq; t++) {
vDSP_vadd(tmp + t * dim, 1,
d_embed + tokens[t] * dim, 1,
d_embed + tokens[t] * dim, 1,
(vDSP_Length)dim);
}
}

View File

@ -85,6 +85,12 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// Read raw fp16 from IOSurface without conversion (for fp16 activation cache)
static void io_read_raw_fp16(IOSurfaceRef s, _Float16 *data, int ch_off, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp * sizeof(_Float16));
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
// Kernel compile/eval
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {

View File

@ -0,0 +1,791 @@
// train_double_buffer.m Double-buffered async ANE training for stories110M
// Based on train_large.m with the key innovation: compile and eval overlap via GCD
// Discovery: probe_v2.m proved ANE compile and eval can run in parallel
// Architecture: two kernel sets (A/B), background compile while active set runs
// 5 weight-bearing ANE kernels per layer × 12 layers = 60 per compile batch
#include <stdatomic.h>
#include "stories_io.h"
#include "stories_mil.h"
#include "stories_cpu_ops.h"
// Double-buffer needs more compile budget than single-buffer
// The original MAX_COMPILES=100 only allows 1 batch per exec() restart
// We push higher to allow initial compile + at least 1 background compile
// If ANE rejects at ~119, the exec() restart will handle it gracefully
#define DB_MAX_COMPILES 250
#define CKPT_PATH_DEFAULT "ane_db_ckpt.bin"
#define MODEL_PATH_DEFAULT "../../assets/models/stories110M.bin"
#define DATA_PATH_DEFAULT "tinystories_data00.bin"
static const char *get_path(const char *env_var, const char *default_val) {
const char *v = getenv(env_var);
return (v && v[0]) ? v : default_val;
}
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS);
fclose(f); return false;
}
int V = abs(cfg.vocab_size);
bool shared = cfg.vocab_size > 0;
// Read in llama2.c order: embed, rms_att[all], wq[all], wk[all], wv[all], wo[all],
// rms_ffn[all], w1[all], w2[all], w3[all], rms_final, [wcls]
fread(embed, 4, V * DIM, f);
// rms_att weights for all layers (contiguous)
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
// wq for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
// wk for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
// wv for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
// wo for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
// rms_ffn weights for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
// w1 for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
// w2 for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
// w3 for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
// rms_final
fread(rms_final, 4, DIM, f);
// wcls = embed if shared (we just use embed pointer)
fclose(f);
printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls");
return true;
}
// ===== Compile one layer's kernels =====
static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) {
lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{
@"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)},
@"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)},
@"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)},
@"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)},
@"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)},
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
}), DIM*SEQ*2, 6*DIM*SEQ*2);
lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{
@"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)},
@"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)},
@"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)},
@"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)},
}), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2);
lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{
@"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)},
@"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)},
@"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)},
}), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2);
lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
@"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)},
}), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{
@"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)},
@"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)},
@"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)},
}), 3*DIM*SEQ*2, DIM*SEQ*2);
return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd;
}
// Compile weight-free sdpaBwd2 (only needs once, no weights)
static Kern *compile_sdpa_bwd2(void) {
return compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
}
static void free_layer_kernels(LayerKernels *lk) {
free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd);
free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd);
// sdpaBwd2 is shared, freed separately
lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL;
}
// ===== Checkpoint save/load =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double cc, double ct, double cw, int cs, int cb, int adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 2;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw;
h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t;
fwrite(&h, sizeof(h), 1, f);
// Per-layer weights + adam
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
// Adam state
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
}
fwrite(rms_final,4,DIM,f);
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
fwrite(embed,4,VOCAB*DIM,f);
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
fclose(f);
}
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "rb");
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall;
*cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
}
fread(rms_final,4,DIM,f);
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
fread(embed,4,VOCAB*DIM,f);
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
fclose(f);
return true;
}
// ===== Main =====
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 10000;
float lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT);
const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT);
const char *data_path = get_path("ANE_DATA_PATH", DATA_PATH_DEFAULT);
bool do_resume = false;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
}
// Allocate per-layer state
LayerWeights lw[NLAYERS];
LayerAdam la[NLAYERS];
LayerActs acts[NLAYERS];
LayerGrads grads[NLAYERS];
// Double-buffer: two sets of kernels
LayerKernels kern_A[NLAYERS], kern_B[NLAYERS];
LayerKernels *kern_active = kern_A; // currently running evals
LayerKernels *kern_pending = kern_B; // being compiled in background
static _Atomic bool pending_ready = false; // signal: pending compile done
static _Atomic bool bg_compile_running = false;
dispatch_queue_t compile_q = dispatch_queue_create("ane.compile.bg", DISPATCH_QUEUE_SERIAL);
// Legacy alias for code that uses kern[L]
#define kern kern_active
for (int L=0; L<NLAYERS; L++) {
lw[L] = layer_weights_alloc();
la[L] = layer_adam_alloc();
acts[L] = layer_acts_alloc();
grads[L] = layer_grads_alloc();
memset(&kern_A[L], 0, sizeof(LayerKernels));
memset(&kern_B[L], 0, sizeof(LayerKernels));
}
// Final RMSNorm + embedding + classifier
float *rms_final = (float*)malloc(DIM*4);
float *embed = (float*)malloc(VOCAB*DIM*4); // [VOCAB, DIM] row-major
float *grms_final = (float*)calloc(DIM, 4);
float *gembed = (float*)calloc(VOCAB*DIM, 4);
AdamState arms_final = adam_alloc(DIM);
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
double cum_compile=0, cum_train=0, cum_wall=0;
int cum_steps=0, cum_batches=0;
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
if (!load_pretrained(lw, rms_final, embed, model_path)) {
printf("Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
}
size_t tp = (size_t)NLAYERS*LAYER_PARAMS + DIM + (size_t)VOCAB*DIM;
double xfmr_params = (double)NLAYERS*LAYER_PARAMS;
double embed_params = (double)VOCAB*DIM;
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
double cls_f = 2.0*VOCAB*DIM*SEQ;
double total_f = fwd_f + bwd_dx_f + bwd_dw_f + sdpa_f + cls_f*3;
double ane_f = fwd_f + bwd_dx_f + sdpa_f;
printf("FLOPs/step: fwd=%.0fM bwd_dx=%.0fM bwd_dW=%.0fM sdpa_bwd=%.0fM total=%.0fM\n",
fwd_f/1e6, bwd_dx_f/1e6, bwd_dw_f/1e6, sdpa_f/1e6, total_f/1e6);
printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | CPU: dW+cls (cblas)\n\n", ane_f/1e6);
}
// mmap token data (or generate synthetic if not available)
uint16_t *token_data = NULL;
size_t n_tokens = 0;
size_t data_len = 0;
bool synthetic_data = false;
int data_fd = open(data_path, O_RDONLY);
if (data_fd >= 0) {
struct stat st; fstat(data_fd, &st);
data_len = st.st_size;
token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
} else {
// Synthetic data for double-buffer benchmark
synthetic_data = true;
n_tokens = 100000;
data_len = n_tokens * 2;
token_data = (uint16_t*)malloc(data_len);
srand48(123);
for (size_t i = 0; i < n_tokens; i++)
token_data[i] = (uint16_t)(drand48() * (VOCAB - 1));
printf("[DB] Using synthetic data: %zu tokens (benchmark mode)\n", n_tokens);
}
// Gradient buffers shared across layers (reused each step)
float *dy = (float*)malloc(SEQ*DIM*4); // gradient flowing backward
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *do_out_buf = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
// x buffer for input to each layer (channel-first [DIM, SEQ])
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4); // after final rmsnorm
float *logits = (float*)malloc(SEQ*VOCAB*4); // [VOCAB, SEQ] for cross-entropy
float *dlogits = (float*)malloc(SEQ*VOCAB*4);
// Compile static sdpaBwd2 kernels (no weights, one per layer)
Kern *sdpaBwd2[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
}
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_compile_ms=0, total_train_ms=0;
int total_steps_done=0, total_batches=0;
uint64_t t_wall_start = mach_absolute_time();
srand48(42 + start_step);
// ===== DOUBLE-BUFFER: Initial synchronous compile (first batch only) =====
printf(" [DB] Initial compile (synchronous)...\n");
{
uint64_t tc = mach_absolute_time();
for (int L=0; L<NLAYERS; L++) {
printf(" Compiling layer %d/%d... (%d compiles)\r", L+1, NLAYERS, g_compile_count);
fflush(stdout);
if (!compile_layer_kernels(&kern_active[L], &lw[L])) {
printf("\nInitial compile failed at layer %d\n", L);
return 1;
}
}
// Compile static sdpaBwd2 kernels
for (int L=0; L<NLAYERS; L++) {
if (!sdpaBwd2[L]) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
}
}
double cms = tb_ms(mach_absolute_time() - tc);
total_compile_ms += cms;
printf(" [DB] Initial compile: %d kernels in %.0fms\n", TOTAL_WEIGHT_KERNELS, cms);
}
// Helper block: compile all layers into a kernel set
// Captured by the GCD block for background compilation
void (^compile_into)(LayerKernels *, LayerWeights *) = ^(LayerKernels *target, LayerWeights *weights) {
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&target[L]);
if (!compile_layer_kernels(&target[L], &weights[L])) {
printf("\n [DB] Background compile failed at layer %d\n", L);
return;
}
}
};
int step = start_step;
int batches_since_swap = 0;
double total_stall_ms = 0;
while (step < total_steps) {
// Check compile budget
if (g_compile_count + TOTAL_WEIGHT_KERNELS > DB_MAX_COMPILES) {
// Wait for any in-flight background compile
dispatch_sync(compile_q, ^{});
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&kern_A[L]);
free_layer_kernels(&kern_B[L]);
free_kern(sdpaBwd2[L]); sdpaBwd2[L] = NULL;
}
#undef kern
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout);
execl(argv[0], argv[0], "--resume", NULL);
perror("execl"); return 1;
#define kern kern_active
}
// ===== DOUBLE-BUFFER: Check if pending kernels are ready to swap =====
if (atomic_load(&pending_ready)) {
// Swap: pending becomes active, old active becomes recycle target
LayerKernels *old_active = kern_active;
kern_active = kern_pending;
kern_pending = old_active;
atomic_store(&pending_ready, false);
batches_since_swap = 0;
printf(" [DB] Swapped kernels (stall=0ms)\n");
}
// Re-compile sdpaBwd2 if needed (after exec restart)
for (int L=0; L<NLAYERS; L++) {
if (!sdpaBwd2[L]) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 recompile failed\n"); return 1; }
}
}
// Zero gradient accumulators
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
memset(grms_final, 0, DIM*4);
memset(gembed, 0, (size_t)VOCAB*DIM*4);
int steps_batch = 0;
uint64_t tt = mach_absolute_time();
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0;
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
uint64_t t0,t1;
// Sample random position in token data
size_t max_pos = n_tokens - SEQ - 1;
size_t pos = (size_t)(drand48() * max_pos);
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens = token_data + pos + 1;
// Embedding lookup x_cur [DIM, SEQ] channel-first
t0=mach_absolute_time();
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
// ===== FORWARD (12 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
// Save layer input for rmsnorm1 backward
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
// Attention forward: x_cur o_out,Q,K,V,attn_out,xnorm
t0=mach_absolute_time();
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1;
io_write_fp16(kern[L].fwdAttn->ioIn, x_cur, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kern[L].fwdAttn);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kern[L].fwdAttn->ioOut, ac->o_out, 0, DIM, SEQ);
io_read_fp16(kern[L].fwdAttn->ioOut, ac->attn_out, 4*DIM, DIM, SEQ);
io_read_fp16(kern[L].fwdAttn->ioOut, ac->xnorm, 5*DIM, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// FFN forward
io_write_fp16(kern[L].fwdFFN->ioIn, ac->x2, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kern[L].fwdFFN);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kern[L].fwdFFN->ioOut, ac->ffn_out, 0, DIM, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->h1, DIM, HIDDEN, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->h3, DIM+HIDDEN, HIDDEN, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->silu_out, DIM+2*HIDDEN, HIDDEN, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->x2norm, DIM+3*HIDDEN, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
}
// Final RMSNorm (CPU)
t0=mach_absolute_time();
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
// Classifier: logits = embed^T @ x_final
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
VOCAB, SEQ, DIM, 1.0f,
embed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
// Cross-entropy loss
float loss = cross_entropy_loss(dlogits, logits, target_tokens, VOCAB, SEQ);
last_loss = loss;
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// ===== BACKWARD =====
// dlogits already computed by cross_entropy_loss
// Classifier backward: dx_final = embed^T @ dlogits, dembed += dlogits @ x_final^T
// dx_final[DIM,SEQ] = embed^T[DIM,VOCAB] @ dlogits[VOCAB,SEQ]
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, VOCAB, 1.0f,
embed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
// dembed[VOCAB,DIM] += dlogits[VOCAB,SEQ] @ x_final^T[SEQ,DIM]
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
VOCAB, DIM, SEQ, 1.0f,
dlogits, SEQ, x_final, SEQ, 1.0f, gembed, DIM);
});
// Final RMSNorm backward
float *dx_rms_final = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ);
memcpy(dy, dx_rms_final, SEQ*DIM*4);
free(dx_rms_final);
// ===== BACKWARD (12 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
// dy is the gradient at the output of this layer
// dffn = dy (residual connection: d(x2 + ffn) = dy for both)
memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward (ANE)
io_write_fp16_at(kern[L].ffnBwd->ioIn, 0, dffn, DIM, SEQ);
io_copy(kern[L].ffnBwd->ioIn, DIM, kern[L].fwdFFN->ioOut, DIM, 2*HIDDEN, SEQ);
ane_eval(kern[L].ffnBwd);
io_read_fp16(kern[L].ffnBwd->ioOut, dx_ffn, 0, DIM, SEQ);
io_read_fp16(kern[L].ffnBwd->ioOut, dh1, DIM, HIDDEN, SEQ);
io_read_fp16(kern[L].ffnBwd->ioOut, dh3, DIM+HIDDEN, HIDDEN, SEQ);
// dW FFN async
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4);
float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4);
float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM);
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
});
// RMSNorm2 backward
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
// Add residual: dx2 += dy (from skip connection)
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
// dWo async
memcpy(do_out_buf, dx2, SEQ*DIM*4);
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, do_out_buf, SEQ*DIM*4);
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
free(capt_do); free(capt_attn);
});
// SDPA backward (ANE)
io_copy(kern[L].sdpaBwd1->ioIn, 0, kern[L].fwdAttn->ioOut, DIM, 3*DIM, SEQ);
io_write_fp16_at(kern[L].sdpaBwd1->ioIn, 3*DIM, dx2, DIM, SEQ);
ane_eval(kern[L].sdpaBwd1);
io_copy(sdpaBwd2[L]->ioIn, 0, kern[L].sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_copy(sdpaBwd2[L]->ioIn, 2*SCORE_CH, kern[L].fwdAttn->ioOut, DIM, 2*DIM, SEQ);
ane_eval(sdpaBwd2[L]);
io_read_fp16(sdpaBwd2[L]->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(sdpaBwd2[L]->ioOut, dk, DIM, DIM, SEQ);
io_read_fp16(kern[L].sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
// dWq/dWk/dWv async
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk, SEQ*DIM*4);
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
});
// QKV backward (ANE)
io_copy(kern[L].qkvBwd->ioIn, 0, sdpaBwd2[L]->ioOut, 0, 2*DIM, SEQ);
io_copy(kern[L].qkvBwd->ioIn, 2*DIM, kern[L].sdpaBwd1->ioOut, 0, DIM, SEQ);
ane_eval(kern[L].qkvBwd);
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
// RMSNorm1 backward (using saved layer input)
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
// dy for next layer (going backward) = dx_rms1 + dx2 residual
// Actually: layer output = layer_input + o_out, and x2 = layer_input + o_out
// So dx(layer_input) = dx_attn_rmsnorm + dx2 (residual from attn skip)
// Wait, dx2 already includes the attn skip residual gradient.
// dy = dx_rms1 (through rmsnorm1) is the gradient to the layer input
// But there's also the skip connection: layer_input x2 directly
// So total gradient to layer_input = dx_rms1 + dx2_skip
// dx2 was computed as rmsnorm2_bwd + dy(ffn_skip), which already flows to x2
// x2 = layer_input + o_out, so d(layer_input) from x2 path = dx2
// And d(layer_input) from attn path through rmsnorm1 = dx_rms1
// Total: dy_prev = dx_rms1 (attn rmsnorm path)
// Wait no - dx2 = d(loss)/d(x2), not d(loss)/d(layer_input)
// d(layer_input) = d(loss)/d(x2) * d(x2)/d(layer_input) = dx2 (since x2 = input + o_out, d(x2)/d(input) = 1)
// Plus the path through rmsnorm1: dx_rms1
// Hmm but dx2 was already used as input to SDPA backward... let me reconsider.
//
// Actually the gradient flow is:
// dy split to (dffn, dy_skip) [dy_skip = dy due to residual]
// dffn ffnBwd dx_ffn
// dx_ffn rmsnorm2_bwd dx_rms2
// dx2 = dx_rms2 + dy (skip connection from residual x2 output)
// dx2 sdpaBwd dx_attn through Wo^T
// dx_attn qkvBwd dx_qkv
// dx_qkv rmsnorm1_bwd dx_rms1
// dy_prev_layer = dx_rms1 + dx2 (skip connection input x2)
//
// So: dy for previous layer = dx_rms1 + dx2
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms1[i] + dx2[i];
free(dx_rms1);
}
// Embedding backward
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
embed_backward(gembed, dy, input_tokens, DIM, SEQ);
steps_batch++;
if (step % 10 == 0 || step == start_step)
printf("step %-4d loss=%.4f\n", step, loss);
// JSON telemetry to stderr
double step_ane = t_ane/steps_batch, step_io = t_io/steps_batch;
double step_cls = t_cls/steps_batch, step_elem = t_elem/steps_batch;
double step_rms = t_rms/steps_batch, step_cbw = t_cblas_wait/steps_batch;
fprintf(stderr, "{\"type\":\"step\",\"step\":%d,\"loss\":%.6f,"
"\"t_ane\":%.3f,\"t_io\":%.3f,\"t_cls\":%.3f,"
"\"t_elem\":%.3f,\"t_rms\":%.3f,\"t_cblas_wait\":%.3f,"
"\"compiles\":%d}\n",
step, loss, step_ane, step_io, step_cls, step_elem, step_rms, step_cbw, g_compile_count);
}
double tms = tb_ms(mach_absolute_time() - tt);
total_train_ms += tms;
total_steps_done += steps_batch;
total_batches++;
// Ensure all async dW finished
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
// Adam update (scale gradients by 1/steps_batch)
float gsc = 1.0f / steps_batch;
adam_t++;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Scale and update embed
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
// ===== DOUBLE-BUFFER: Start background compile with updated weights =====
batches_since_swap++;
// Only start bg compile if we have budget
if (!atomic_load(&bg_compile_running) &&
g_compile_count + TOTAL_WEIGHT_KERNELS <= DB_MAX_COMPILES) {
atomic_store(&bg_compile_running, true);
// Capture pointers (not stack arrays) for background block
LayerKernels *bg_target = kern_pending;
LayerWeights *bg_weights = lw; // decays to pointer, safe for block
dispatch_async(compile_q, ^{
compile_into(bg_target, bg_weights);
atomic_store(&pending_ready, true);
atomic_store(&bg_compile_running, false);
});
}
double cms = 0; // compile was async, no stall
printf(" [batch %d: compile_stall=0ms train=%.1fms (%.1fms/step) compiles=%d bg=%s]\n",
steps_batch, tms, tms/steps_batch, g_compile_count,
atomic_load(&bg_compile_running) ? "compiling" : "idle");
printf(" ane=%.1f io=%.1f cls=%.1f elem=%.1f rms=%.1f cblas_wait=%.1f ms/step\n",
t_ane/steps_batch, t_io/steps_batch, t_cls/steps_batch, t_elem/steps_batch,
t_rms/steps_batch, t_cblas_wait/steps_batch);
// JSON batch telemetry to stderr
{
double bf = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bs = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
double ane_f_batch = (bf*2 + bs) * steps_batch;
double ane_tflops = ane_f_batch / (tms * 1e9);
fprintf(stderr, "{\"type\":\"batch\",\"batch\":%d,\"compile_ms\":%.1f,"
"\"train_ms\":%.1f,\"ms_per_step\":%.1f}\n",
steps_batch, cms, tms, tms/steps_batch);
fprintf(stderr, "{\"type\":\"perf\",\"ane_tflops\":%.3f,\"ane_util_pct\":%.2f}\n",
ane_tflops, 100.0*ane_tflops/15.8);
}
}
// Efficiency report
double wall = tb_ms(mach_absolute_time() - t_wall_start);
total_compile_ms += cum_compile; total_train_ms += cum_train;
wall += cum_wall; total_steps_done += cum_steps; total_batches += cum_batches;
double fwd_flops = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double sdpa_flops = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
double cls_flops = 2.0*VOCAB*DIM*SEQ;
double total_flops = (fwd_flops*3 + sdpa_flops + cls_flops*3) * total_steps_done;
double ane_flops = (fwd_flops*2 + sdpa_flops) * total_steps_done;
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Wall time: %.0f ms (%.1f s)\n", wall, wall/1000);
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100*total_compile_ms/wall);
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100*total_train_ms/wall);
printf("Avg train: %.1f ms/step\n", total_train_ms/total_steps_done);
printf("ANE TFLOPS: %.2f sustained\n", ane_flops / (total_train_ms * 1e9));
printf("Total TFLOPS: %.2f (ANE+CPU)\n", total_flops / (total_train_ms * 1e9));
printf("ANE utilization: %.1f%% of 15.8 TFLOPS\n", 100*ane_flops/(total_train_ms*1e9)/15.8);
// Wait for any in-flight background compile
dispatch_sync(compile_q, ^{});
// Cleanup
#undef kern
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&kern_A[L]);
free_layer_kernels(&kern_B[L]);
free_kern(sdpaBwd2[L]);
layer_weights_free(&lw[L]);
layer_adam_free(&la[L]);
layer_acts_free(&acts[L]);
layer_grads_free(&grads[L]);
}
if (synthetic_data) { free(token_data); }
else { munmap(token_data, data_len); close(data_fd); }
free(rms_final); free(embed); free(grms_final); free(gembed);
adam_free(&arms_final); adam_free(&aembed);
free(dy); free(dffn); free(dh1); free(dh3); free(dx_ffn); free(dx2);
free(do_out_buf); free(dq); free(dk); free(dv); free(dx_attn);
free(x_cur); free(x_final); free(logits); free(dlogits);
}
return 0;
}

971
training/train_opt.m Normal file
View File

@ -0,0 +1,971 @@
// train_opt.m Optimized train_large with:
// Phase 1: NEON Adam, vectorized embed ops, pre-allocated capture buffers
// Phase 2: Concurrent dW dispatch, fp16 activation cache
// Phase 3: Metal GPU for weight gradient computation (dW)
//
// Key perf wins:
// - Pre-allocated LayerCaptures: eliminates ~132 malloc/free per step
// - Concurrent dW queue: individual sgemms run in parallel (was serial)
// - fp16 activation cache: skip fp16fp32 on main thread for dW-only buffers
// - Metal GPU dW: ~12ms for all weight gradients vs ~435ms serial CPU
// - NEON Adam: ~3x faster optimizer step
// - Vectorized embed: vDSP_mtrans instead of scalar scatter/gather
#include "stories_io.h"
#include "stories_mil.h"
#include "stories_cpu_ops_opt.h"
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
#define MODEL_PATH_DEFAULT "stories110M.bin"
#define DATA_PATH_DEFAULT "tinystories_data00.bin"
static const char *get_path(const char *env_var, const char *default_val) {
const char *v = getenv(env_var);
return (v && v[0]) ? v : default_val;
}
// ===== Pre-allocated capture buffers per layer (Phase 1) =====
// Eliminates malloc/free in dispatch blocks
typedef struct {
// FFN dW captures
float *dffn; // [DIM, SEQ]
float *silu_out; // [HIDDEN, SEQ]
float *dh1; // [HIDDEN, SEQ]
float *dh3; // [HIDDEN, SEQ]
float *x2norm; // [DIM, SEQ]
// Attn dW captures
float *do_buf; // [DIM, SEQ] (for dWo)
float *attn_out; // [DIM, SEQ]
// QKV dW captures
float *dq; // [DIM, SEQ]
float *dk; // [DIM, SEQ]
float *dv; // [DIM, SEQ]
float *xnorm; // [DIM, SEQ]
// fp16 backward gradient cache (read raw from IOSurface, convert in dispatch block)
_Float16 *dh1_fp16; // [HIDDEN, SEQ]
_Float16 *dh3_fp16; // [HIDDEN, SEQ]
_Float16 *dq_fp16; // [DIM, SEQ]
_Float16 *dk_fp16; // [DIM, SEQ]
_Float16 *dv_fp16; // [DIM, SEQ]
} LayerCaptures;
static LayerCaptures layer_captures_alloc(void) {
LayerCaptures c;
c.dffn = (float*)malloc(SEQ * DIM * 4);
c.silu_out = (float*)malloc(SEQ * HIDDEN * 4);
c.dh1 = (float*)malloc(SEQ * HIDDEN * 4);
c.dh3 = (float*)malloc(SEQ * HIDDEN * 4);
c.x2norm = (float*)malloc(SEQ * DIM * 4);
c.do_buf = (float*)malloc(SEQ * DIM * 4);
c.attn_out = (float*)malloc(SEQ * DIM * 4);
c.dq = (float*)malloc(SEQ * DIM * 4);
c.dk = (float*)malloc(SEQ * DIM * 4);
c.dv = (float*)malloc(SEQ * DIM * 4);
c.xnorm = (float*)malloc(SEQ * DIM * 4);
c.dh1_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2);
c.dh3_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2);
c.dq_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
c.dk_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
c.dv_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
return c;
}
static void layer_captures_free(LayerCaptures *c) {
free(c->dffn); free(c->silu_out); free(c->dh1); free(c->dh3);
free(c->x2norm); free(c->do_buf); free(c->attn_out);
free(c->dq); free(c->dk); free(c->dv); free(c->xnorm);
free(c->dh1_fp16); free(c->dh3_fp16);
free(c->dq_fp16); free(c->dk_fp16); free(c->dv_fp16);
}
// ===== fp16 activation cache (Phase 2) =====
// Store activations that are only used for dW as fp16 (skip main-thread conversion)
typedef struct {
_Float16 *xnorm_fp16; // [DIM, SEQ]
_Float16 *attn_out_fp16; // [DIM, SEQ]
_Float16 *x2norm_fp16; // [DIM, SEQ]
_Float16 *silu_out_fp16; // [HIDDEN, SEQ]
} LayerFP16Cache;
static LayerFP16Cache layer_fp16_cache_alloc(void) {
LayerFP16Cache c;
c.xnorm_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
c.attn_out_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
c.x2norm_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
c.silu_out_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2);
return c;
}
static void layer_fp16_cache_free(LayerFP16Cache *c) {
free(c->xnorm_fp16); free(c->attn_out_fp16);
free(c->x2norm_fp16); free(c->silu_out_fp16);
}
// ===== Metal GPU dW context (Phase 3) =====
typedef struct {
id<MTLDevice> device;
id<MTLCommandQueue> queue;
// Shared gradient accumulator buffers (one per weight matrix per layer)
id<MTLBuffer> dW_bufs[NLAYERS][9]; // Wq,Wk,Wv,Wo,W1,W2,W3,rms_att,rms_ffn
id<MTLCommandBuffer> lastCmdBuf; // Track last submitted buffer for sync
} MetalDWContext;
// Weight matrix indices for Metal buffers
enum { MW_Q=0, MW_K, MW_V, MW_O, MW_1, MW_2, MW_3, MW_RMSA, MW_RMSF };
static bool metal_dw_init(MetalDWContext *ctx) {
ctx->device = MTLCreateSystemDefaultDevice();
if (!ctx->device) { printf("[Metal] No GPU device\n"); return false; }
ctx->queue = [ctx->device newCommandQueue];
if (!ctx->queue) { printf("[Metal] No command queue\n"); return false; }
// Allocate shared-mode gradient accumulator buffers
size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4,
W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4};
for (int L = 0; L < NLAYERS; L++) {
for (int w = 0; w < 9; w++) {
ctx->dW_bufs[L][w] = [ctx->device newBufferWithLength:sizes[w]
options:MTLResourceStorageModeShared];
if (!ctx->dW_bufs[L][w]) { printf("[Metal] Buffer alloc failed L=%d w=%d\n", L, w); return false; }
}
}
printf("[Metal] GPU: %s\n", [[ctx->device name] UTF8String]);
return true;
}
static void metal_dw_zero(MetalDWContext *ctx) {
size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4,
W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4};
for (int L = 0; L < NLAYERS; L++) {
for (int w = 0; w < 9; w++) {
memset([ctx->dW_bufs[L][w] contents], 0, sizes[w]);
}
}
}
// Encode a single dW sgemm to Metal command buffer using MPS
// C[M,N] += A[M,K] @ B^T[N,K] (i.e., C += A @ B^T, accumulating into C)
static void metal_encode_dw_sgemm(id<MTLCommandBuffer> cmdBuf,
id<MTLDevice> device,
const float *a_data, int M, int K,
const float *b_data, int N,
id<MTLBuffer> c_buf) {
// Create temporary input buffers (shared mode = zero-copy on Apple Silicon)
id<MTLBuffer> aBuf = [device newBufferWithBytesNoCopy:(void*)a_data
length:M * K * sizeof(float)
options:MTLResourceStorageModeShared
deallocator:nil];
id<MTLBuffer> bBuf = [device newBufferWithBytesNoCopy:(void*)b_data
length:N * K * sizeof(float)
options:MTLResourceStorageModeShared
deallocator:nil];
// A is [M, K] row-major, B is [N, K] row-major
// We want C += A @ B^T, i.e., C[M, N] = A[M, K] * B[K, N]^T
// MPS uses row-major by default
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor matrixDescriptorWithRows:M
columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32];
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor matrixDescriptorWithRows:N
columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32];
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor matrixDescriptorWithRows:M
columns:N rowBytes:N * sizeof(float) dataType:MPSDataTypeFloat32];
MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:aBuf descriptor:descA];
MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bBuf descriptor:descB];
MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:c_buf descriptor:descC];
MPSMatrixMultiplication *mm = [[MPSMatrixMultiplication alloc]
initWithDevice:device transposeLeft:NO transposeRight:YES
resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:1.0];
[mm encodeToCommandBuffer:cmdBuf leftMatrix:matA rightMatrix:matB resultMatrix:matC];
}
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS);
fclose(f); return false;
}
int V = abs(cfg.vocab_size);
bool shared = cfg.vocab_size > 0;
(void)V; (void)shared;
fread(embed, 4, VOCAB * DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
fread(rms_final, 4, DIM, f);
fclose(f);
printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls");
return true;
}
// ===== Compile one layer's kernels =====
static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) {
lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{
@"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)},
@"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)},
@"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)},
@"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)},
@"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)},
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
}), DIM*SEQ*2, 6*DIM*SEQ*2);
lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{
@"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)},
@"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)},
@"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)},
@"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)},
}), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2);
lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{
@"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)},
@"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)},
@"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)},
}), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2);
lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
@"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)},
}), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{
@"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)},
@"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)},
@"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)},
}), 3*DIM*SEQ*2, DIM*SEQ*2);
return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd;
}
static Kern *compile_sdpa_bwd2(void) {
return compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
}
static void free_layer_kernels(LayerKernels *lk) {
free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd);
free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd);
lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL;
}
// ===== Checkpoint save/load =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double cc, double ct, double cw, int cs, int cb, int adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 2;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw;
h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t;
fwrite(&h, sizeof(h), 1, f);
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
}
fwrite(rms_final,4,DIM,f);
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
fwrite(embed,4,VOCAB*DIM,f);
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
fclose(f);
}
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "rb");
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall;
*cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
}
fread(rms_final,4,DIM,f);
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
fread(embed,4,VOCAB*DIM,f);
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
fclose(f);
return true;
}
// ===== Main =====
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
// Phase 2: Limit BLAS thread count to prevent oversubscription with concurrent dispatch
setenv("VECLIB_MAXIMUM_THREADS", "2", 1);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 10000;
float lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
// Parse args
const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT);
const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT);
const char *data_path = get_path("ANE_DATA_PATH", DATA_PATH_DEFAULT);
bool do_resume = false;
bool use_metal = false; // default off: Metal dW contends with ANE for memory bandwidth
int pos = 0;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
else if (strcmp(argv[i], "--no-metal") == 0) use_metal = false;
else if (strcmp(argv[i], "--metal") == 0) use_metal = true;
else if (argv[i][0] != '-') {
if (pos == 0) model_path = argv[i];
pos++;
}
}
// Allocate per-layer state
LayerWeights lw[NLAYERS];
LayerAdam la[NLAYERS];
LayerActs acts[NLAYERS];
LayerGrads grads[NLAYERS];
LayerKernels kern[NLAYERS];
LayerCaptures caps[NLAYERS]; // Phase 1: pre-allocated captures
LayerFP16Cache fp16cache[NLAYERS]; // Phase 2: fp16 activation cache
for (int L=0; L<NLAYERS; L++) {
lw[L] = layer_weights_alloc();
la[L] = layer_adam_alloc();
acts[L] = layer_acts_alloc();
grads[L] = layer_grads_alloc();
memset(&kern[L], 0, sizeof(LayerKernels));
caps[L] = layer_captures_alloc();
fp16cache[L] = layer_fp16_cache_alloc();
}
// Final RMSNorm + embedding + classifier
float *rms_final = (float*)malloc(DIM*4);
float *embed = (float*)malloc(VOCAB*DIM*4);
float *grms_final = (float*)calloc(DIM, 4);
float *gembed = (float*)calloc(VOCAB*DIM, 4);
AdamState arms_final = adam_alloc(DIM);
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
// Phase 1: Pre-allocate dx_rms scratch (was calloc/free per step)
float *dx_rms_scratch = (float*)malloc(SEQ*DIM*4);
// Phase 1: Pre-allocate embed temp buffer for vectorized ops
float *embed_tmp = (float*)malloc(SEQ*DIM*4);
// Phase 3: Metal GPU for dW
MetalDWContext metal_ctx;
bool metal_ok = false;
if (use_metal) {
metal_ok = metal_dw_init(&metal_ctx);
if (!metal_ok) printf("[Metal] GPU init failed, falling back to CPU cblas\n");
}
// Classifier dW capture buffers (pre-allocated, Phase 1)
float *capt_dlogits = (float*)malloc(SEQ*VOCAB*4);
float *capt_xfinal = (float*)malloc(SEQ*DIM*4);
double cum_compile=0, cum_train=0, cum_wall=0;
int cum_steps=0, cum_batches=0;
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Training (OPTIMIZED): Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
printf("Optimizations: NEON-Adam, vec-embed, pre-alloc, concurrent-dW, fp16-cache%s\n",
metal_ok ? ", Metal-GPU-dW" : "");
if (!load_pretrained(lw, rms_final, embed, model_path)) {
printf("Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
}
size_t tp = (size_t)NLAYERS*LAYER_PARAMS + DIM + (size_t)VOCAB*DIM;
double xfmr_params = (double)NLAYERS*LAYER_PARAMS;
double embed_params = (double)VOCAB*DIM;
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
double cls_f = 2.0*VOCAB*DIM*SEQ;
double total_f = fwd_f + bwd_dx_f + bwd_dw_f + sdpa_f + cls_f*3;
double ane_f = fwd_f + bwd_dx_f + sdpa_f;
printf("FLOPs/step: fwd=%.0fM bwd_dx=%.0fM bwd_dW=%.0fM sdpa_bwd=%.0fM total=%.0fM\n",
fwd_f/1e6, bwd_dx_f/1e6, bwd_dw_f/1e6, sdpa_f/1e6, total_f/1e6);
printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | %s: dW+cls\n\n",
ane_f/1e6, metal_ok ? "GPU" : "CPU cblas");
}
// mmap token data
int data_fd = open(data_path, O_RDONLY);
if (data_fd < 0) { printf("Cannot open %s\n", data_path); return 1; }
struct stat st; fstat(data_fd, &st);
size_t data_len = st.st_size;
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
size_t n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
// Gradient buffers shared across layers (reused each step)
float *dy = (float*)malloc(SEQ*DIM*4);
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *do_out_buf = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4);
float *logits = (float*)malloc(SEQ*VOCAB*4);
float *dlogits = (float*)malloc(SEQ*VOCAB*4);
// Compile static sdpaBwd2 kernels
Kern *sdpaBwd2[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
}
// Phase 2: Concurrent dW dispatch queue (was DISPATCH_QUEUE_SERIAL)
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_CONCURRENT);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_compile_ms=0, total_train_ms=0;
int total_steps_done=0, total_batches=0;
uint64_t t_wall_start = mach_absolute_time();
srand48(42 + start_step);
int step = start_step;
while (step < total_steps) {
// Check compile budget
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout);
// Preserve --metal flag across restarts (default is off)
if (use_metal) execl(argv[0], argv[0], "--resume", "--metal", NULL);
else execl(argv[0], argv[0], "--resume", NULL);
perror("execl"); return 1;
}
// Compile all layers' weight-bearing kernels
uint64_t tc = mach_absolute_time();
for (int L=0; L<NLAYERS; L++) free_layer_kernels(&kern[L]);
bool compile_ok = true;
for (int L=0; L<NLAYERS; L++) {
printf(" Compiling layer %d/%d... (%d compiles)\r", L+1, NLAYERS, g_compile_count);
fflush(stdout);
if (!compile_layer_kernels(&kern[L], &lw[L])) {
printf("\nCompile failed at layer %d, restart\n", L);
compile_ok = false; break;
}
}
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
for (int L=0; L<NLAYERS; L++) {
if (!sdpaBwd2[L]) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 recompile failed\n"); return 1; }
}
}
double cms = tb_ms(mach_absolute_time() - tc);
total_compile_ms += cms;
printf(" Compiled %d kernels in %.0fms \n", TOTAL_WEIGHT_KERNELS, cms);
// Zero gradient accumulators
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
memset(grms_final, 0, DIM*4);
memset(gembed, 0, (size_t)VOCAB*DIM*4);
if (metal_ok) metal_dw_zero(&metal_ctx);
int steps_batch = 0;
uint64_t tt = mach_absolute_time();
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0,t_metal=0,t_bwd=0;
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
uint64_t t0,t1;
size_t max_pos = n_tokens - SEQ - 1;
size_t pos = (size_t)(drand48() * max_pos);
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens = token_data + pos + 1;
// Phase 1: Vectorized embedding lookup
t0=mach_absolute_time();
embed_lookup_opt(x_cur, embed, input_tokens, DIM, SEQ, embed_tmp);
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
// ===== FORWARD (12 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
LayerFP16Cache *fc = &fp16cache[L];
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
// Attention forward
t0=mach_absolute_time();
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1;
io_write_fp16(kern[L].fwdAttn->ioIn, x_cur, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kern[L].fwdAttn);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Read o_out (needed on main thread for residual)
io_read_fp16(kern[L].fwdAttn->ioOut, ac->o_out, 0, DIM, SEQ);
// Phase 2: Read dW-only activations as raw fp16 (skip conversion on main thread)
io_read_raw_fp16(kern[L].fwdAttn->ioOut, fc->attn_out_fp16, 4*DIM, DIM, SEQ);
io_read_raw_fp16(kern[L].fwdAttn->ioOut, fc->xnorm_fp16, 5*DIM, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// FFN forward
io_write_fp16(kern[L].fwdFFN->ioIn, ac->x2, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kern[L].fwdFFN);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Read ffn_out (needed on main thread for residual)
io_read_fp16(kern[L].fwdFFN->ioOut, ac->ffn_out, 0, DIM, SEQ);
// h1, h3 NOT read here backward uses io_copy from fwdFFN->ioOut directly
// silu_out and x2norm are dW-only read as fp16
io_read_raw_fp16(kern[L].fwdFFN->ioOut, fc->silu_out_fp16, DIM+2*HIDDEN, HIDDEN, SEQ);
io_read_raw_fp16(kern[L].fwdFFN->ioOut, fc->x2norm_fp16, DIM+3*HIDDEN, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0);
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
}
// Final RMSNorm (CPU)
t0=mach_absolute_time();
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
// Classifier
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
VOCAB, SEQ, DIM, 1.0f,
embed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
float loss = cross_entropy_loss(dlogits, logits, target_tokens, VOCAB, SEQ);
last_loss = loss;
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// ===== BACKWARD =====
uint64_t t_bwd_start = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, VOCAB, 1.0f,
embed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
// dW embed (classifier) async on dW queue
memcpy(capt_dlogits, dlogits, SEQ*VOCAB*4);
memcpy(capt_xfinal, x_final, SEQ*DIM*4);
// Classifier dW on CPU (gembed is CPU-side accumulator, not Metal buffer)
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
VOCAB, DIM, SEQ, 1.0f,
capt_dlogits, SEQ, capt_xfinal, SEQ, 1.0f, gembed, DIM);
});
// Final RMSNorm backward (using pre-allocated scratch)
memset(dx_rms_scratch, 0, SEQ*DIM*4);
rmsnorm_bwd(dx_rms_scratch, grms_final, dy, x_cur, rms_final, DIM, SEQ);
memcpy(dy, dx_rms_scratch, SEQ*DIM*4);
// ===== BACKWARD (12 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
LayerCaptures *cp = &caps[L];
LayerFP16Cache *fc = &fp16cache[L];
memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward (ANE)
io_write_fp16_at(kern[L].ffnBwd->ioIn, 0, dffn, DIM, SEQ);
io_copy(kern[L].ffnBwd->ioIn, DIM, kern[L].fwdFFN->ioOut, DIM, 2*HIDDEN, SEQ);
ane_eval(kern[L].ffnBwd);
io_read_fp16(kern[L].ffnBwd->ioOut, dx_ffn, 0, DIM, SEQ);
// dh1, dh3: only used for dW captures read as raw fp16
io_read_raw_fp16(kern[L].ffnBwd->ioOut, cp->dh1_fp16, DIM, HIDDEN, SEQ);
io_read_raw_fp16(kern[L].ffnBwd->ioOut, cp->dh3_fp16, DIM+HIDDEN, HIDDEN, SEQ);
memcpy(cp->dffn, dffn, SEQ*DIM*4);
if (metal_ok) {
// Metal path: convert all on main thread for GPU buffers
cvt_f16_f32(cp->dh1, cp->dh1_fp16, SEQ*HIDDEN);
cvt_f16_f32(cp->dh3, cp->dh3_fp16, SEQ*HIDDEN);
cvt_f16_f32(cp->silu_out, fc->silu_out_fp16, SEQ*HIDDEN);
cvt_f16_f32(cp->x2norm, fc->x2norm_fp16, SEQ*DIM);
@autoreleasepool {
id<MTLCommandBuffer> cmdBuf = [metal_ctx.queue commandBuffer];
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->dffn, DIM, SEQ, cp->silu_out, HIDDEN,
metal_ctx.dW_bufs[L][MW_2]);
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->dh1, HIDDEN, SEQ, cp->x2norm, DIM,
metal_ctx.dW_bufs[L][MW_1]);
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->dh3, HIDDEN, SEQ, cp->x2norm, DIM,
metal_ctx.dW_bufs[L][MW_3]);
[cmdBuf commit];
metal_ctx.lastCmdBuf = cmdBuf;
}
} else {
// CPU: concurrent dispatch, convert fp16fp32 in each block
_Float16 *fc_silu = fc->silu_out_fp16;
_Float16 *cp_dh1_f16 = cp->dh1_fp16, *cp_dh3_f16 = cp->dh3_fp16;
float *cp_dffn = cp->dffn, *cp_silu = cp->silu_out;
float *cp_dh1 = cp->dh1, *cp_dh3 = cp->dh3, *cp_x2n = cp->x2norm;
float *gr_W2 = gr->W2, *gr_W1 = gr->W1, *gr_W3 = gr->W3;
// Convert shared x2norm on main thread (W1+W3 blocks read concurrently)
cvt_f16_f32(cp_x2n, fc->x2norm_fp16, SEQ*DIM);
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_silu, fc_silu, SEQ*HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, cp_dffn, SEQ, cp_silu, SEQ, 1.0f, gr_W2, HIDDEN);
});
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_dh1, cp_dh1_f16, SEQ*HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, cp_dh1, SEQ, cp_x2n, SEQ, 1.0f, gr_W1, DIM);
});
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_dh3, cp_dh3_f16, SEQ*HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, cp_dh3, SEQ, cp_x2n, SEQ, 1.0f, gr_W3, DIM);
});
}
// RMSNorm2 backward
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
// dWo async
memcpy(cp->do_buf, dx2, SEQ*DIM*4);
if (metal_ok) {
cvt_f16_f32(cp->attn_out, fc->attn_out_fp16, SEQ*DIM);
@autoreleasepool {
id<MTLCommandBuffer> cmdBuf = [metal_ctx.queue commandBuffer];
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->do_buf, DIM, SEQ, cp->attn_out, DIM,
metal_ctx.dW_bufs[L][MW_O]);
[cmdBuf commit];
metal_ctx.lastCmdBuf = cmdBuf;
}
} else {
_Float16 *fc_attn = fc->attn_out_fp16;
float *cp_do = cp->do_buf, *cp_attn = cp->attn_out;
float *gr_Wo = gr->Wo;
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_attn, fc_attn, SEQ*DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, cp_do, SEQ, cp_attn, SEQ, 1.0f, gr_Wo, DIM);
});
}
// SDPA backward (ANE)
io_copy(kern[L].sdpaBwd1->ioIn, 0, kern[L].fwdAttn->ioOut, DIM, 3*DIM, SEQ);
io_write_fp16_at(kern[L].sdpaBwd1->ioIn, 3*DIM, dx2, DIM, SEQ);
ane_eval(kern[L].sdpaBwd1);
io_copy(sdpaBwd2[L]->ioIn, 0, kern[L].sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_copy(sdpaBwd2[L]->ioIn, 2*SCORE_CH, kern[L].fwdAttn->ioOut, DIM, 2*DIM, SEQ);
ane_eval(sdpaBwd2[L]);
// dq, dk, dv: only used for dW captures read as raw fp16
io_read_raw_fp16(sdpaBwd2[L]->ioOut, cp->dq_fp16, 0, DIM, SEQ);
io_read_raw_fp16(sdpaBwd2[L]->ioOut, cp->dk_fp16, DIM, DIM, SEQ);
io_read_raw_fp16(kern[L].sdpaBwd1->ioOut, cp->dv_fp16, 0, DIM, SEQ);
if (metal_ok) {
// Metal path: convert all on main thread for GPU buffers
cvt_f16_f32(cp->dq, cp->dq_fp16, SEQ*DIM);
cvt_f16_f32(cp->dk, cp->dk_fp16, SEQ*DIM);
cvt_f16_f32(cp->dv, cp->dv_fp16, SEQ*DIM);
cvt_f16_f32(cp->xnorm, fc->xnorm_fp16, SEQ*DIM);
@autoreleasepool {
id<MTLCommandBuffer> cmdBuf = [metal_ctx.queue commandBuffer];
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->dq, DIM, SEQ, cp->xnorm, DIM,
metal_ctx.dW_bufs[L][MW_Q]);
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->dk, DIM, SEQ, cp->xnorm, DIM,
metal_ctx.dW_bufs[L][MW_K]);
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
cp->dv, DIM, SEQ, cp->xnorm, DIM,
metal_ctx.dW_bufs[L][MW_V]);
[cmdBuf commit];
metal_ctx.lastCmdBuf = cmdBuf;
}
} else {
_Float16 *cp_dq_f16 = cp->dq_fp16, *cp_dk_f16 = cp->dk_fp16, *cp_dv_f16 = cp->dv_fp16;
float *cp_dq = cp->dq, *cp_dk = cp->dk, *cp_dv = cp->dv, *cp_xn = cp->xnorm;
float *gr_Wq = gr->Wq, *gr_Wk = gr->Wk, *gr_Wv = gr->Wv;
// Convert shared xnorm on main thread (all 3 blocks read concurrently)
cvt_f16_f32(cp_xn, fc->xnorm_fp16, SEQ*DIM);
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_dq, cp_dq_f16, SEQ*DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, cp_dq, SEQ, cp_xn, SEQ, 1.0f, gr_Wq, DIM);
});
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_dk, cp_dk_f16, SEQ*DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, cp_dk, SEQ, cp_xn, SEQ, 1.0f, gr_Wk, DIM);
});
dispatch_group_async(dw_grp, dw_q, ^{
cvt_f16_f32(cp_dv, cp_dv_f16, SEQ*DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, cp_dv, SEQ, cp_xn, SEQ, 1.0f, gr_Wv, DIM);
});
}
// QKV backward (ANE)
io_copy(kern[L].qkvBwd->ioIn, 0, sdpaBwd2[L]->ioOut, 0, 2*DIM, SEQ);
io_copy(kern[L].qkvBwd->ioIn, 2*DIM, kern[L].sdpaBwd1->ioOut, 0, DIM, SEQ);
ane_eval(kern[L].qkvBwd);
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
// RMSNorm1 backward
memset(dx_rms_scratch, 0, SEQ*DIM*4);
rmsnorm_bwd(dx_rms_scratch, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms_scratch[i] + dx2[i];
}
// Embedding backward
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
// Phase 1: Vectorized embed backward
embed_backward_opt(gembed, dy, input_tokens, DIM, SEQ, embed_tmp);
t_bwd += tb_ms(mach_absolute_time() - t_bwd_start);
steps_batch++;
if (step % 10 == 0 || step == start_step)
printf("step %-4d loss=%.4f\n", step, loss);
// JSON telemetry to stderr
double step_ane = t_ane/steps_batch, step_io = t_io/steps_batch;
double step_cls = t_cls/steps_batch, step_elem = t_elem/steps_batch;
double step_rms = t_rms/steps_batch, step_cbw = t_cblas_wait/steps_batch;
fprintf(stderr, "{\"type\":\"step\",\"step\":%d,\"loss\":%.6f,"
"\"t_ane\":%.3f,\"t_io\":%.3f,\"t_cls\":%.3f,"
"\"t_elem\":%.3f,\"t_rms\":%.3f,\"t_cblas_wait\":%.3f,"
"\"t_bwd\":%.3f,\"t_metal\":%.3f,\"compiles\":%d}\n",
step, loss, step_ane, step_io, step_cls, step_elem, step_rms, step_cbw,
t_bwd/steps_batch, t_metal/steps_batch, g_compile_count);
}
double tms = tb_ms(mach_absolute_time() - tt);
total_train_ms += tms;
total_steps_done += steps_batch;
total_batches++;
// Ensure all async dW finished (CPU cblas or Metal)
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
// Phase 3: If Metal, wait for GPU then copy gradient accumulators to CPU grads
if (metal_ok) {
// Must wait for all GPU command buffers to complete before reading
if (metal_ctx.lastCmdBuf) {
[metal_ctx.lastCmdBuf waitUntilCompleted];
metal_ctx.lastCmdBuf = nil;
}
for (int L = 0; L < NLAYERS; L++) {
float *gpu_ptrs[7];
float *cpu_ptrs[7];
size_t sizes[7] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4,
W1_SZ*4, W2_SZ*4, W3_SZ*4};
int indices[7] = {MW_Q, MW_K, MW_V, MW_O, MW_1, MW_2, MW_3};
cpu_ptrs[0]=grads[L].Wq; cpu_ptrs[1]=grads[L].Wk; cpu_ptrs[2]=grads[L].Wv;
cpu_ptrs[3]=grads[L].Wo; cpu_ptrs[4]=grads[L].W1; cpu_ptrs[5]=grads[L].W2;
cpu_ptrs[6]=grads[L].W3;
for (int w = 0; w < 7; w++) {
gpu_ptrs[w] = (float*)[metal_ctx.dW_bufs[L][indices[w]] contents];
// Accumulate GPU gradients into CPU accumulators
vDSP_vadd(gpu_ptrs[w], 1, cpu_ptrs[w], 1, cpu_ptrs[w], 1,
(vDSP_Length)(sizes[w]/4));
}
}
}
// Adam update (scale gradients by 1/steps_batch)
float gsc = 1.0f / steps_batch;
adam_t++;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
// Phase 1: NEON Adam
adam_update_opt(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update_opt(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
adam_update_opt(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
adam_update_opt(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
printf(" [batch %d: compile=%.0fms train=%.1fms (%.1fms/step) compiles=%d]\n",
steps_batch, cms, tms, tms/steps_batch, g_compile_count);
printf(" fwd: ane=%.1f io=%.1f cls=%.1f elem=%.1f rms=%.1f | bwd=%.1f | cblas_wait=%.1f ms/step\n",
t_ane/steps_batch, t_io/steps_batch, t_cls/steps_batch, t_elem/steps_batch,
t_rms/steps_batch, t_bwd/steps_batch, t_cblas_wait/steps_batch);
// JSON batch telemetry to stderr
{
double bf = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bs = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
double ane_f_batch = (bf*2 + bs) * steps_batch;
double ane_tflops = ane_f_batch / (tms * 1e9);
fprintf(stderr, "{\"type\":\"batch\",\"batch\":%d,\"compile_ms\":%.1f,"
"\"train_ms\":%.1f,\"ms_per_step\":%.1f}\n",
steps_batch, cms, tms, tms/steps_batch);
fprintf(stderr, "{\"type\":\"perf\",\"ane_tflops\":%.3f,\"ane_util_pct\":%.2f,"
"\"metal_dw\":%s}\n",
ane_tflops, 100.0*ane_tflops/15.8, metal_ok ? "true" : "false");
}
}
// Efficiency report
double wall = tb_ms(mach_absolute_time() - t_wall_start);
total_compile_ms += cum_compile; total_train_ms += cum_train;
wall += cum_wall; total_steps_done += cum_steps; total_batches += cum_batches;
double fwd_flops = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double sdpa_flops = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
double cls_flops = 2.0*VOCAB*DIM*SEQ;
double total_flops = (fwd_flops*3 + sdpa_flops + cls_flops*3) * total_steps_done;
double ane_flops = (fwd_flops*2 + sdpa_flops) * total_steps_done;
printf("\n=== Efficiency Report (OPTIMIZED) ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Wall time: %.0f ms (%.1f s)\n", wall, wall/1000);
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100*total_compile_ms/wall);
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100*total_train_ms/wall);
printf("Avg train: %.1f ms/step\n", total_train_ms/total_steps_done);
printf("ANE TFLOPS: %.2f sustained\n", ane_flops / (total_train_ms * 1e9));
printf("Total TFLOPS: %.2f (ANE+%s)\n", total_flops / (total_train_ms * 1e9),
metal_ok ? "GPU" : "CPU");
printf("ANE utilization: %.1f%% of 15.8 TFLOPS\n", 100*ane_flops/(total_train_ms*1e9)/15.8);
printf("Metal GPU dW: %s\n", metal_ok ? "ENABLED" : "disabled");
// Cleanup
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&kern[L]);
free_kern(sdpaBwd2[L]);
layer_weights_free(&lw[L]);
layer_adam_free(&la[L]);
layer_acts_free(&acts[L]);
layer_grads_free(&grads[L]);
layer_captures_free(&caps[L]);
layer_fp16_cache_free(&fp16cache[L]);
}
munmap(token_data, data_len);
close(data_fd);
free(rms_final); free(embed); free(grms_final); free(gembed);
adam_free(&arms_final); adam_free(&aembed);
free(dy); free(dffn); free(dh1); free(dh3); free(dx_ffn); free(dx2);
free(do_out_buf); free(dq); free(dk); free(dv); free(dx_attn);
free(x_cur); free(x_final); free(logits); free(dlogits);
free(dx_rms_scratch); free(embed_tmp);
free(capt_dlogits); free(capt_xfinal);
}
return 0;
}