mirror of https://github.com/maderix/ANE.git
[feat] Merge upstream PRs #21, #23, #26: NEON-optimized training (train_opt), double-buffered async ANE training (train_double_buffer), Qwen2.5-0.5B LLM inference (inference/). Added get_path() env var support and SEC_FLAGS to all new targets. Skipped PR #22 (binary blob risk).
This commit is contained in:
parent
99b06838bc
commit
b4d81b71d4
|
|
@ -17,9 +17,17 @@ tiny_train_m1
|
|||
train_large
|
||||
training/train_large
|
||||
training/train_large_ane
|
||||
training/train_opt
|
||||
training/train_double_buffer
|
||||
training/test_*
|
||||
!training/test_*.m
|
||||
|
||||
# Inference binaries
|
||||
inference/qwen_ane
|
||||
|
||||
# Dynamic training binaries
|
||||
training/training_dynamic/train
|
||||
|
||||
# Test/research binaries
|
||||
test_chaining
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
# ANE Probe Results: M4 (macOS 26.3)
|
||||
|
||||
**Machine:** Apple M4 (10 cores), 32GB RAM, macOS 26.3
|
||||
**Date:** 2026-03-03
|
||||
**ANE Family:** H16 (same as M5 results in `training/m5result.md`)
|
||||
|
||||
## Key Discovery: Compile and Eval Run in Parallel
|
||||
|
||||
**This was not known before.** The M5 probes tested compile and eval sequentially.
|
||||
We tested with GCD `dispatch_async` and found they fully overlap.
|
||||
|
||||
### probe_v2.m Results
|
||||
|
||||
#### TEST 1: Pure Eval Throughput
|
||||
```
|
||||
Conv 128x128, spatial=64
|
||||
1000 evals: 189.1ms total, 0.189ms/eval
|
||||
11.09 GFLOPS sustained
|
||||
```
|
||||
|
||||
#### TEST 2: Ping-pong (Two Pre-compiled Models)
|
||||
```
|
||||
500 ping-pong pairs: 207.4ms (0.415ms/pair, 0.207ms/eval)
|
||||
```
|
||||
Near-zero overhead switching between two loaded models.
|
||||
|
||||
#### TEST 3: Sequential Compile (20 Models)
|
||||
```
|
||||
All 20 models compiled and verified ✓
|
||||
Compile time: ~23-29ms each (consistent, no degradation)
|
||||
All 20 models correct with different scale factors
|
||||
```
|
||||
|
||||
#### TEST 4: Background Compile Overlap ⭐
|
||||
```
|
||||
Background compile: 26.8ms
|
||||
Foreground evals during compile: 119 (26.8ms total)
|
||||
Overlap: YES — compile and eval CAN run in parallel!
|
||||
Background model verified correct ✓
|
||||
```
|
||||
|
||||
### Summary
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Compile time | ~25ms per kernel set |
|
||||
| Eval time | 0.189ms per eval |
|
||||
| Compile:eval ratio | ~130:1 |
|
||||
| Parallel compile+eval | **YES** |
|
||||
| Max simultaneous models | 20+ |
|
||||
| Ping-pong overhead | +10% vs single model |
|
||||
|
||||
## Peak ANE Throughput (inmem_peak)
|
||||
|
||||
```
|
||||
Config W(MB) GFLOP ms/eval TFLOPS
|
||||
96x conv 512ch sp64 48.0 3.22 0.429 ms 7.50
|
||||
128x conv 512ch sp64 64.0 4.29 0.589 ms 7.30
|
||||
256x conv 256ch sp64 32.0 2.15 0.380 ms 5.65
|
||||
64x conv 512ch sp64 32.0 2.15 0.395 ms 5.43
|
||||
```
|
||||
|
||||
Peak: **7.50 TFLOPS** (47% of 15.8 TFLOPS theoretical).
|
||||
|
||||
## Implications for Training
|
||||
|
||||
### Before (train_large.m)
|
||||
- Synchronous compile: **88.6% of wall time is compilation**
|
||||
- 55ms compile per batch, 0.54ms actual training
|
||||
- Training throughput limited by compiler, not by ANE
|
||||
|
||||
### After (train_double_buffer.m)
|
||||
- Async double-buffered compile: **0% compile stall**
|
||||
- Background compile happens during forward/backward passes
|
||||
- ~130 eval steps fit in one compile window
|
||||
- Weight updates are "delayed" by one batch (standard technique in distributed training)
|
||||
- Training throughput limited only by ANE eval speed
|
||||
|
||||
### Architecture
|
||||
```
|
||||
Time →
|
||||
Active kernels: [=== eval batch N ===][=== eval batch N+1 ===][=== eval batch N+2 ===]
|
||||
Background: [compile N+1 weights ][compile N+2 weights ][compile N+3 weights ]
|
||||
↑ ↑ ↑
|
||||
swap ready swap ready swap ready
|
||||
```
|
||||
|
||||
Two kernel sets (A and B) alternate between active evaluation and background compilation.
|
||||
When the background compile finishes, pointers swap atomically at the batch boundary.
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
# ANE Inference — Full LLM on Apple Neural Engine
|
||||
|
||||
First complete LLM inference running directly on Apple's Neural Engine via reverse-engineered `_ANEClient` APIs. No CoreML. No Xcode compiler dependency at runtime. Token-for-token match with PyTorch.
|
||||
|
||||
Built on top of the [maderix/ANE](https://github.com/maderix/ANE) training runtime.
|
||||
|
||||
## What This Does
|
||||
|
||||
Runs **Qwen2.5-0.5B-Instruct** (24 transformer layers, 494M parameters) entirely on the ANE:
|
||||
|
||||
- **169 ANE kernels** compiled at startup via `_ANEInMemoryModel`
|
||||
- **82 tokens/sec** decode on M4 Pro
|
||||
- **Zero GPU usage** — runs on 16 dedicated neural cores
|
||||
- **Correct output** — matches PyTorch reference token-for-token
|
||||
|
||||
All linear projections (Q, K, V, O, gate, up, down × 24 layers + chunked LM head) compile as baked-weight 1×1 convolution kernels on ANE. Element-wise ops (RMSNorm, RoPE, softmax, SiLU, attention scores) run on CPU via Accelerate BLAS.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Token → Embedding (CPU) → 24× Transformer Layer → LM Head (CPU) → Next Token
|
||||
│
|
||||
├── RMSNorm (CPU)
|
||||
├── Q/K/V Projection (ANE conv kernel)
|
||||
├── RoPE (CPU, rotate_half)
|
||||
├── GQA Attention (CPU, 14 heads / 2 KV heads)
|
||||
├── O Projection (ANE conv kernel)
|
||||
├── Residual (CPU)
|
||||
├── RMSNorm (CPU)
|
||||
├── Gate/Up Projection (ANE conv kernel)
|
||||
├── SiLU + elementwise mul (CPU)
|
||||
├── Down Projection (ANE conv kernel)
|
||||
└── Residual (CPU)
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Convert weights from HuggingFace safetensors to flat binary
|
||||
pip install safetensors torch transformers
|
||||
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct qwen05b.bin
|
||||
|
||||
# 2. Build
|
||||
xcrun clang -O2 -framework Foundation -framework IOSurface \
|
||||
-framework CoreML -framework Accelerate -ldl -lobjc \
|
||||
-o qwen_ane main.m
|
||||
|
||||
# 3. Run (pass space-separated token IDs)
|
||||
./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20
|
||||
|
||||
# 4. With tokenizer (requires transformers)
|
||||
python3 run.py "Say hello in one word."
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
```
|
||||
=== Qwen2.5-0.5B ANE Inference ===
|
||||
|
||||
Loading weights...
|
||||
Config: dim=896 hidden=4864 layers=24 heads=14 kv_heads=2 vocab=151936
|
||||
Compiling ANE kernels (169 total)...
|
||||
Compile time: 5.1s
|
||||
|
||||
Prompt: 28 tokens, generating up to 10
|
||||
Prefill: 64.2 t/s (28 tokens)
|
||||
OUT: 9707 13 151645
|
||||
Decode: 82.4 t/s (2 tokens)
|
||||
|
||||
→ "Hello." (matches PyTorch exactly)
|
||||
```
|
||||
|
||||
## Files
|
||||
|
||||
| File | What |
|
||||
|------|------|
|
||||
| `qwen_ane_infer.h` | Full 24-layer transformer forward pass, ANE kernel compilation, KV cache |
|
||||
| `main.m` | Weight loader, token I/O, main generation loop |
|
||||
| `convert_weights.py` | HuggingFace safetensors → flat f32 binary (includes Q/K/V biases) |
|
||||
| `run.py` | Python wrapper with HuggingFace tokenizer |
|
||||
|
||||
## Model Support
|
||||
|
||||
Currently implements **Qwen2.5** architecture:
|
||||
- GQA attention (grouped-query, `n_heads` ≠ `n_kv_heads`)
|
||||
- `rotate_half` RoPE (not interleaved pairs)
|
||||
- SwiGLU FFN (gate + up + silu + down)
|
||||
- Q/K/V bias (Qwen-specific)
|
||||
- Tied word embeddings (lm_head = embed)
|
||||
- Chunked LM head (vocab > 65536 exceeds ANE max dim)
|
||||
|
||||
Adapting to other architectures (LLaMA, Gemma, Mistral) requires:
|
||||
1. Adjusting the config constants in `qwen_ane_infer.h`
|
||||
2. Updating `convert_weights.py` for the weight naming scheme
|
||||
3. Removing Q/K/V bias handling if the model doesn't have them
|
||||
4. Switching RoPE to interleaved pairs if needed
|
||||
|
||||
## Requirements
|
||||
|
||||
- macOS 15+ on Apple Silicon (M1/M2/M3/M4)
|
||||
- Xcode Command Line Tools (for `xcrun clang`)
|
||||
- Python 3.9+ with `safetensors`, `torch`, `transformers` (for weight conversion)
|
||||
|
||||
## Known Limitations
|
||||
|
||||
- **CPU projections only** — ANE baked-weight conv kernels compile successfully but produce incorrect output (FP16 weight blob format mismatch). The `USE_ANE_PROJECTIONS` toggle exists but defaults to 0 (CPU via Accelerate BLAS). Fixing this would push decode speed from 82 t/s to 120+ t/s.
|
||||
- **No persistent server** — each invocation recompiles 169 kernels (~5s). A server mode that compiles once and serves via HTTP would eliminate this overhead.
|
||||
- **Single model** — hardcoded for Qwen2.5-0.5B. Needs parameterization for other sizes.
|
||||
- **f32 weights** — 1.9GB on disk. FP16 or quantized weight support would halve this.
|
||||
|
||||
## How It Works
|
||||
|
||||
The key insight from maderix's reverse engineering: the ANE executes compiled MIL (Machine Learning Intermediate Language) programs as atomic graph operations. Each linear projection becomes a MIL program with baked FP16 weights, compiled in-memory via `_ANEInMemoryModel`, and executed through IOSurface-based zero-copy I/O.
|
||||
|
||||
We chain 169 of these atomic operations (7 per transformer layer + 16 LM head chunks) with CPU-side element-wise ops in between. The ANE handles the compute-heavy matmuls; the CPU handles the memory-bound operations (attention scores, softmax, RoPE).
|
||||
|
||||
## License
|
||||
|
||||
Same as maderix/ANE — research and educational use.
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Convert Qwen2.5-0.5B-Instruct safetensors → flat binary for ANE inference.
|
||||
|
||||
Output format: config header (7 ints) + all weights in f32, layer by layer.
|
||||
Matches the layout expected by qwen_ane_infer.h.
|
||||
|
||||
Usage:
|
||||
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct /path/to/output.bin
|
||||
"""
|
||||
|
||||
import struct
|
||||
import sys
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
|
||||
def convert(model_dir: str, output_path: str):
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
# Load safetensors
|
||||
st_files = list(model_dir.glob("*.safetensors"))
|
||||
if not st_files:
|
||||
print(f"No safetensors files in {model_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
tensors = {}
|
||||
for f in st_files:
|
||||
with safe_open(str(f), framework="pt") as sf:
|
||||
for key in sf.keys():
|
||||
tensors[key] = sf.get_tensor(key).float().numpy()
|
||||
|
||||
print(f"Loaded {len(tensors)} tensors from {len(st_files)} files")
|
||||
|
||||
# Qwen2.5-0.5B config
|
||||
dim = 896
|
||||
hidden = 4864
|
||||
n_layers = 24
|
||||
n_heads = 14
|
||||
n_kv_heads = 2
|
||||
vocab_size = 151936
|
||||
max_seq = 512
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
# Config header: 7 x int32
|
||||
f.write(struct.pack("iiiiiii",
|
||||
dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq))
|
||||
|
||||
# Embedding [vocab, dim]
|
||||
emb = tensors["model.embed_tokens.weight"].astype(np.float32)
|
||||
print(f"embed: {emb.shape}")
|
||||
f.write(emb.tobytes())
|
||||
|
||||
# Per-layer weights
|
||||
for l in range(n_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
|
||||
# Attention norm
|
||||
rms_att = tensors[f"{prefix}.input_layernorm.weight"].astype(np.float32)
|
||||
f.write(rms_att.tobytes())
|
||||
|
||||
# Q, K, V projections
|
||||
wq = tensors[f"{prefix}.self_attn.q_proj.weight"].astype(np.float32)
|
||||
wk = tensors[f"{prefix}.self_attn.k_proj.weight"].astype(np.float32)
|
||||
wv = tensors[f"{prefix}.self_attn.v_proj.weight"].astype(np.float32)
|
||||
wo = tensors[f"{prefix}.self_attn.o_proj.weight"].astype(np.float32)
|
||||
f.write(wq.tobytes())
|
||||
f.write(wk.tobytes())
|
||||
f.write(wv.tobytes())
|
||||
f.write(wo.tobytes())
|
||||
|
||||
# Q/K biases (Qwen has them)
|
||||
# Q/K/V biases
|
||||
qb = tensors.get(f"{prefix}.self_attn.q_proj.bias")
|
||||
kb = tensors.get(f"{prefix}.self_attn.k_proj.bias")
|
||||
vb = tensors.get(f"{prefix}.self_attn.v_proj.bias")
|
||||
f.write((qb if qb is not None else np.zeros(wq.shape[0])).astype(np.float32).tobytes())
|
||||
f.write((kb if kb is not None else np.zeros(wk.shape[0])).astype(np.float32).tobytes())
|
||||
f.write((vb if vb is not None else np.zeros(wv.shape[0])).astype(np.float32).tobytes())
|
||||
|
||||
# FFN norm
|
||||
rms_ffn = tensors[f"{prefix}.post_attention_layernorm.weight"].astype(np.float32)
|
||||
f.write(rms_ffn.tobytes())
|
||||
|
||||
# FFN: gate, up, down
|
||||
w_gate = tensors[f"{prefix}.mlp.gate_proj.weight"].astype(np.float32)
|
||||
w_up = tensors[f"{prefix}.mlp.up_proj.weight"].astype(np.float32)
|
||||
w_down = tensors[f"{prefix}.mlp.down_proj.weight"].astype(np.float32)
|
||||
f.write(w_gate.tobytes())
|
||||
f.write(w_up.tobytes())
|
||||
f.write(w_down.tobytes())
|
||||
|
||||
print(f" Layer {l}: Q{wq.shape} K{wk.shape} V{wv.shape} O{wo.shape} "
|
||||
f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape}")
|
||||
|
||||
# Final norm
|
||||
rms_final = tensors["model.norm.weight"].astype(np.float32)
|
||||
f.write(rms_final.tobytes())
|
||||
|
||||
size_mb = Path(output_path).stat().st_size / 1024 / 1024
|
||||
print(f"\nWritten: {output_path} ({size_mb:.0f} MB)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 3:
|
||||
print("Usage: python3 convert_weights.py <model_dir> <output.bin>")
|
||||
sys.exit(1)
|
||||
convert(sys.argv[1], sys.argv[2])
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
// main.m — Qwen2.5-0.5B inference on Apple Neural Engine
|
||||
// Compiles ANE kernels for all linear projections, runs autoregressive decode.
|
||||
//
|
||||
// Build:
|
||||
// xcrun clang -O2 -framework Foundation -framework IOSurface \
|
||||
// -framework CoreML -framework Accelerate -ldl -lobjc \
|
||||
// -o qwen_ane main.m
|
||||
//
|
||||
// Run:
|
||||
// ./qwen_ane qwen05b.bin "Hello world"
|
||||
//
|
||||
#import <Foundation/Foundation.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include "qwen_ane_infer.h"
|
||||
|
||||
int g_fp16_io = 0;
|
||||
static QwenModel g_model;
|
||||
|
||||
static int load_weights(const char *path) {
|
||||
FILE *f = fopen(path, "rb");
|
||||
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
|
||||
|
||||
// Read config header
|
||||
int config[7];
|
||||
fread(config, sizeof(int), 7, f);
|
||||
int dim = config[0], hidden = config[1], n_layers = config[2];
|
||||
int n_heads = config[3], n_kv_heads = config[4], vocab = config[5];
|
||||
printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d\n",
|
||||
dim, hidden, n_layers, n_heads, n_kv_heads, vocab);
|
||||
|
||||
int q_dim = n_heads * QWEN_HEAD_DIM;
|
||||
int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
|
||||
|
||||
// Embedding
|
||||
g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float));
|
||||
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
|
||||
|
||||
// Per-layer
|
||||
for (int l = 0; l < n_layers; l++) {
|
||||
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
|
||||
fread(g_model.rms_att[l], sizeof(float), dim, f);
|
||||
|
||||
g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
|
||||
fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f);
|
||||
g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
||||
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
|
||||
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
||||
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
|
||||
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); // o_proj is [dim, q_dim]
|
||||
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
|
||||
|
||||
// Q/K/V biases
|
||||
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
|
||||
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
||||
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
||||
fread(g_model.q_bias[l], sizeof(float), q_dim, f);
|
||||
fread(g_model.k_bias[l], sizeof(float), kv_dim, f);
|
||||
fread(g_model.v_bias[l], sizeof(float), kv_dim, f);
|
||||
|
||||
g_model.rms_ffn[l] = (float*)malloc(dim * sizeof(float));
|
||||
fread(g_model.rms_ffn[l], sizeof(float), dim, f);
|
||||
|
||||
g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
|
||||
fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f);
|
||||
g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
|
||||
fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f);
|
||||
g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float));
|
||||
fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f);
|
||||
}
|
||||
|
||||
g_model.rms_final = (float*)malloc(dim * sizeof(float));
|
||||
fread(g_model.rms_final, sizeof(float), dim, f);
|
||||
|
||||
fclose(f);
|
||||
printf("Weights loaded (%.0f MB)\n",
|
||||
(float)ftell(f) / 1024 / 1024);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
@autoreleasepool {
|
||||
if (argc < 3) {
|
||||
fprintf(stderr, "Usage: %s <weights.bin> <prompt>\n", argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("=== Qwen2.5-0.5B ANE Inference ===\n\n");
|
||||
|
||||
// Load weights
|
||||
printf("Loading weights...\n");
|
||||
if (load_weights(argv[1]) != 0) return 1;
|
||||
|
||||
// Allocate buffers
|
||||
qwen_alloc(&g_model);
|
||||
|
||||
// Compile ANE kernels
|
||||
printf("Compiling ANE kernels (169 total)...\n");
|
||||
struct timespec t0, t1;
|
||||
clock_gettime(CLOCK_MONOTONIC, &t0);
|
||||
qwen_compile_kernels(&g_model);
|
||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||
double compile_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9;
|
||||
printf("Compile time: %.1fs\n\n", compile_sec);
|
||||
|
||||
// Parse token IDs from argv[2] (space-separated)
|
||||
// argv[3] = max generation tokens
|
||||
int max_gen = 50;
|
||||
if (argc >= 4) max_gen = atoi(argv[3]);
|
||||
|
||||
// Parse input token IDs
|
||||
int prompt_ids[2048];
|
||||
int n_prompt = 0;
|
||||
char *tok_str = strdup(argv[2]);
|
||||
char *saveptr;
|
||||
char *p = strtok_r(tok_str, " ", &saveptr);
|
||||
while (p && n_prompt < 2048) {
|
||||
prompt_ids[n_prompt++] = atoi(p);
|
||||
p = strtok_r(NULL, " ", &saveptr);
|
||||
}
|
||||
free(tok_str);
|
||||
printf("Prompt: %d tokens, generating up to %d\n", n_prompt, max_gen);
|
||||
|
||||
clock_gettime(CLOCK_MONOTONIC, &t0);
|
||||
|
||||
// Prefill: feed all prompt tokens
|
||||
int next = 0;
|
||||
for (int i = 0; i < n_prompt; i++) {
|
||||
next = qwen_forward(&g_model, prompt_ids[i]);
|
||||
}
|
||||
|
||||
struct timespec t_prefill;
|
||||
clock_gettime(CLOCK_MONOTONIC, &t_prefill);
|
||||
double prefill_sec = (t_prefill.tv_sec - t0.tv_sec) + (t_prefill.tv_nsec - t0.tv_nsec) / 1e9;
|
||||
printf("Prefill: %d tokens in %.2fs (%.1f t/s)\n", n_prompt, prefill_sec, n_prompt / prefill_sec);
|
||||
|
||||
// Generate
|
||||
int eos = 151645; // <|im_end|>
|
||||
int eos2 = 151643; // <|endoftext|>
|
||||
printf("OUT:");
|
||||
for (int i = 0; i < max_gen; i++) {
|
||||
printf(" %d", next);
|
||||
fflush(stdout);
|
||||
if (next == eos || next == eos2) break;
|
||||
next = qwen_forward(&g_model, next);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||
double gen_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9;
|
||||
int total_tokens = g_model.pos;
|
||||
int gen_tokens = total_tokens - n_prompt;
|
||||
double decode_sec = gen_sec - prefill_sec;
|
||||
printf("\nTotal: %d tokens in %.2fs\n", total_tokens, gen_sec);
|
||||
printf("Prefill: %.1f t/s (%d tokens)\n", n_prompt / prefill_sec, n_prompt);
|
||||
printf("Decode: %.1f t/s (%d tokens)\n",
|
||||
decode_sec > 0 ? gen_tokens / decode_sec : 0, gen_tokens);
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,435 @@
|
|||
// qwen_ane_infer.h — Qwen2.5-0.5B inference on Apple Neural Engine
|
||||
// Linear projections on ANE (baked-weight conv kernels), CPU for element-wise ops.
|
||||
// Based on maderix/ANE runtime + MIL generation.
|
||||
#pragma once
|
||||
|
||||
#include "../training/ane_runtime.h"
|
||||
#include "../training/ane_mil_gen.h"
|
||||
|
||||
// Compile a matmul kernel: W[out_ch, in_ch] @ x[in_ch] → y[out_ch]
|
||||
// Uses the two-input matmul MIL variant (weights passed as input, not baked)
|
||||
static ANEKernel *compile_matmul_kernel(int in_ch, int out_ch) {
|
||||
NSString *mil = mil_gen_matmul(in_ch, out_ch, 1);
|
||||
size_t inputSizes[2] = {(size_t)in_ch * 1 * 4, (size_t)out_ch * in_ch * 4};
|
||||
size_t outBytes = (size_t)out_ch * 1 * 4;
|
||||
return ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], nil, 2, inputSizes, 1, &outBytes);
|
||||
}
|
||||
|
||||
// Compile a baked-weight conv kernel (from model.h)
|
||||
static ANEKernel *compile_conv_kernel(const float *weights, int in_ch, int out_ch, int spatial) {
|
||||
NSData *wb = mil_build_weight_blob(weights, out_ch, in_ch);
|
||||
NSString *mil = mil_gen_conv(in_ch, out_ch, spatial);
|
||||
size_t inBytes = (size_t)in_ch * spatial * 4;
|
||||
size_t outBytes = (size_t)out_ch * spatial * 4;
|
||||
return ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, 1, &inBytes, 1, &outBytes);
|
||||
}
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
|
||||
// Qwen2.5-0.5B-Instruct architecture
|
||||
#define QWEN_DIM 896
|
||||
#define QWEN_HIDDEN 4864
|
||||
#define QWEN_LAYERS 24
|
||||
#define QWEN_HEADS 14
|
||||
#define QWEN_KV_HEADS 2
|
||||
#define QWEN_HEAD_DIM 64
|
||||
#define QWEN_VOCAB 151936
|
||||
#define QWEN_RMS_EPS 1e-6f
|
||||
#define QWEN_ROPE_THETA 1000000.0f
|
||||
#define QWEN_MAX_SEQ 512
|
||||
|
||||
// GQA: each KV head serves (HEADS / KV_HEADS) query heads
|
||||
#define QWEN_GQA_FACTOR (QWEN_HEADS / QWEN_KV_HEADS)
|
||||
|
||||
// Sizes for GQA projections
|
||||
#define QWEN_Q_DIM (QWEN_HEADS * QWEN_HEAD_DIM) // 896
|
||||
#define QWEN_KV_DIM (QWEN_KV_HEADS * QWEN_HEAD_DIM) // 128
|
||||
|
||||
typedef struct {
|
||||
// Weights (f32)
|
||||
float *embed; // [vocab, dim]
|
||||
float *rms_att[QWEN_LAYERS]; // [dim]
|
||||
float *wq[QWEN_LAYERS]; // [q_dim, dim]
|
||||
float *wk[QWEN_LAYERS]; // [kv_dim, dim]
|
||||
float *wv[QWEN_LAYERS]; // [kv_dim, dim]
|
||||
float *wo[QWEN_LAYERS]; // [dim, q_dim]
|
||||
float *rms_ffn[QWEN_LAYERS]; // [dim]
|
||||
float *w_gate[QWEN_LAYERS]; // [hidden, dim]
|
||||
float *w_up[QWEN_LAYERS]; // [hidden, dim]
|
||||
float *w_down[QWEN_LAYERS]; // [dim, hidden]
|
||||
float *rms_final; // [dim]
|
||||
// wcls = embed (tied)
|
||||
|
||||
// ANE kernels (one per linear projection per layer)
|
||||
ANEKernel *k_q[QWEN_LAYERS];
|
||||
ANEKernel *k_k[QWEN_LAYERS];
|
||||
ANEKernel *k_v[QWEN_LAYERS];
|
||||
ANEKernel *k_o[QWEN_LAYERS];
|
||||
ANEKernel *k_gate[QWEN_LAYERS];
|
||||
ANEKernel *k_up[QWEN_LAYERS];
|
||||
ANEKernel *k_down[QWEN_LAYERS];
|
||||
// LM head chunked: vocab too large for single ANE kernel (max 65536)
|
||||
#define QWEN_LM_CHUNKS 16
|
||||
#define QWEN_LM_CHUNK_SIZE 9496 // 151936 / 16
|
||||
ANEKernel *k_lmhead[QWEN_LM_CHUNKS];
|
||||
|
||||
// Q/K/V biases per layer
|
||||
float *q_bias[QWEN_LAYERS]; // [q_dim]
|
||||
float *k_bias[QWEN_LAYERS]; // [kv_dim]
|
||||
float *v_bias[QWEN_LAYERS]; // [kv_dim]
|
||||
|
||||
// KV cache [layer][kv_heads * head_dim * max_seq]
|
||||
float *kv_cache_k[QWEN_LAYERS];
|
||||
float *kv_cache_v[QWEN_LAYERS];
|
||||
int pos; // current position in sequence
|
||||
|
||||
// Scratch buffers
|
||||
float *x; // [dim]
|
||||
float *xb; // [dim]
|
||||
float *q; // [q_dim]
|
||||
float *k; // [kv_dim]
|
||||
float *v; // [kv_dim]
|
||||
float *att; // [heads * max_seq]
|
||||
float *hb; // [hidden]
|
||||
float *hb2; // [hidden]
|
||||
float *logits; // [vocab]
|
||||
} QwenModel;
|
||||
|
||||
// ── CPU ops ──────────────────────────────────────────────────────────
|
||||
|
||||
static void qwen_rmsnorm(float *out, const float *x, const float *w, int D) {
|
||||
float ss = 0;
|
||||
for (int i = 0; i < D; i++) ss += x[i] * x[i];
|
||||
ss = 1.0f / sqrtf(ss / D + QWEN_RMS_EPS);
|
||||
for (int i = 0; i < D; i++) out[i] = x[i] * ss * w[i];
|
||||
}
|
||||
|
||||
static void qwen_rope(float *q, float *k, int pos, int n_q_heads, int n_kv_heads, int head_dim) {
|
||||
// Qwen uses rotate_half RoPE (NOT interleaved pairs):
|
||||
// rotate_half(x) = [-x[dim/2:], x[:dim/2]]
|
||||
// q_embed = q * cos + rotate_half(q) * sin
|
||||
// cos/sin have shape [head_dim/2] and are applied to both halves
|
||||
int half = head_dim / 2;
|
||||
|
||||
// Precompute cos/sin for this position (head_dim/2 frequencies)
|
||||
float cos_v[half], sin_v[half];
|
||||
for (int i = 0; i < half; i++) {
|
||||
float freq = 1.0f / powf(QWEN_ROPE_THETA, (float)(2 * i) / head_dim);
|
||||
float angle = pos * freq;
|
||||
cos_v[i] = cosf(angle);
|
||||
sin_v[i] = sinf(angle);
|
||||
}
|
||||
|
||||
// Apply to Q heads
|
||||
for (int h = 0; h < n_q_heads; h++) {
|
||||
float *qh = q + h * head_dim;
|
||||
for (int i = 0; i < half; i++) {
|
||||
float q_first = qh[i];
|
||||
float q_second = qh[i + half];
|
||||
// rotate_half: [-q_second, q_first]
|
||||
qh[i] = q_first * cos_v[i] + (-q_second) * sin_v[i];
|
||||
qh[i + half] = q_second * cos_v[i] + q_first * sin_v[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Apply to K heads
|
||||
for (int h = 0; h < n_kv_heads; h++) {
|
||||
float *kh = k + h * head_dim;
|
||||
for (int i = 0; i < half; i++) {
|
||||
float k_first = kh[i];
|
||||
float k_second = kh[i + half];
|
||||
kh[i] = k_first * cos_v[i] + (-k_second) * sin_v[i];
|
||||
kh[i + half] = k_second * cos_v[i] + k_first * sin_v[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void qwen_silu(float *x, int n) {
|
||||
for (int i = 0; i < n; i++)
|
||||
x[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
// ── ANE projection helper (single token: spatial=1) ─────────────────
|
||||
|
||||
static void ane_project(ANEKernel *kernel, const float *in, float *out,
|
||||
int in_dim, int out_dim) {
|
||||
// For single-token inference: spatial=1
|
||||
ane_write_input(kernel, 0, in, in_dim * sizeof(float));
|
||||
ane_eval(kernel);
|
||||
ane_read_output(kernel, 0, out, out_dim * sizeof(float));
|
||||
}
|
||||
|
||||
// CPU matmul via Accelerate BLAS: y = W @ x, W[out_dim, in_dim]
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
static void cpu_project(const float *W, const float *x, float *y, int in_dim, int out_dim) {
|
||||
// y = W @ x where W is [out_dim, in_dim] row-major
|
||||
// cblas_sgemv: y = alpha * A * x + beta * y
|
||||
cblas_sgemv(CblasRowMajor, CblasNoTrans,
|
||||
out_dim, in_dim,
|
||||
1.0f, W, in_dim,
|
||||
x, 1,
|
||||
0.0f, y, 1);
|
||||
}
|
||||
|
||||
// Toggle: 1 = use ANE for projections, 0 = CPU fallback
|
||||
#define USE_ANE_PROJECTIONS 0
|
||||
|
||||
// ── Forward one token ────────────────────────────────────────────────
|
||||
|
||||
static int qwen_forward(QwenModel *m, int token) {
|
||||
int D = QWEN_DIM, HD = QWEN_HIDDEN;
|
||||
int pos = m->pos;
|
||||
|
||||
// Token embedding
|
||||
memcpy(m->x, m->embed + token * D, D * sizeof(float));
|
||||
|
||||
for (int l = 0; l < QWEN_LAYERS; l++) {
|
||||
// Attention RMSNorm
|
||||
qwen_rmsnorm(m->xb, m->x, m->rms_att[l], D);
|
||||
|
||||
// Debug: print first layer input/output norms
|
||||
if (l == 0 && pos == 0) {
|
||||
float xnorm = 0, qnorm = 0;
|
||||
for (int i = 0; i < D; i++) xnorm += m->xb[i] * m->xb[i];
|
||||
printf(" L0 RMSNorm out norm=%.4f (first 4: %.4f %.4f %.4f %.4f)\n",
|
||||
sqrtf(xnorm), m->xb[0], m->xb[1], m->xb[2], m->xb[3]);
|
||||
}
|
||||
|
||||
// QKV projections (ANE) + bias
|
||||
#if USE_ANE_PROJECTIONS
|
||||
ane_project(m->k_q[l], m->xb, m->q, D, QWEN_Q_DIM);
|
||||
ane_project(m->k_k[l], m->xb, m->k, D, QWEN_KV_DIM);
|
||||
ane_project(m->k_v[l], m->xb, m->v, D, QWEN_KV_DIM);
|
||||
#else
|
||||
cpu_project(m->wq[l], m->xb, m->q, D, QWEN_Q_DIM);
|
||||
cpu_project(m->wk[l], m->xb, m->k, D, QWEN_KV_DIM);
|
||||
cpu_project(m->wv[l], m->xb, m->v, D, QWEN_KV_DIM);
|
||||
#endif
|
||||
// Apply Q/K biases
|
||||
if (m->q_bias[l]) {
|
||||
for (int i = 0; i < QWEN_Q_DIM; i++) m->q[i] += m->q_bias[l][i];
|
||||
}
|
||||
if (m->k_bias[l]) {
|
||||
for (int i = 0; i < QWEN_KV_DIM; i++) m->k[i] += m->k_bias[l][i];
|
||||
}
|
||||
if (m->v_bias[l]) {
|
||||
for (int i = 0; i < QWEN_KV_DIM; i++) m->v[i] += m->v_bias[l][i];
|
||||
}
|
||||
|
||||
if (l == 0 && pos == 0) {
|
||||
float qn = 0;
|
||||
for (int i = 0; i < QWEN_Q_DIM; i++) qn += m->q[i] * m->q[i];
|
||||
printf(" L0 ANE Q norm=%.4f (first 4: %.4f %.4f %.4f %.4f)\n",
|
||||
sqrtf(qn), m->q[0], m->q[1], m->q[2], m->q[3]);
|
||||
// CPU reference
|
||||
float cpu_q[4] = {0};
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < D; j++)
|
||||
cpu_q[i] += m->wq[0][i * D + j] * m->xb[j];
|
||||
cpu_q[i] += m->q_bias[0][i];
|
||||
}
|
||||
printf(" L0 CPU Q first 4: %.4f %.4f %.4f %.4f\n",
|
||||
cpu_q[0], cpu_q[1], cpu_q[2], cpu_q[3]);
|
||||
}
|
||||
|
||||
// RoPE
|
||||
qwen_rope(m->q, m->k, pos, QWEN_HEADS, QWEN_KV_HEADS, QWEN_HEAD_DIM);
|
||||
|
||||
// Store K, V in cache
|
||||
memcpy(m->kv_cache_k[l] + pos * QWEN_KV_DIM,
|
||||
m->k, QWEN_KV_DIM * sizeof(float));
|
||||
memcpy(m->kv_cache_v[l] + pos * QWEN_KV_DIM,
|
||||
m->v, QWEN_KV_DIM * sizeof(float));
|
||||
|
||||
// GQA attention (CPU — element-wise ops)
|
||||
float scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM);
|
||||
float *attn_out = m->xb; // reuse buffer
|
||||
memset(attn_out, 0, QWEN_Q_DIM * sizeof(float));
|
||||
|
||||
for (int h = 0; h < QWEN_HEADS; h++) {
|
||||
int kv_h = h / QWEN_GQA_FACTOR;
|
||||
float *qh = m->q + h * QWEN_HEAD_DIM;
|
||||
|
||||
// Attention scores: Q @ K^T for all positions up to pos
|
||||
float max_score = -1e9f;
|
||||
for (int t = 0; t <= pos; t++) {
|
||||
float *kt = m->kv_cache_k[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM;
|
||||
// Use BLAS dot product for precision
|
||||
float score = cblas_sdot(QWEN_HEAD_DIM, qh, 1, kt, 1);
|
||||
m->att[h * QWEN_MAX_SEQ + t] = score * scale;
|
||||
if (score * scale > max_score) max_score = score * scale;
|
||||
}
|
||||
// Softmax (double accumulation for precision)
|
||||
double sum = 0;
|
||||
for (int t = 0; t <= pos; t++) {
|
||||
m->att[h * QWEN_MAX_SEQ + t] = expf(m->att[h * QWEN_MAX_SEQ + t] - max_score);
|
||||
sum += (double)m->att[h * QWEN_MAX_SEQ + t];
|
||||
}
|
||||
float inv_sum = (float)(1.0 / sum);
|
||||
for (int t = 0; t <= pos; t++)
|
||||
m->att[h * QWEN_MAX_SEQ + t] *= inv_sum;
|
||||
|
||||
// Weighted sum of V: attn_out[h] += att[t] * V[t] for each t
|
||||
for (int t = 0; t <= pos; t++) {
|
||||
float a = m->att[h * QWEN_MAX_SEQ + t];
|
||||
float *vt = m->kv_cache_v[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM;
|
||||
cblas_saxpy(QWEN_HEAD_DIM, a, vt, 1,
|
||||
attn_out + h * QWEN_HEAD_DIM, 1);
|
||||
}
|
||||
}
|
||||
|
||||
float o_out[QWEN_DIM];
|
||||
#if USE_ANE_PROJECTIONS
|
||||
ane_project(m->k_o[l], attn_out, o_out, QWEN_Q_DIM, D);
|
||||
#else
|
||||
cpu_project(m->wo[l], attn_out, o_out, QWEN_Q_DIM, D);
|
||||
#endif
|
||||
|
||||
// Residual
|
||||
for (int i = 0; i < D; i++) m->x[i] += o_out[i];
|
||||
|
||||
if (l == 0 && pos == 0) {
|
||||
float pan = 0;
|
||||
for (int i = 0; i < D; i++) pan += m->x[i] * m->x[i];
|
||||
printf(" L0 post-attn norm=%.4f first4=[%.6f, %.6f, %.6f, %.6f]\n",
|
||||
sqrtf(pan), m->x[0], m->x[1], m->x[2], m->x[3]);
|
||||
float on = 0;
|
||||
for (int i = 0; i < D; i++) on += o_out[i] * o_out[i];
|
||||
printf(" L0 o_proj out norm=%.4f first4=[%.6f, %.6f, %.6f, %.6f]\n",
|
||||
sqrtf(on), o_out[0], o_out[1], o_out[2], o_out[3]);
|
||||
}
|
||||
|
||||
// FFN RMSNorm
|
||||
qwen_rmsnorm(m->xb, m->x, m->rms_ffn[l], D);
|
||||
|
||||
// SwiGLU FFN
|
||||
#if USE_ANE_PROJECTIONS
|
||||
ane_project(m->k_gate[l], m->xb, m->hb, D, HD);
|
||||
ane_project(m->k_up[l], m->xb, m->hb2, D, HD);
|
||||
#else
|
||||
cpu_project(m->w_gate[l], m->xb, m->hb, D, HD);
|
||||
cpu_project(m->w_up[l], m->xb, m->hb2, D, HD);
|
||||
#endif
|
||||
|
||||
if (l == 0 && pos == 0) {
|
||||
float gn = 0, un = 0;
|
||||
for (int i = 0; i < HD; i++) { gn += m->hb[i]*m->hb[i]; un += m->hb2[i]*m->hb2[i]; }
|
||||
printf(" L0 gate norm=%.4f up norm=%.4f\n", sqrtf(gn), sqrtf(un));
|
||||
printf(" L0 gate first4=[%.6f, %.6f, %.6f, %.6f]\n",
|
||||
m->hb[0], m->hb[1], m->hb[2], m->hb[3]);
|
||||
}
|
||||
|
||||
qwen_silu(m->hb, HD);
|
||||
for (int i = 0; i < HD; i++) m->hb[i] *= m->hb2[i];
|
||||
|
||||
float ffn_out[QWEN_DIM];
|
||||
#if USE_ANE_PROJECTIONS
|
||||
ane_project(m->k_down[l], m->hb, ffn_out, HD, D);
|
||||
#else
|
||||
cpu_project(m->w_down[l], m->hb, ffn_out, HD, D);
|
||||
#endif
|
||||
|
||||
// Residual
|
||||
for (int i = 0; i < D; i++) m->x[i] += ffn_out[i];
|
||||
|
||||
// Debug: hidden state after each layer (first 3 layers, first token only)
|
||||
if (l < 3 && pos == 0) {
|
||||
float hn = 0;
|
||||
for (int i = 0; i < D; i++) hn += m->x[i] * m->x[i];
|
||||
printf(" C hidden[%d] norm=%.4f first4=[%.4f, %.4f, %.4f, %.4f]\n",
|
||||
l+1, sqrtf(hn), m->x[0], m->x[1], m->x[2], m->x[3]);
|
||||
}
|
||||
}
|
||||
|
||||
// Final RMSNorm
|
||||
qwen_rmsnorm(m->xb, m->x, m->rms_final, D);
|
||||
|
||||
// Debug: check final hidden state before LM head
|
||||
if (m->pos < 2) {
|
||||
float fn = 0;
|
||||
for (int i = 0; i < D; i++) fn += m->xb[i] * m->xb[i];
|
||||
printf(" Final hidden norm=%.4f (first 4: %.6f %.6f %.6f %.6f)\n",
|
||||
sqrtf(fn), m->xb[0], m->xb[1], m->xb[2], m->xb[3]);
|
||||
}
|
||||
|
||||
// LM head via Accelerate BLAS: logits = embed @ xb
|
||||
// embed is [vocab, dim] row-major
|
||||
cblas_sgemv(CblasRowMajor, CblasNoTrans,
|
||||
QWEN_VOCAB, D,
|
||||
1.0f, m->embed, D,
|
||||
m->xb, 1,
|
||||
0.0f, m->logits, 1);
|
||||
|
||||
// Debug: check logits
|
||||
if (m->pos < 2) {
|
||||
float lmax = m->logits[0], lmin = m->logits[0];
|
||||
int nonzero = 0;
|
||||
for (int i = 0; i < QWEN_VOCAB; i++) {
|
||||
if (m->logits[i] > lmax) lmax = m->logits[i];
|
||||
if (m->logits[i] < lmin) lmin = m->logits[i];
|
||||
if (m->logits[i] != 0.0f) nonzero++;
|
||||
}
|
||||
printf(" Logits: min=%.4f max=%.4f nonzero=%d/%d\n", lmin, lmax, nonzero, QWEN_VOCAB);
|
||||
}
|
||||
|
||||
m->pos++;
|
||||
|
||||
// Argmax
|
||||
int max_idx = 0;
|
||||
float max_val = m->logits[0];
|
||||
for (int i = 1; i < QWEN_VOCAB; i++) {
|
||||
if (m->logits[i] > max_val) {
|
||||
max_val = m->logits[i];
|
||||
max_idx = i;
|
||||
}
|
||||
}
|
||||
return max_idx;
|
||||
}
|
||||
|
||||
// ── Compile all ANE kernels ──────────────────────────────────────────
|
||||
|
||||
static void qwen_compile_kernels(QwenModel *m) {
|
||||
int D = QWEN_DIM, HD = QWEN_HIDDEN;
|
||||
printf("Compiling %d ANE kernels...\n", QWEN_LAYERS * 7 + 1);
|
||||
for (int l = 0; l < QWEN_LAYERS; l++) {
|
||||
m->k_q[l] = compile_conv_kernel(m->wq[l], D, QWEN_Q_DIM, 1);
|
||||
m->k_k[l] = compile_conv_kernel(m->wk[l], D, QWEN_KV_DIM, 1);
|
||||
m->k_v[l] = compile_conv_kernel(m->wv[l], D, QWEN_KV_DIM, 1);
|
||||
m->k_o[l] = compile_conv_kernel(m->wo[l], QWEN_Q_DIM, D, 1);
|
||||
m->k_gate[l] = compile_conv_kernel(m->w_gate[l], D, HD, 1);
|
||||
m->k_up[l] = compile_conv_kernel(m->w_up[l], D, HD, 1);
|
||||
m->k_down[l] = compile_conv_kernel(m->w_down[l], HD, D, 1);
|
||||
printf(" Layer %d/%d compiled\r", l+1, QWEN_LAYERS);
|
||||
fflush(stdout);
|
||||
}
|
||||
// LM head (tied = embedding, chunked into 16 pieces)
|
||||
for (int c = 0; c < QWEN_LM_CHUNKS; c++) {
|
||||
float *chunk_weights = m->embed + c * QWEN_LM_CHUNK_SIZE * D;
|
||||
m->k_lmhead[c] = compile_conv_kernel(chunk_weights, D, QWEN_LM_CHUNK_SIZE, 1);
|
||||
if (!m->k_lmhead[c]) {
|
||||
printf(" LM head chunk %d FAILED to compile\n", c);
|
||||
}
|
||||
}
|
||||
printf("\nAll kernels compiled.\n");
|
||||
}
|
||||
|
||||
// ── Allocate buffers ─────────────────────────────────────────────────
|
||||
|
||||
static void qwen_alloc(QwenModel *m) {
|
||||
m->x = (float*)calloc(QWEN_DIM, sizeof(float));
|
||||
m->xb = (float*)calloc(QWEN_DIM, sizeof(float));
|
||||
m->q = (float*)calloc(QWEN_Q_DIM, sizeof(float));
|
||||
m->k = (float*)calloc(QWEN_KV_DIM, sizeof(float));
|
||||
m->v = (float*)calloc(QWEN_KV_DIM, sizeof(float));
|
||||
m->att = (float*)calloc(QWEN_HEADS * QWEN_MAX_SEQ, sizeof(float));
|
||||
m->hb = (float*)calloc(QWEN_HIDDEN, sizeof(float));
|
||||
m->hb2 = (float*)calloc(QWEN_HIDDEN, sizeof(float));
|
||||
m->logits = (float*)calloc(QWEN_VOCAB, sizeof(float));
|
||||
for (int l = 0; l < QWEN_LAYERS; l++) {
|
||||
m->kv_cache_k[l] = (float*)calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float));
|
||||
m->kv_cache_v[l] = (float*)calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float));
|
||||
}
|
||||
m->pos = 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Run Qwen2.5-0.5B on ANE with proper tokenization.
|
||||
|
||||
Usage:
|
||||
python3 run.py "Your prompt here" [--max-tokens 50]
|
||||
"""
|
||||
import argparse
|
||||
import ctypes
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
INFERENCE_DIR = Path(__file__).parent
|
||||
WEIGHTS_PATH = INFERENCE_DIR / "qwen05b.bin"
|
||||
MODEL_DIR = Path.home() / "models" / "Qwen2.5-0.5B-Instruct"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("prompt", type=str)
|
||||
parser.add_argument("--max-tokens", type=int, default=50)
|
||||
args = parser.parse_args()
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
print("Loading tokenizer...")
|
||||
tok = AutoTokenizer.from_pretrained(str(MODEL_DIR), trust_remote_code=True)
|
||||
|
||||
# Build chat template
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant. Be concise."},
|
||||
{"role": "user", "content": args.prompt},
|
||||
]
|
||||
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
input_ids = tok.encode(text)
|
||||
print(f"Prompt tokens: {len(input_ids)}")
|
||||
|
||||
# Run the C binary — pass token IDs as arguments
|
||||
import subprocess
|
||||
binary = str(INFERENCE_DIR / "qwen_ane")
|
||||
|
||||
# We need to modify the binary to accept token IDs as input
|
||||
# For now, print the token IDs so we can verify tokenization
|
||||
print(f"First 10 tokens: {input_ids[:10]}")
|
||||
print(f"Token text: {[tok.decode([t]) for t in input_ids[:10]]}")
|
||||
print(f"\nRunning ANE inference with {len(input_ids)} prompt tokens + {args.max_tokens} generation...")
|
||||
|
||||
# Call binary with token IDs piped via stdin
|
||||
result = subprocess.run(
|
||||
[binary, str(WEIGHTS_PATH), " ".join(str(t) for t in input_ids),
|
||||
str(args.max_tokens)],
|
||||
capture_output=True, text=True, timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(result.stderr[:500], file=sys.stderr)
|
||||
|
||||
# Parse output token IDs from binary stdout
|
||||
output_ids = []
|
||||
for line in result.stdout.split("\n"):
|
||||
if line.startswith("OUT:"):
|
||||
ids = [int(x) for x in line[4:].split() if x.isdigit()]
|
||||
output_ids.extend(ids)
|
||||
|
||||
if output_ids:
|
||||
decoded = tok.decode(output_ids, skip_special_tokens=True)
|
||||
print(f"\n=== Response ===\n{decoded}")
|
||||
else:
|
||||
print("\n(No output tokens parsed — binary may need token ID input mode)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -21,6 +21,14 @@ train_large: train_large.m $(HEADERS_LARGE)
|
|||
train_large_ane: train_large_ane.m $(HEADERS_ANE)
|
||||
$(CC) $(CFLAGS) -o $@ train_large_ane.m $(LDFLAGS) -framework Accelerate
|
||||
|
||||
HEADERS_OPT = $(HEADERS_LARGE) stories_cpu_ops_opt.h
|
||||
|
||||
train_opt: train_opt.m $(HEADERS_OPT)
|
||||
$(CC) $(CFLAGS) -o $@ train_opt.m $(LDFLAGS) -framework Accelerate -framework Metal -framework MetalPerformanceShaders
|
||||
|
||||
train_double_buffer: train_double_buffer.m $(HEADERS_LARGE)
|
||||
$(CC) $(CFLAGS) -o $@ train_double_buffer.m $(LDFLAGS) -framework Accelerate
|
||||
|
||||
PROBES = test_weight_reload test_perf_stats test_qos_sweep test_ane_advanced
|
||||
|
||||
test_rmsnorm_bwd: test_rmsnorm_bwd.m $(HEADERS_ANE)
|
||||
|
|
@ -65,7 +73,7 @@ verify-flags:
|
|||
@xcrun clang --version
|
||||
|
||||
clean:
|
||||
rm -f train train_large train_large_ane $(PROBES) test_rmsnorm_bwd test_classifier
|
||||
rm -f train train_large train_large_ane train_opt train_double_buffer $(PROBES) test_rmsnorm_bwd test_classifier
|
||||
|
||||
.PHONY: clean tokenize probes verify-flags data setup
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,110 @@
|
|||
// stories_cpu_ops_opt.h — Optimized CPU operations: NEON Adam, vectorized embedding
|
||||
#pragma once
|
||||
#include "stories_cpu_ops.h"
|
||||
#include <arm_neon.h>
|
||||
|
||||
// ===== NEON-vectorized Adam optimizer =====
|
||||
// ~3-3.5x faster than scalar version for large param counts
|
||||
// Uses vrsqrteq_f32 + one Newton-Raphson step for fast reciprocal sqrt
|
||||
static void adam_update_opt(float *w, const float *g, AdamState *s, int t,
|
||||
float lr, float b1, float b2, float eps) {
|
||||
float bc1 = 1.0f - powf(b1, t);
|
||||
float bc2 = 1.0f - powf(b2, t);
|
||||
float inv_bc1 = 1.0f / bc1;
|
||||
float inv_bc2 = 1.0f / bc2;
|
||||
float one_minus_b1 = 1.0f - b1;
|
||||
float one_minus_b2 = 1.0f - b2;
|
||||
|
||||
float32x4_t vb1 = vdupq_n_f32(b1);
|
||||
float32x4_t vb2 = vdupq_n_f32(b2);
|
||||
float32x4_t v1mb1 = vdupq_n_f32(one_minus_b1);
|
||||
float32x4_t v1mb2 = vdupq_n_f32(one_minus_b2);
|
||||
float32x4_t vinv_bc1 = vdupq_n_f32(inv_bc1);
|
||||
float32x4_t vinv_bc2 = vdupq_n_f32(inv_bc2);
|
||||
float32x4_t vneg_lr = vdupq_n_f32(-lr);
|
||||
float32x4_t veps = vdupq_n_f32(eps);
|
||||
|
||||
size_t n = s->n;
|
||||
size_t i = 0;
|
||||
|
||||
// Process 4 elements at a time
|
||||
for (; i + 3 < n; i += 4) {
|
||||
// Load
|
||||
float32x4_t vm = vld1q_f32(s->m + i);
|
||||
float32x4_t vv = vld1q_f32(s->v + i);
|
||||
float32x4_t vg = vld1q_f32(g + i);
|
||||
float32x4_t vw = vld1q_f32(w + i);
|
||||
|
||||
// m = b1*m + (1-b1)*g
|
||||
vm = vmlaq_f32(vmulq_f32(vb1, vm), v1mb1, vg);
|
||||
// v = b2*v + (1-b2)*g*g
|
||||
float32x4_t g2 = vmulq_f32(vg, vg);
|
||||
vv = vmlaq_f32(vmulq_f32(vb2, vv), v1mb2, g2);
|
||||
|
||||
// Store updated m, v
|
||||
vst1q_f32(s->m + i, vm);
|
||||
vst1q_f32(s->v + i, vv);
|
||||
|
||||
// mhat = m / bc1, vhat = v / bc2
|
||||
float32x4_t mhat = vmulq_f32(vm, vinv_bc1);
|
||||
float32x4_t vhat = vmulq_f32(vv, vinv_bc2);
|
||||
|
||||
// Fast reciprocal sqrt: vrsqrteq + one Newton-Raphson iteration
|
||||
// rsqrt_est ≈ 1/sqrt(vhat)
|
||||
float32x4_t rsqrt_est = vrsqrteq_f32(vhat);
|
||||
// Newton-Raphson: rsqrt *= (3 - vhat * rsqrt^2) / 2
|
||||
float32x4_t rsqrt_sq = vmulq_f32(rsqrt_est, rsqrt_est);
|
||||
float32x4_t nr_step = vrsqrtsq_f32(vhat, rsqrt_sq);
|
||||
rsqrt_est = vmulq_f32(rsqrt_est, nr_step);
|
||||
|
||||
// w -= lr * mhat / (sqrt(vhat) + eps)
|
||||
// = w + (-lr) * mhat * (1/(sqrt(vhat) + eps))
|
||||
// Compute sqrt(vhat) from rsqrt: sqrt = vhat * rsqrt(vhat) (avoids division)
|
||||
float32x4_t sqrt_vhat = vmulq_f32(vhat, rsqrt_est);
|
||||
float32x4_t denom = vaddq_f32(sqrt_vhat, veps);
|
||||
|
||||
// Use vdivq_f32 for the final division (accurate, eps-adjusted)
|
||||
float32x4_t update = vmulq_f32(vneg_lr, vdivq_f32(mhat, denom));
|
||||
vw = vaddq_f32(vw, update);
|
||||
|
||||
vst1q_f32(w + i, vw);
|
||||
}
|
||||
|
||||
// Scalar tail
|
||||
for (; i < n; i++) {
|
||||
s->m[i] = b1 * s->m[i] + one_minus_b1 * g[i];
|
||||
s->v[i] = b2 * s->v[i] + one_minus_b2 * g[i] * g[i];
|
||||
float mh = s->m[i] * inv_bc1;
|
||||
float vh = s->v[i] * inv_bc2;
|
||||
w[i] -= lr * mh / (sqrtf(vh) + eps);
|
||||
}
|
||||
}
|
||||
|
||||
// ===== Vectorized embedding lookup =====
|
||||
// Gather rows from [VOCAB, DIM] row-major embed table → x [DIM, SEQ] channel-first
|
||||
// Strategy: gather token rows into temp buffer [SEQ, DIM], then transpose via vDSP_mtrans
|
||||
static void embed_lookup_opt(float *x, const float *embed, const uint16_t *tokens,
|
||||
int dim, int seq, float *tmp) {
|
||||
// Gather: tmp[t*dim + d] = embed[tokens[t]*dim + d]
|
||||
for (int t = 0; t < seq; t++) {
|
||||
memcpy(tmp + t * dim, embed + tokens[t] * dim, dim * sizeof(float));
|
||||
}
|
||||
// Transpose [SEQ, DIM] → [DIM, SEQ]: x[d*seq + t] = tmp[t*dim + d]
|
||||
vDSP_mtrans(tmp, 1, x, 1, (vDSP_Length)dim, (vDSP_Length)seq);
|
||||
}
|
||||
|
||||
// ===== Vectorized embedding backward =====
|
||||
// Accumulate dE[tok] += dx[:,t] for each position
|
||||
// Strategy: transpose dx [DIM, SEQ] → tmp [SEQ, DIM], then accumulate rows
|
||||
static void embed_backward_opt(float *d_embed, const float *dx, const uint16_t *tokens,
|
||||
int dim, int seq, float *tmp) {
|
||||
// Transpose [DIM, SEQ] → [SEQ, DIM]: tmp[t*dim + d] = dx[d*seq + t]
|
||||
vDSP_mtrans(dx, 1, tmp, 1, (vDSP_Length)seq, (vDSP_Length)dim);
|
||||
// Scatter-add: d_embed[tok*dim .. (tok+1)*dim] += tmp[t*dim .. (t+1)*dim]
|
||||
for (int t = 0; t < seq; t++) {
|
||||
vDSP_vadd(tmp + t * dim, 1,
|
||||
d_embed + tokens[t] * dim, 1,
|
||||
d_embed + tokens[t] * dim, 1,
|
||||
(vDSP_Length)dim);
|
||||
}
|
||||
}
|
||||
|
|
@ -85,6 +85,12 @@ static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int
|
|||
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
|
||||
IOSurfaceUnlock(s, 0, NULL);
|
||||
}
|
||||
// Read raw fp16 from IOSurface without conversion (for fp16 activation cache)
|
||||
static void io_read_raw_fp16(IOSurfaceRef s, _Float16 *data, int ch_off, int channels, int sp) {
|
||||
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
|
||||
memcpy(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp * sizeof(_Float16));
|
||||
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
|
||||
}
|
||||
|
||||
// Kernel compile/eval
|
||||
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,791 @@
|
|||
// train_double_buffer.m — Double-buffered async ANE training for stories110M
|
||||
// Based on train_large.m with the key innovation: compile and eval overlap via GCD
|
||||
// Discovery: probe_v2.m proved ANE compile and eval can run in parallel
|
||||
// Architecture: two kernel sets (A/B), background compile while active set runs
|
||||
// 5 weight-bearing ANE kernels per layer × 12 layers = 60 per compile batch
|
||||
#include <stdatomic.h>
|
||||
#include "stories_io.h"
|
||||
#include "stories_mil.h"
|
||||
#include "stories_cpu_ops.h"
|
||||
|
||||
// Double-buffer needs more compile budget than single-buffer
|
||||
// The original MAX_COMPILES=100 only allows 1 batch per exec() restart
|
||||
// We push higher to allow initial compile + at least 1 background compile
|
||||
// If ANE rejects at ~119, the exec() restart will handle it gracefully
|
||||
#define DB_MAX_COMPILES 250
|
||||
|
||||
#define CKPT_PATH_DEFAULT "ane_db_ckpt.bin"
|
||||
#define MODEL_PATH_DEFAULT "../../assets/models/stories110M.bin"
|
||||
#define DATA_PATH_DEFAULT "tinystories_data00.bin"
|
||||
|
||||
static const char *get_path(const char *env_var, const char *default_val) {
|
||||
const char *v = getenv(env_var);
|
||||
return (v && v[0]) ? v : default_val;
|
||||
}
|
||||
|
||||
// ===== Weight loading from llama2.c format =====
|
||||
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
|
||||
FILE *f = fopen(path, "rb");
|
||||
if (!f) { printf("Cannot open %s\n", path); return false; }
|
||||
Llama2Config cfg;
|
||||
fread(&cfg, sizeof(cfg), 1, f);
|
||||
printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
|
||||
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
|
||||
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
|
||||
printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS);
|
||||
fclose(f); return false;
|
||||
}
|
||||
int V = abs(cfg.vocab_size);
|
||||
bool shared = cfg.vocab_size > 0;
|
||||
|
||||
// Read in llama2.c order: embed, rms_att[all], wq[all], wk[all], wv[all], wo[all],
|
||||
// rms_ffn[all], w1[all], w2[all], w3[all], rms_final, [wcls]
|
||||
fread(embed, 4, V * DIM, f);
|
||||
|
||||
// rms_att weights for all layers (contiguous)
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
|
||||
// wq for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
|
||||
// wk for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
|
||||
// wv for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
|
||||
// wo for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
|
||||
// rms_ffn weights for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
|
||||
// w1 for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
|
||||
// w2 for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
|
||||
// w3 for all layers
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
|
||||
// rms_final
|
||||
fread(rms_final, 4, DIM, f);
|
||||
// wcls = embed if shared (we just use embed pointer)
|
||||
|
||||
fclose(f);
|
||||
printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls");
|
||||
return true;
|
||||
}
|
||||
|
||||
// ===== Compile one layer's kernels =====
|
||||
static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) {
|
||||
lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{
|
||||
@"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)},
|
||||
@"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)},
|
||||
@"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)},
|
||||
@"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)},
|
||||
@"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)},
|
||||
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
|
||||
}), DIM*SEQ*2, 6*DIM*SEQ*2);
|
||||
|
||||
lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{
|
||||
@"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)},
|
||||
@"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)},
|
||||
@"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)},
|
||||
@"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)},
|
||||
}), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2);
|
||||
|
||||
lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{
|
||||
@"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)},
|
||||
@"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)},
|
||||
@"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)},
|
||||
}), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2);
|
||||
|
||||
lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{
|
||||
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
|
||||
@"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)},
|
||||
}), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
|
||||
|
||||
lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{
|
||||
@"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)},
|
||||
@"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)},
|
||||
@"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)},
|
||||
}), 3*DIM*SEQ*2, DIM*SEQ*2);
|
||||
|
||||
return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd;
|
||||
}
|
||||
|
||||
// Compile weight-free sdpaBwd2 (only needs once, no weights)
|
||||
static Kern *compile_sdpa_bwd2(void) {
|
||||
return compile_kern_mil_w(gen_sdpa_bwd2(), @{},
|
||||
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
|
||||
}
|
||||
|
||||
static void free_layer_kernels(LayerKernels *lk) {
|
||||
free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd);
|
||||
free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd);
|
||||
// sdpaBwd2 is shared, freed separately
|
||||
lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL;
|
||||
}
|
||||
|
||||
// ===== Checkpoint save/load =====
|
||||
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
|
||||
double cc, double ct, double cw, int cs, int cb, int adam_t,
|
||||
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
|
||||
float *embed, AdamState *aembed) {
|
||||
FILE *f = fopen(path, "wb");
|
||||
CkptHdr h = {0};
|
||||
h.magic = 0x424C5A54; h.version = 2;
|
||||
h.step = step; h.total_steps = total_steps;
|
||||
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
|
||||
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
|
||||
h.lr = lr; h.loss = loss;
|
||||
h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw;
|
||||
h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t;
|
||||
fwrite(&h, sizeof(h), 1, f);
|
||||
// Per-layer weights + adam
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
|
||||
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
|
||||
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
|
||||
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
|
||||
// Adam state
|
||||
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
|
||||
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
|
||||
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
|
||||
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
|
||||
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
|
||||
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
|
||||
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
|
||||
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
|
||||
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
|
||||
}
|
||||
fwrite(rms_final,4,DIM,f);
|
||||
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
|
||||
fwrite(embed,4,VOCAB*DIM,f);
|
||||
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
|
||||
fclose(f);
|
||||
}
|
||||
|
||||
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
|
||||
double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t,
|
||||
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
|
||||
float *embed, AdamState *aembed) {
|
||||
FILE *f = fopen(path, "rb");
|
||||
if (!f) return false;
|
||||
CkptHdr h;
|
||||
fread(&h, sizeof(h), 1, f);
|
||||
if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; }
|
||||
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
|
||||
*cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall;
|
||||
*cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t;
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
|
||||
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
|
||||
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
|
||||
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
|
||||
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
|
||||
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
|
||||
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
|
||||
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
|
||||
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
|
||||
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
|
||||
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
|
||||
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
|
||||
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
|
||||
}
|
||||
fread(rms_final,4,DIM,f);
|
||||
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
|
||||
fread(embed,4,VOCAB*DIM,f);
|
||||
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
|
||||
fclose(f);
|
||||
return true;
|
||||
}
|
||||
|
||||
// ===== Main =====
|
||||
int main(int argc, char *argv[]) {
|
||||
@autoreleasepool {
|
||||
setbuf(stdout, NULL);
|
||||
ane_init();
|
||||
mach_timebase_info(&g_tb);
|
||||
|
||||
int total_steps = 10000;
|
||||
float lr = 3e-4f;
|
||||
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
|
||||
int adam_t = 0, start_step = 0;
|
||||
|
||||
const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT);
|
||||
const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT);
|
||||
const char *data_path = get_path("ANE_DATA_PATH", DATA_PATH_DEFAULT);
|
||||
|
||||
bool do_resume = false;
|
||||
for (int i=1; i<argc; i++) {
|
||||
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
|
||||
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
|
||||
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
|
||||
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
|
||||
}
|
||||
|
||||
// Allocate per-layer state
|
||||
LayerWeights lw[NLAYERS];
|
||||
LayerAdam la[NLAYERS];
|
||||
LayerActs acts[NLAYERS];
|
||||
LayerGrads grads[NLAYERS];
|
||||
// Double-buffer: two sets of kernels
|
||||
LayerKernels kern_A[NLAYERS], kern_B[NLAYERS];
|
||||
LayerKernels *kern_active = kern_A; // currently running evals
|
||||
LayerKernels *kern_pending = kern_B; // being compiled in background
|
||||
static _Atomic bool pending_ready = false; // signal: pending compile done
|
||||
static _Atomic bool bg_compile_running = false;
|
||||
dispatch_queue_t compile_q = dispatch_queue_create("ane.compile.bg", DISPATCH_QUEUE_SERIAL);
|
||||
// Legacy alias for code that uses kern[L]
|
||||
#define kern kern_active
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
lw[L] = layer_weights_alloc();
|
||||
la[L] = layer_adam_alloc();
|
||||
acts[L] = layer_acts_alloc();
|
||||
grads[L] = layer_grads_alloc();
|
||||
memset(&kern_A[L], 0, sizeof(LayerKernels));
|
||||
memset(&kern_B[L], 0, sizeof(LayerKernels));
|
||||
}
|
||||
|
||||
// Final RMSNorm + embedding + classifier
|
||||
float *rms_final = (float*)malloc(DIM*4);
|
||||
float *embed = (float*)malloc(VOCAB*DIM*4); // [VOCAB, DIM] row-major
|
||||
float *grms_final = (float*)calloc(DIM, 4);
|
||||
float *gembed = (float*)calloc(VOCAB*DIM, 4);
|
||||
AdamState arms_final = adam_alloc(DIM);
|
||||
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
|
||||
|
||||
double cum_compile=0, cum_train=0, cum_wall=0;
|
||||
int cum_steps=0, cum_batches=0;
|
||||
|
||||
float resume_loss = 0;
|
||||
bool resuming = false;
|
||||
if (do_resume) {
|
||||
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
|
||||
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
|
||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
|
||||
}
|
||||
if (!resuming) {
|
||||
printf("=== ANE Training: Stories110M (12 layers) ===\n");
|
||||
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
||||
if (!load_pretrained(lw, rms_final, embed, model_path)) {
|
||||
printf("Pretrained load failed, using random init\n");
|
||||
srand48(42);
|
||||
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
|
||||
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
|
||||
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
|
||||
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
|
||||
}
|
||||
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
|
||||
float escale = 0.02f;
|
||||
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
|
||||
}
|
||||
size_t tp = (size_t)NLAYERS*LAYER_PARAMS + DIM + (size_t)VOCAB*DIM;
|
||||
double xfmr_params = (double)NLAYERS*LAYER_PARAMS;
|
||||
double embed_params = (double)VOCAB*DIM;
|
||||
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
|
||||
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
|
||||
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
|
||||
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
|
||||
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
|
||||
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
|
||||
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
|
||||
double cls_f = 2.0*VOCAB*DIM*SEQ;
|
||||
double total_f = fwd_f + bwd_dx_f + bwd_dw_f + sdpa_f + cls_f*3;
|
||||
double ane_f = fwd_f + bwd_dx_f + sdpa_f;
|
||||
printf("FLOPs/step: fwd=%.0fM bwd_dx=%.0fM bwd_dW=%.0fM sdpa_bwd=%.0fM total=%.0fM\n",
|
||||
fwd_f/1e6, bwd_dx_f/1e6, bwd_dw_f/1e6, sdpa_f/1e6, total_f/1e6);
|
||||
printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | CPU: dW+cls (cblas)\n\n", ane_f/1e6);
|
||||
}
|
||||
|
||||
// mmap token data (or generate synthetic if not available)
|
||||
uint16_t *token_data = NULL;
|
||||
size_t n_tokens = 0;
|
||||
size_t data_len = 0;
|
||||
bool synthetic_data = false;
|
||||
int data_fd = open(data_path, O_RDONLY);
|
||||
if (data_fd >= 0) {
|
||||
struct stat st; fstat(data_fd, &st);
|
||||
data_len = st.st_size;
|
||||
token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
|
||||
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
|
||||
n_tokens = data_len / 2;
|
||||
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
|
||||
} else {
|
||||
// Synthetic data for double-buffer benchmark
|
||||
synthetic_data = true;
|
||||
n_tokens = 100000;
|
||||
data_len = n_tokens * 2;
|
||||
token_data = (uint16_t*)malloc(data_len);
|
||||
srand48(123);
|
||||
for (size_t i = 0; i < n_tokens; i++)
|
||||
token_data[i] = (uint16_t)(drand48() * (VOCAB - 1));
|
||||
printf("[DB] Using synthetic data: %zu tokens (benchmark mode)\n", n_tokens);
|
||||
}
|
||||
|
||||
// Gradient buffers shared across layers (reused each step)
|
||||
float *dy = (float*)malloc(SEQ*DIM*4); // gradient flowing backward
|
||||
float *dffn = (float*)malloc(SEQ*DIM*4);
|
||||
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
|
||||
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
|
||||
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
|
||||
float *dx2 = (float*)malloc(SEQ*DIM*4);
|
||||
float *do_out_buf = (float*)malloc(SEQ*DIM*4);
|
||||
float *dq = (float*)malloc(SEQ*DIM*4);
|
||||
float *dk = (float*)malloc(SEQ*DIM*4);
|
||||
float *dv = (float*)malloc(SEQ*DIM*4);
|
||||
float *dx_attn = (float*)malloc(SEQ*DIM*4);
|
||||
|
||||
// x buffer for input to each layer (channel-first [DIM, SEQ])
|
||||
float *x_cur = (float*)malloc(SEQ*DIM*4);
|
||||
float *x_final = (float*)malloc(SEQ*DIM*4); // after final rmsnorm
|
||||
float *logits = (float*)malloc(SEQ*VOCAB*4); // [VOCAB, SEQ] for cross-entropy
|
||||
float *dlogits = (float*)malloc(SEQ*VOCAB*4);
|
||||
|
||||
// Compile static sdpaBwd2 kernels (no weights, one per layer)
|
||||
Kern *sdpaBwd2[NLAYERS];
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
sdpaBwd2[L] = compile_sdpa_bwd2();
|
||||
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
|
||||
}
|
||||
|
||||
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
|
||||
dispatch_group_t dw_grp = dispatch_group_create();
|
||||
|
||||
float last_loss = 999.0f;
|
||||
double total_compile_ms=0, total_train_ms=0;
|
||||
int total_steps_done=0, total_batches=0;
|
||||
uint64_t t_wall_start = mach_absolute_time();
|
||||
|
||||
srand48(42 + start_step);
|
||||
|
||||
// ===== DOUBLE-BUFFER: Initial synchronous compile (first batch only) =====
|
||||
printf(" [DB] Initial compile (synchronous)...\n");
|
||||
{
|
||||
uint64_t tc = mach_absolute_time();
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
printf(" Compiling layer %d/%d... (%d compiles)\r", L+1, NLAYERS, g_compile_count);
|
||||
fflush(stdout);
|
||||
if (!compile_layer_kernels(&kern_active[L], &lw[L])) {
|
||||
printf("\nInitial compile failed at layer %d\n", L);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
// Compile static sdpaBwd2 kernels
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
if (!sdpaBwd2[L]) {
|
||||
sdpaBwd2[L] = compile_sdpa_bwd2();
|
||||
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
|
||||
}
|
||||
}
|
||||
double cms = tb_ms(mach_absolute_time() - tc);
|
||||
total_compile_ms += cms;
|
||||
printf(" [DB] Initial compile: %d kernels in %.0fms\n", TOTAL_WEIGHT_KERNELS, cms);
|
||||
}
|
||||
|
||||
// Helper block: compile all layers into a kernel set
|
||||
// Captured by the GCD block for background compilation
|
||||
void (^compile_into)(LayerKernels *, LayerWeights *) = ^(LayerKernels *target, LayerWeights *weights) {
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
free_layer_kernels(&target[L]);
|
||||
if (!compile_layer_kernels(&target[L], &weights[L])) {
|
||||
printf("\n [DB] Background compile failed at layer %d\n", L);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int step = start_step;
|
||||
int batches_since_swap = 0;
|
||||
double total_stall_ms = 0;
|
||||
|
||||
while (step < total_steps) {
|
||||
// Check compile budget
|
||||
if (g_compile_count + TOTAL_WEIGHT_KERNELS > DB_MAX_COMPILES) {
|
||||
// Wait for any in-flight background compile
|
||||
dispatch_sync(compile_q, ^{});
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
free_layer_kernels(&kern_A[L]);
|
||||
free_layer_kernels(&kern_B[L]);
|
||||
free_kern(sdpaBwd2[L]); sdpaBwd2[L] = NULL;
|
||||
}
|
||||
#undef kern
|
||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
|
||||
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
|
||||
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
|
||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
|
||||
fflush(stdout);
|
||||
execl(argv[0], argv[0], "--resume", NULL);
|
||||
perror("execl"); return 1;
|
||||
#define kern kern_active
|
||||
}
|
||||
|
||||
// ===== DOUBLE-BUFFER: Check if pending kernels are ready to swap =====
|
||||
if (atomic_load(&pending_ready)) {
|
||||
// Swap: pending becomes active, old active becomes recycle target
|
||||
LayerKernels *old_active = kern_active;
|
||||
kern_active = kern_pending;
|
||||
kern_pending = old_active;
|
||||
atomic_store(&pending_ready, false);
|
||||
batches_since_swap = 0;
|
||||
printf(" [DB] Swapped kernels (stall=0ms)\n");
|
||||
}
|
||||
|
||||
// Re-compile sdpaBwd2 if needed (after exec restart)
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
if (!sdpaBwd2[L]) {
|
||||
sdpaBwd2[L] = compile_sdpa_bwd2();
|
||||
if (!sdpaBwd2[L]) { printf("sdpaBwd2 recompile failed\n"); return 1; }
|
||||
}
|
||||
}
|
||||
|
||||
// Zero gradient accumulators
|
||||
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
|
||||
memset(grms_final, 0, DIM*4);
|
||||
memset(gembed, 0, (size_t)VOCAB*DIM*4);
|
||||
|
||||
int steps_batch = 0;
|
||||
uint64_t tt = mach_absolute_time();
|
||||
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0;
|
||||
|
||||
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
|
||||
uint64_t t0,t1;
|
||||
// Sample random position in token data
|
||||
size_t max_pos = n_tokens - SEQ - 1;
|
||||
size_t pos = (size_t)(drand48() * max_pos);
|
||||
uint16_t *input_tokens = token_data + pos;
|
||||
uint16_t *target_tokens = token_data + pos + 1;
|
||||
|
||||
// Embedding lookup → x_cur [DIM, SEQ] channel-first
|
||||
t0=mach_absolute_time();
|
||||
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
|
||||
|
||||
// ===== FORWARD (12 layers) =====
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
LayerActs *ac = &acts[L];
|
||||
|
||||
// Save layer input for rmsnorm1 backward
|
||||
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
|
||||
// Attention forward: x_cur → o_out,Q,K,V,attn_out,xnorm
|
||||
t0=mach_absolute_time();
|
||||
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
|
||||
t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1;
|
||||
io_write_fp16(kern[L].fwdAttn->ioIn, x_cur, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
ane_eval(kern[L].fwdAttn);
|
||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||
io_read_fp16(kern[L].fwdAttn->ioOut, ac->o_out, 0, DIM, SEQ);
|
||||
io_read_fp16(kern[L].fwdAttn->ioOut, ac->attn_out, 4*DIM, DIM, SEQ);
|
||||
io_read_fp16(kern[L].fwdAttn->ioOut, ac->xnorm, 5*DIM, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// FFN forward
|
||||
io_write_fp16(kern[L].fwdFFN->ioIn, ac->x2, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
ane_eval(kern[L].fwdFFN);
|
||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||
io_read_fp16(kern[L].fwdFFN->ioOut, ac->ffn_out, 0, DIM, SEQ);
|
||||
io_read_fp16(kern[L].fwdFFN->ioOut, ac->h1, DIM, HIDDEN, SEQ);
|
||||
io_read_fp16(kern[L].fwdFFN->ioOut, ac->h3, DIM+HIDDEN, HIDDEN, SEQ);
|
||||
io_read_fp16(kern[L].fwdFFN->ioOut, ac->silu_out, DIM+2*HIDDEN, HIDDEN, SEQ);
|
||||
io_read_fp16(kern[L].fwdFFN->ioOut, ac->x2norm, DIM+3*HIDDEN, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
|
||||
}
|
||||
|
||||
// Final RMSNorm (CPU)
|
||||
t0=mach_absolute_time();
|
||||
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// Classifier: logits = embed^T @ x_final
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||
VOCAB, SEQ, DIM, 1.0f,
|
||||
embed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
|
||||
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// Cross-entropy loss
|
||||
float loss = cross_entropy_loss(dlogits, logits, target_tokens, VOCAB, SEQ);
|
||||
last_loss = loss;
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// ===== BACKWARD =====
|
||||
// dlogits already computed by cross_entropy_loss
|
||||
|
||||
// Classifier backward: dx_final = embed^T @ dlogits, dembed += dlogits @ x_final^T
|
||||
// dx_final[DIM,SEQ] = embed^T[DIM,VOCAB] @ dlogits[VOCAB,SEQ]
|
||||
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
|
||||
DIM, SEQ, VOCAB, 1.0f,
|
||||
embed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
|
||||
|
||||
// dembed[VOCAB,DIM] += dlogits[VOCAB,SEQ] @ x_final^T[SEQ,DIM]
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
VOCAB, DIM, SEQ, 1.0f,
|
||||
dlogits, SEQ, x_final, SEQ, 1.0f, gembed, DIM);
|
||||
});
|
||||
|
||||
// Final RMSNorm backward
|
||||
float *dx_rms_final = (float*)calloc(SEQ*DIM, 4);
|
||||
rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ);
|
||||
memcpy(dy, dx_rms_final, SEQ*DIM*4);
|
||||
free(dx_rms_final);
|
||||
|
||||
// ===== BACKWARD (12 layers, reverse) =====
|
||||
for (int L=NLAYERS-1; L>=0; L--) {
|
||||
LayerActs *ac = &acts[L];
|
||||
LayerGrads *gr = &grads[L];
|
||||
|
||||
// dy is the gradient at the output of this layer
|
||||
// dffn = dy (residual connection: d(x2 + ffn) = dy for both)
|
||||
memcpy(dffn, dy, SEQ*DIM*4);
|
||||
|
||||
// FFN backward (ANE)
|
||||
io_write_fp16_at(kern[L].ffnBwd->ioIn, 0, dffn, DIM, SEQ);
|
||||
io_copy(kern[L].ffnBwd->ioIn, DIM, kern[L].fwdFFN->ioOut, DIM, 2*HIDDEN, SEQ);
|
||||
ane_eval(kern[L].ffnBwd);
|
||||
io_read_fp16(kern[L].ffnBwd->ioOut, dx_ffn, 0, DIM, SEQ);
|
||||
io_read_fp16(kern[L].ffnBwd->ioOut, dh1, DIM, HIDDEN, SEQ);
|
||||
io_read_fp16(kern[L].ffnBwd->ioOut, dh3, DIM+HIDDEN, HIDDEN, SEQ);
|
||||
|
||||
// dW FFN async
|
||||
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
|
||||
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
|
||||
float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4);
|
||||
float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4);
|
||||
float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4);
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
|
||||
1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
|
||||
1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
|
||||
1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM);
|
||||
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
|
||||
});
|
||||
|
||||
// RMSNorm2 backward
|
||||
memset(dx2, 0, SEQ*DIM*4);
|
||||
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
|
||||
// Add residual: dx2 += dy (from skip connection)
|
||||
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
|
||||
|
||||
// dWo async
|
||||
memcpy(do_out_buf, dx2, SEQ*DIM*4);
|
||||
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, do_out_buf, SEQ*DIM*4);
|
||||
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
|
||||
free(capt_do); free(capt_attn);
|
||||
});
|
||||
|
||||
// SDPA backward (ANE)
|
||||
io_copy(kern[L].sdpaBwd1->ioIn, 0, kern[L].fwdAttn->ioOut, DIM, 3*DIM, SEQ);
|
||||
io_write_fp16_at(kern[L].sdpaBwd1->ioIn, 3*DIM, dx2, DIM, SEQ);
|
||||
ane_eval(kern[L].sdpaBwd1);
|
||||
io_copy(sdpaBwd2[L]->ioIn, 0, kern[L].sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
|
||||
io_copy(sdpaBwd2[L]->ioIn, 2*SCORE_CH, kern[L].fwdAttn->ioOut, DIM, 2*DIM, SEQ);
|
||||
ane_eval(sdpaBwd2[L]);
|
||||
|
||||
io_read_fp16(sdpaBwd2[L]->ioOut, dq, 0, DIM, SEQ);
|
||||
io_read_fp16(sdpaBwd2[L]->ioOut, dk, DIM, DIM, SEQ);
|
||||
io_read_fp16(kern[L].sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
|
||||
|
||||
// dWq/dWk/dWv async
|
||||
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
|
||||
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk, SEQ*DIM*4);
|
||||
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
|
||||
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
|
||||
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
|
||||
});
|
||||
|
||||
// QKV backward (ANE)
|
||||
io_copy(kern[L].qkvBwd->ioIn, 0, sdpaBwd2[L]->ioOut, 0, 2*DIM, SEQ);
|
||||
io_copy(kern[L].qkvBwd->ioIn, 2*DIM, kern[L].sdpaBwd1->ioOut, 0, DIM, SEQ);
|
||||
ane_eval(kern[L].qkvBwd);
|
||||
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
|
||||
|
||||
// RMSNorm1 backward (using saved layer input)
|
||||
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
|
||||
rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
|
||||
|
||||
// dy for next layer (going backward) = dx_rms1 + dx2 residual
|
||||
// Actually: layer output = layer_input + o_out, and x2 = layer_input + o_out
|
||||
// So dx(layer_input) = dx_attn_rmsnorm + dx2 (residual from attn skip)
|
||||
// Wait, dx2 already includes the attn skip residual gradient.
|
||||
// dy = dx_rms1 (through rmsnorm1) is the gradient to the layer input
|
||||
// But there's also the skip connection: layer_input → x2 directly
|
||||
// So total gradient to layer_input = dx_rms1 + dx2_skip
|
||||
// dx2 was computed as rmsnorm2_bwd + dy(ffn_skip), which already flows to x2
|
||||
// x2 = layer_input + o_out, so d(layer_input) from x2 path = dx2
|
||||
// And d(layer_input) from attn path through rmsnorm1 = dx_rms1
|
||||
// Total: dy_prev = dx_rms1 (attn rmsnorm path)
|
||||
// Wait no - dx2 = d(loss)/d(x2), not d(loss)/d(layer_input)
|
||||
// d(layer_input) = d(loss)/d(x2) * d(x2)/d(layer_input) = dx2 (since x2 = input + o_out, d(x2)/d(input) = 1)
|
||||
// Plus the path through rmsnorm1: dx_rms1
|
||||
// Hmm but dx2 was already used as input to SDPA backward... let me reconsider.
|
||||
//
|
||||
// Actually the gradient flow is:
|
||||
// dy → split to (dffn, dy_skip) [dy_skip = dy due to residual]
|
||||
// dffn → ffnBwd → dx_ffn
|
||||
// dx_ffn → rmsnorm2_bwd → dx_rms2
|
||||
// dx2 = dx_rms2 + dy (skip connection from residual x2 → output)
|
||||
// dx2 → sdpaBwd → dx_attn through Wo^T
|
||||
// dx_attn → qkvBwd → dx_qkv
|
||||
// dx_qkv → rmsnorm1_bwd → dx_rms1
|
||||
// dy_prev_layer = dx_rms1 + dx2 (skip connection input → x2)
|
||||
//
|
||||
// So: dy for previous layer = dx_rms1 + dx2
|
||||
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms1[i] + dx2[i];
|
||||
free(dx_rms1);
|
||||
}
|
||||
|
||||
// Embedding backward
|
||||
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
|
||||
embed_backward(gembed, dy, input_tokens, DIM, SEQ);
|
||||
|
||||
steps_batch++;
|
||||
if (step % 10 == 0 || step == start_step)
|
||||
printf("step %-4d loss=%.4f\n", step, loss);
|
||||
|
||||
// JSON telemetry to stderr
|
||||
double step_ane = t_ane/steps_batch, step_io = t_io/steps_batch;
|
||||
double step_cls = t_cls/steps_batch, step_elem = t_elem/steps_batch;
|
||||
double step_rms = t_rms/steps_batch, step_cbw = t_cblas_wait/steps_batch;
|
||||
fprintf(stderr, "{\"type\":\"step\",\"step\":%d,\"loss\":%.6f,"
|
||||
"\"t_ane\":%.3f,\"t_io\":%.3f,\"t_cls\":%.3f,"
|
||||
"\"t_elem\":%.3f,\"t_rms\":%.3f,\"t_cblas_wait\":%.3f,"
|
||||
"\"compiles\":%d}\n",
|
||||
step, loss, step_ane, step_io, step_cls, step_elem, step_rms, step_cbw, g_compile_count);
|
||||
}
|
||||
double tms = tb_ms(mach_absolute_time() - tt);
|
||||
total_train_ms += tms;
|
||||
total_steps_done += steps_batch;
|
||||
total_batches++;
|
||||
|
||||
// Ensure all async dW finished
|
||||
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
|
||||
|
||||
// Adam update (scale gradients by 1/steps_batch)
|
||||
float gsc = 1.0f / steps_batch;
|
||||
adam_t++;
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
LayerGrads *g = &grads[L];
|
||||
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
|
||||
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
|
||||
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
|
||||
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
|
||||
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
|
||||
|
||||
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
}
|
||||
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
|
||||
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
// Scale and update embed
|
||||
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
|
||||
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
|
||||
// ===== DOUBLE-BUFFER: Start background compile with updated weights =====
|
||||
batches_since_swap++;
|
||||
// Only start bg compile if we have budget
|
||||
if (!atomic_load(&bg_compile_running) &&
|
||||
g_compile_count + TOTAL_WEIGHT_KERNELS <= DB_MAX_COMPILES) {
|
||||
atomic_store(&bg_compile_running, true);
|
||||
// Capture pointers (not stack arrays) for background block
|
||||
LayerKernels *bg_target = kern_pending;
|
||||
LayerWeights *bg_weights = lw; // decays to pointer, safe for block
|
||||
dispatch_async(compile_q, ^{
|
||||
compile_into(bg_target, bg_weights);
|
||||
atomic_store(&pending_ready, true);
|
||||
atomic_store(&bg_compile_running, false);
|
||||
});
|
||||
}
|
||||
|
||||
double cms = 0; // compile was async, no stall
|
||||
printf(" [batch %d: compile_stall=0ms train=%.1fms (%.1fms/step) compiles=%d bg=%s]\n",
|
||||
steps_batch, tms, tms/steps_batch, g_compile_count,
|
||||
atomic_load(&bg_compile_running) ? "compiling" : "idle");
|
||||
printf(" ane=%.1f io=%.1f cls=%.1f elem=%.1f rms=%.1f cblas_wait=%.1f ms/step\n",
|
||||
t_ane/steps_batch, t_io/steps_batch, t_cls/steps_batch, t_elem/steps_batch,
|
||||
t_rms/steps_batch, t_cblas_wait/steps_batch);
|
||||
|
||||
// JSON batch telemetry to stderr
|
||||
{
|
||||
double bf = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
|
||||
double bs = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
|
||||
double ane_f_batch = (bf*2 + bs) * steps_batch;
|
||||
double ane_tflops = ane_f_batch / (tms * 1e9);
|
||||
fprintf(stderr, "{\"type\":\"batch\",\"batch\":%d,\"compile_ms\":%.1f,"
|
||||
"\"train_ms\":%.1f,\"ms_per_step\":%.1f}\n",
|
||||
steps_batch, cms, tms, tms/steps_batch);
|
||||
fprintf(stderr, "{\"type\":\"perf\",\"ane_tflops\":%.3f,\"ane_util_pct\":%.2f}\n",
|
||||
ane_tflops, 100.0*ane_tflops/15.8);
|
||||
}
|
||||
}
|
||||
|
||||
// Efficiency report
|
||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||
total_compile_ms += cum_compile; total_train_ms += cum_train;
|
||||
wall += cum_wall; total_steps_done += cum_steps; total_batches += cum_batches;
|
||||
double fwd_flops = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
|
||||
double sdpa_flops = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
|
||||
double cls_flops = 2.0*VOCAB*DIM*SEQ;
|
||||
double total_flops = (fwd_flops*3 + sdpa_flops + cls_flops*3) * total_steps_done;
|
||||
double ane_flops = (fwd_flops*2 + sdpa_flops) * total_steps_done;
|
||||
printf("\n=== Efficiency Report ===\n");
|
||||
printf("Total steps: %d\n", total_steps_done);
|
||||
printf("Wall time: %.0f ms (%.1f s)\n", wall, wall/1000);
|
||||
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100*total_compile_ms/wall);
|
||||
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100*total_train_ms/wall);
|
||||
printf("Avg train: %.1f ms/step\n", total_train_ms/total_steps_done);
|
||||
printf("ANE TFLOPS: %.2f sustained\n", ane_flops / (total_train_ms * 1e9));
|
||||
printf("Total TFLOPS: %.2f (ANE+CPU)\n", total_flops / (total_train_ms * 1e9));
|
||||
printf("ANE utilization: %.1f%% of 15.8 TFLOPS\n", 100*ane_flops/(total_train_ms*1e9)/15.8);
|
||||
|
||||
// Wait for any in-flight background compile
|
||||
dispatch_sync(compile_q, ^{});
|
||||
|
||||
// Cleanup
|
||||
#undef kern
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
free_layer_kernels(&kern_A[L]);
|
||||
free_layer_kernels(&kern_B[L]);
|
||||
free_kern(sdpaBwd2[L]);
|
||||
layer_weights_free(&lw[L]);
|
||||
layer_adam_free(&la[L]);
|
||||
layer_acts_free(&acts[L]);
|
||||
layer_grads_free(&grads[L]);
|
||||
}
|
||||
if (synthetic_data) { free(token_data); }
|
||||
else { munmap(token_data, data_len); close(data_fd); }
|
||||
free(rms_final); free(embed); free(grms_final); free(gembed);
|
||||
adam_free(&arms_final); adam_free(&aembed);
|
||||
free(dy); free(dffn); free(dh1); free(dh3); free(dx_ffn); free(dx2);
|
||||
free(do_out_buf); free(dq); free(dk); free(dv); free(dx_attn);
|
||||
free(x_cur); free(x_final); free(logits); free(dlogits);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,971 @@
|
|||
// train_opt.m — Optimized train_large with:
|
||||
// Phase 1: NEON Adam, vectorized embed ops, pre-allocated capture buffers
|
||||
// Phase 2: Concurrent dW dispatch, fp16 activation cache
|
||||
// Phase 3: Metal GPU for weight gradient computation (dW)
|
||||
//
|
||||
// Key perf wins:
|
||||
// - Pre-allocated LayerCaptures: eliminates ~132 malloc/free per step
|
||||
// - Concurrent dW queue: individual sgemms run in parallel (was serial)
|
||||
// - fp16 activation cache: skip fp16→fp32 on main thread for dW-only buffers
|
||||
// - Metal GPU dW: ~12ms for all weight gradients vs ~435ms serial CPU
|
||||
// - NEON Adam: ~3x faster optimizer step
|
||||
// - Vectorized embed: vDSP_mtrans instead of scalar scatter/gather
|
||||
|
||||
#include "stories_io.h"
|
||||
#include "stories_mil.h"
|
||||
#include "stories_cpu_ops_opt.h"
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
|
||||
#define MODEL_PATH_DEFAULT "stories110M.bin"
|
||||
#define DATA_PATH_DEFAULT "tinystories_data00.bin"
|
||||
|
||||
static const char *get_path(const char *env_var, const char *default_val) {
|
||||
const char *v = getenv(env_var);
|
||||
return (v && v[0]) ? v : default_val;
|
||||
}
|
||||
|
||||
// ===== Pre-allocated capture buffers per layer (Phase 1) =====
|
||||
// Eliminates malloc/free in dispatch blocks
|
||||
typedef struct {
|
||||
// FFN dW captures
|
||||
float *dffn; // [DIM, SEQ]
|
||||
float *silu_out; // [HIDDEN, SEQ]
|
||||
float *dh1; // [HIDDEN, SEQ]
|
||||
float *dh3; // [HIDDEN, SEQ]
|
||||
float *x2norm; // [DIM, SEQ]
|
||||
// Attn dW captures
|
||||
float *do_buf; // [DIM, SEQ] (for dWo)
|
||||
float *attn_out; // [DIM, SEQ]
|
||||
// QKV dW captures
|
||||
float *dq; // [DIM, SEQ]
|
||||
float *dk; // [DIM, SEQ]
|
||||
float *dv; // [DIM, SEQ]
|
||||
float *xnorm; // [DIM, SEQ]
|
||||
// fp16 backward gradient cache (read raw from IOSurface, convert in dispatch block)
|
||||
_Float16 *dh1_fp16; // [HIDDEN, SEQ]
|
||||
_Float16 *dh3_fp16; // [HIDDEN, SEQ]
|
||||
_Float16 *dq_fp16; // [DIM, SEQ]
|
||||
_Float16 *dk_fp16; // [DIM, SEQ]
|
||||
_Float16 *dv_fp16; // [DIM, SEQ]
|
||||
} LayerCaptures;
|
||||
|
||||
static LayerCaptures layer_captures_alloc(void) {
|
||||
LayerCaptures c;
|
||||
c.dffn = (float*)malloc(SEQ * DIM * 4);
|
||||
c.silu_out = (float*)malloc(SEQ * HIDDEN * 4);
|
||||
c.dh1 = (float*)malloc(SEQ * HIDDEN * 4);
|
||||
c.dh3 = (float*)malloc(SEQ * HIDDEN * 4);
|
||||
c.x2norm = (float*)malloc(SEQ * DIM * 4);
|
||||
c.do_buf = (float*)malloc(SEQ * DIM * 4);
|
||||
c.attn_out = (float*)malloc(SEQ * DIM * 4);
|
||||
c.dq = (float*)malloc(SEQ * DIM * 4);
|
||||
c.dk = (float*)malloc(SEQ * DIM * 4);
|
||||
c.dv = (float*)malloc(SEQ * DIM * 4);
|
||||
c.xnorm = (float*)malloc(SEQ * DIM * 4);
|
||||
c.dh1_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2);
|
||||
c.dh3_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2);
|
||||
c.dq_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
|
||||
c.dk_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
|
||||
c.dv_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
|
||||
return c;
|
||||
}
|
||||
static void layer_captures_free(LayerCaptures *c) {
|
||||
free(c->dffn); free(c->silu_out); free(c->dh1); free(c->dh3);
|
||||
free(c->x2norm); free(c->do_buf); free(c->attn_out);
|
||||
free(c->dq); free(c->dk); free(c->dv); free(c->xnorm);
|
||||
free(c->dh1_fp16); free(c->dh3_fp16);
|
||||
free(c->dq_fp16); free(c->dk_fp16); free(c->dv_fp16);
|
||||
}
|
||||
|
||||
// ===== fp16 activation cache (Phase 2) =====
|
||||
// Store activations that are only used for dW as fp16 (skip main-thread conversion)
|
||||
typedef struct {
|
||||
_Float16 *xnorm_fp16; // [DIM, SEQ]
|
||||
_Float16 *attn_out_fp16; // [DIM, SEQ]
|
||||
_Float16 *x2norm_fp16; // [DIM, SEQ]
|
||||
_Float16 *silu_out_fp16; // [HIDDEN, SEQ]
|
||||
} LayerFP16Cache;
|
||||
|
||||
static LayerFP16Cache layer_fp16_cache_alloc(void) {
|
||||
LayerFP16Cache c;
|
||||
c.xnorm_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
|
||||
c.attn_out_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
|
||||
c.x2norm_fp16 = (_Float16*)malloc(SEQ * DIM * 2);
|
||||
c.silu_out_fp16 = (_Float16*)malloc(SEQ * HIDDEN * 2);
|
||||
return c;
|
||||
}
|
||||
static void layer_fp16_cache_free(LayerFP16Cache *c) {
|
||||
free(c->xnorm_fp16); free(c->attn_out_fp16);
|
||||
free(c->x2norm_fp16); free(c->silu_out_fp16);
|
||||
}
|
||||
|
||||
// ===== Metal GPU dW context (Phase 3) =====
|
||||
typedef struct {
|
||||
id<MTLDevice> device;
|
||||
id<MTLCommandQueue> queue;
|
||||
// Shared gradient accumulator buffers (one per weight matrix per layer)
|
||||
id<MTLBuffer> dW_bufs[NLAYERS][9]; // Wq,Wk,Wv,Wo,W1,W2,W3,rms_att,rms_ffn
|
||||
id<MTLCommandBuffer> lastCmdBuf; // Track last submitted buffer for sync
|
||||
} MetalDWContext;
|
||||
|
||||
// Weight matrix indices for Metal buffers
|
||||
enum { MW_Q=0, MW_K, MW_V, MW_O, MW_1, MW_2, MW_3, MW_RMSA, MW_RMSF };
|
||||
|
||||
static bool metal_dw_init(MetalDWContext *ctx) {
|
||||
ctx->device = MTLCreateSystemDefaultDevice();
|
||||
if (!ctx->device) { printf("[Metal] No GPU device\n"); return false; }
|
||||
ctx->queue = [ctx->device newCommandQueue];
|
||||
if (!ctx->queue) { printf("[Metal] No command queue\n"); return false; }
|
||||
|
||||
// Allocate shared-mode gradient accumulator buffers
|
||||
size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4,
|
||||
W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4};
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
for (int w = 0; w < 9; w++) {
|
||||
ctx->dW_bufs[L][w] = [ctx->device newBufferWithLength:sizes[w]
|
||||
options:MTLResourceStorageModeShared];
|
||||
if (!ctx->dW_bufs[L][w]) { printf("[Metal] Buffer alloc failed L=%d w=%d\n", L, w); return false; }
|
||||
}
|
||||
}
|
||||
printf("[Metal] GPU: %s\n", [[ctx->device name] UTF8String]);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void metal_dw_zero(MetalDWContext *ctx) {
|
||||
size_t sizes[9] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4,
|
||||
W1_SZ*4, W2_SZ*4, W3_SZ*4, DIM*4, DIM*4};
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
for (int w = 0; w < 9; w++) {
|
||||
memset([ctx->dW_bufs[L][w] contents], 0, sizes[w]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode a single dW sgemm to Metal command buffer using MPS
|
||||
// C[M,N] += A[M,K] @ B^T[N,K] (i.e., C += A @ B^T, accumulating into C)
|
||||
static void metal_encode_dw_sgemm(id<MTLCommandBuffer> cmdBuf,
|
||||
id<MTLDevice> device,
|
||||
const float *a_data, int M, int K,
|
||||
const float *b_data, int N,
|
||||
id<MTLBuffer> c_buf) {
|
||||
// Create temporary input buffers (shared mode = zero-copy on Apple Silicon)
|
||||
id<MTLBuffer> aBuf = [device newBufferWithBytesNoCopy:(void*)a_data
|
||||
length:M * K * sizeof(float)
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
id<MTLBuffer> bBuf = [device newBufferWithBytesNoCopy:(void*)b_data
|
||||
length:N * K * sizeof(float)
|
||||
options:MTLResourceStorageModeShared
|
||||
deallocator:nil];
|
||||
|
||||
// A is [M, K] row-major, B is [N, K] row-major
|
||||
// We want C += A @ B^T, i.e., C[M, N] = A[M, K] * B[K, N]^T
|
||||
// MPS uses row-major by default
|
||||
MPSMatrixDescriptor *descA = [MPSMatrixDescriptor matrixDescriptorWithRows:M
|
||||
columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32];
|
||||
MPSMatrixDescriptor *descB = [MPSMatrixDescriptor matrixDescriptorWithRows:N
|
||||
columns:K rowBytes:K * sizeof(float) dataType:MPSDataTypeFloat32];
|
||||
MPSMatrixDescriptor *descC = [MPSMatrixDescriptor matrixDescriptorWithRows:M
|
||||
columns:N rowBytes:N * sizeof(float) dataType:MPSDataTypeFloat32];
|
||||
|
||||
MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:aBuf descriptor:descA];
|
||||
MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bBuf descriptor:descB];
|
||||
MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:c_buf descriptor:descC];
|
||||
|
||||
MPSMatrixMultiplication *mm = [[MPSMatrixMultiplication alloc]
|
||||
initWithDevice:device transposeLeft:NO transposeRight:YES
|
||||
resultRows:M resultColumns:N interiorColumns:K alpha:1.0 beta:1.0];
|
||||
|
||||
[mm encodeToCommandBuffer:cmdBuf leftMatrix:matA rightMatrix:matB resultMatrix:matC];
|
||||
}
|
||||
|
||||
// ===== Weight loading from llama2.c format =====
|
||||
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
|
||||
FILE *f = fopen(path, "rb");
|
||||
if (!f) { printf("Cannot open %s\n", path); return false; }
|
||||
Llama2Config cfg;
|
||||
fread(&cfg, sizeof(cfg), 1, f);
|
||||
printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
|
||||
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
|
||||
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
|
||||
printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS);
|
||||
fclose(f); return false;
|
||||
}
|
||||
int V = abs(cfg.vocab_size);
|
||||
bool shared = cfg.vocab_size > 0;
|
||||
(void)V; (void)shared;
|
||||
|
||||
fread(embed, 4, VOCAB * DIM, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
|
||||
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
|
||||
fread(rms_final, 4, DIM, f);
|
||||
fclose(f);
|
||||
printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls");
|
||||
return true;
|
||||
}
|
||||
|
||||
// ===== Compile one layer's kernels =====
|
||||
static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) {
|
||||
lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{
|
||||
@"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)},
|
||||
@"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)},
|
||||
@"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)},
|
||||
@"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)},
|
||||
@"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)},
|
||||
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
|
||||
}), DIM*SEQ*2, 6*DIM*SEQ*2);
|
||||
|
||||
lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{
|
||||
@"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)},
|
||||
@"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)},
|
||||
@"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)},
|
||||
@"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)},
|
||||
}), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2);
|
||||
|
||||
lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{
|
||||
@"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)},
|
||||
@"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)},
|
||||
@"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)},
|
||||
}), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2);
|
||||
|
||||
lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{
|
||||
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
|
||||
@"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)},
|
||||
}), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
|
||||
|
||||
lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{
|
||||
@"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)},
|
||||
@"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)},
|
||||
@"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)},
|
||||
}), 3*DIM*SEQ*2, DIM*SEQ*2);
|
||||
|
||||
return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd;
|
||||
}
|
||||
|
||||
static Kern *compile_sdpa_bwd2(void) {
|
||||
return compile_kern_mil_w(gen_sdpa_bwd2(), @{},
|
||||
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
|
||||
}
|
||||
|
||||
static void free_layer_kernels(LayerKernels *lk) {
|
||||
free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd);
|
||||
free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd);
|
||||
lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL;
|
||||
}
|
||||
|
||||
// ===== Checkpoint save/load =====
|
||||
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
|
||||
double cc, double ct, double cw, int cs, int cb, int adam_t,
|
||||
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
|
||||
float *embed, AdamState *aembed) {
|
||||
FILE *f = fopen(path, "wb");
|
||||
CkptHdr h = {0};
|
||||
h.magic = 0x424C5A54; h.version = 2;
|
||||
h.step = step; h.total_steps = total_steps;
|
||||
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
|
||||
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
|
||||
h.lr = lr; h.loss = loss;
|
||||
h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw;
|
||||
h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t;
|
||||
fwrite(&h, sizeof(h), 1, f);
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
|
||||
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
|
||||
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
|
||||
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
|
||||
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
|
||||
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
|
||||
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
|
||||
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
|
||||
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
|
||||
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
|
||||
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
|
||||
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
|
||||
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
|
||||
}
|
||||
fwrite(rms_final,4,DIM,f);
|
||||
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
|
||||
fwrite(embed,4,VOCAB*DIM,f);
|
||||
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
|
||||
fclose(f);
|
||||
}
|
||||
|
||||
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
|
||||
double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t,
|
||||
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
|
||||
float *embed, AdamState *aembed) {
|
||||
FILE *f = fopen(path, "rb");
|
||||
if (!f) return false;
|
||||
CkptHdr h;
|
||||
fread(&h, sizeof(h), 1, f);
|
||||
if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; }
|
||||
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
|
||||
*cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall;
|
||||
*cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t;
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
|
||||
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
|
||||
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
|
||||
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
|
||||
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
|
||||
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
|
||||
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
|
||||
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
|
||||
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
|
||||
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
|
||||
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
|
||||
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
|
||||
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
|
||||
}
|
||||
fread(rms_final,4,DIM,f);
|
||||
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
|
||||
fread(embed,4,VOCAB*DIM,f);
|
||||
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
|
||||
fclose(f);
|
||||
return true;
|
||||
}
|
||||
|
||||
// ===== Main =====
|
||||
int main(int argc, char *argv[]) {
|
||||
@autoreleasepool {
|
||||
setbuf(stdout, NULL);
|
||||
|
||||
// Phase 2: Limit BLAS thread count to prevent oversubscription with concurrent dispatch
|
||||
setenv("VECLIB_MAXIMUM_THREADS", "2", 1);
|
||||
|
||||
ane_init();
|
||||
mach_timebase_info(&g_tb);
|
||||
|
||||
int total_steps = 10000;
|
||||
float lr = 3e-4f;
|
||||
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
|
||||
int adam_t = 0, start_step = 0;
|
||||
|
||||
// Parse args
|
||||
const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT);
|
||||
const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT);
|
||||
const char *data_path = get_path("ANE_DATA_PATH", DATA_PATH_DEFAULT);
|
||||
bool do_resume = false;
|
||||
bool use_metal = false; // default off: Metal dW contends with ANE for memory bandwidth
|
||||
int pos = 0;
|
||||
for (int i=1; i<argc; i++) {
|
||||
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
|
||||
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
|
||||
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
|
||||
else if (strcmp(argv[i], "--no-metal") == 0) use_metal = false;
|
||||
else if (strcmp(argv[i], "--metal") == 0) use_metal = true;
|
||||
else if (argv[i][0] != '-') {
|
||||
if (pos == 0) model_path = argv[i];
|
||||
pos++;
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate per-layer state
|
||||
LayerWeights lw[NLAYERS];
|
||||
LayerAdam la[NLAYERS];
|
||||
LayerActs acts[NLAYERS];
|
||||
LayerGrads grads[NLAYERS];
|
||||
LayerKernels kern[NLAYERS];
|
||||
LayerCaptures caps[NLAYERS]; // Phase 1: pre-allocated captures
|
||||
LayerFP16Cache fp16cache[NLAYERS]; // Phase 2: fp16 activation cache
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
lw[L] = layer_weights_alloc();
|
||||
la[L] = layer_adam_alloc();
|
||||
acts[L] = layer_acts_alloc();
|
||||
grads[L] = layer_grads_alloc();
|
||||
memset(&kern[L], 0, sizeof(LayerKernels));
|
||||
caps[L] = layer_captures_alloc();
|
||||
fp16cache[L] = layer_fp16_cache_alloc();
|
||||
}
|
||||
|
||||
// Final RMSNorm + embedding + classifier
|
||||
float *rms_final = (float*)malloc(DIM*4);
|
||||
float *embed = (float*)malloc(VOCAB*DIM*4);
|
||||
float *grms_final = (float*)calloc(DIM, 4);
|
||||
float *gembed = (float*)calloc(VOCAB*DIM, 4);
|
||||
AdamState arms_final = adam_alloc(DIM);
|
||||
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
|
||||
|
||||
// Phase 1: Pre-allocate dx_rms scratch (was calloc/free per step)
|
||||
float *dx_rms_scratch = (float*)malloc(SEQ*DIM*4);
|
||||
// Phase 1: Pre-allocate embed temp buffer for vectorized ops
|
||||
float *embed_tmp = (float*)malloc(SEQ*DIM*4);
|
||||
|
||||
// Phase 3: Metal GPU for dW
|
||||
MetalDWContext metal_ctx;
|
||||
bool metal_ok = false;
|
||||
if (use_metal) {
|
||||
metal_ok = metal_dw_init(&metal_ctx);
|
||||
if (!metal_ok) printf("[Metal] GPU init failed, falling back to CPU cblas\n");
|
||||
}
|
||||
|
||||
// Classifier dW capture buffers (pre-allocated, Phase 1)
|
||||
float *capt_dlogits = (float*)malloc(SEQ*VOCAB*4);
|
||||
float *capt_xfinal = (float*)malloc(SEQ*DIM*4);
|
||||
|
||||
double cum_compile=0, cum_train=0, cum_wall=0;
|
||||
int cum_steps=0, cum_batches=0;
|
||||
|
||||
float resume_loss = 0;
|
||||
bool resuming = false;
|
||||
if (do_resume) {
|
||||
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
|
||||
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
|
||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
|
||||
}
|
||||
if (!resuming) {
|
||||
printf("=== ANE Training (OPTIMIZED): Stories110M (12 layers) ===\n");
|
||||
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
|
||||
printf("Optimizations: NEON-Adam, vec-embed, pre-alloc, concurrent-dW, fp16-cache%s\n",
|
||||
metal_ok ? ", Metal-GPU-dW" : "");
|
||||
if (!load_pretrained(lw, rms_final, embed, model_path)) {
|
||||
printf("Pretrained load failed, using random init\n");
|
||||
srand48(42);
|
||||
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
|
||||
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
|
||||
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
|
||||
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
|
||||
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
|
||||
}
|
||||
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
|
||||
float escale = 0.02f;
|
||||
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
|
||||
}
|
||||
size_t tp = (size_t)NLAYERS*LAYER_PARAMS + DIM + (size_t)VOCAB*DIM;
|
||||
double xfmr_params = (double)NLAYERS*LAYER_PARAMS;
|
||||
double embed_params = (double)VOCAB*DIM;
|
||||
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
|
||||
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
|
||||
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
|
||||
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
|
||||
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
|
||||
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
|
||||
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
|
||||
double cls_f = 2.0*VOCAB*DIM*SEQ;
|
||||
double total_f = fwd_f + bwd_dx_f + bwd_dw_f + sdpa_f + cls_f*3;
|
||||
double ane_f = fwd_f + bwd_dx_f + sdpa_f;
|
||||
printf("FLOPs/step: fwd=%.0fM bwd_dx=%.0fM bwd_dW=%.0fM sdpa_bwd=%.0fM total=%.0fM\n",
|
||||
fwd_f/1e6, bwd_dx_f/1e6, bwd_dw_f/1e6, sdpa_f/1e6, total_f/1e6);
|
||||
printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | %s: dW+cls\n\n",
|
||||
ane_f/1e6, metal_ok ? "GPU" : "CPU cblas");
|
||||
}
|
||||
|
||||
// mmap token data
|
||||
int data_fd = open(data_path, O_RDONLY);
|
||||
if (data_fd < 0) { printf("Cannot open %s\n", data_path); return 1; }
|
||||
struct stat st; fstat(data_fd, &st);
|
||||
size_t data_len = st.st_size;
|
||||
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
|
||||
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
|
||||
size_t n_tokens = data_len / 2;
|
||||
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
|
||||
|
||||
// Gradient buffers shared across layers (reused each step)
|
||||
float *dy = (float*)malloc(SEQ*DIM*4);
|
||||
float *dffn = (float*)malloc(SEQ*DIM*4);
|
||||
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
|
||||
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
|
||||
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
|
||||
float *dx2 = (float*)malloc(SEQ*DIM*4);
|
||||
float *do_out_buf = (float*)malloc(SEQ*DIM*4);
|
||||
float *dq = (float*)malloc(SEQ*DIM*4);
|
||||
float *dk = (float*)malloc(SEQ*DIM*4);
|
||||
float *dv = (float*)malloc(SEQ*DIM*4);
|
||||
float *dx_attn = (float*)malloc(SEQ*DIM*4);
|
||||
|
||||
float *x_cur = (float*)malloc(SEQ*DIM*4);
|
||||
float *x_final = (float*)malloc(SEQ*DIM*4);
|
||||
float *logits = (float*)malloc(SEQ*VOCAB*4);
|
||||
float *dlogits = (float*)malloc(SEQ*VOCAB*4);
|
||||
|
||||
// Compile static sdpaBwd2 kernels
|
||||
Kern *sdpaBwd2[NLAYERS];
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
sdpaBwd2[L] = compile_sdpa_bwd2();
|
||||
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
|
||||
}
|
||||
|
||||
// Phase 2: Concurrent dW dispatch queue (was DISPATCH_QUEUE_SERIAL)
|
||||
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_CONCURRENT);
|
||||
dispatch_group_t dw_grp = dispatch_group_create();
|
||||
|
||||
float last_loss = 999.0f;
|
||||
double total_compile_ms=0, total_train_ms=0;
|
||||
int total_steps_done=0, total_batches=0;
|
||||
uint64_t t_wall_start = mach_absolute_time();
|
||||
|
||||
srand48(42 + start_step);
|
||||
|
||||
int step = start_step;
|
||||
while (step < total_steps) {
|
||||
// Check compile budget
|
||||
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
|
||||
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
|
||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
|
||||
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
|
||||
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
|
||||
lw, la, rms_final, &arms_final, embed, &aembed);
|
||||
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
|
||||
fflush(stdout);
|
||||
// Preserve --metal flag across restarts (default is off)
|
||||
if (use_metal) execl(argv[0], argv[0], "--resume", "--metal", NULL);
|
||||
else execl(argv[0], argv[0], "--resume", NULL);
|
||||
perror("execl"); return 1;
|
||||
}
|
||||
|
||||
// Compile all layers' weight-bearing kernels
|
||||
uint64_t tc = mach_absolute_time();
|
||||
for (int L=0; L<NLAYERS; L++) free_layer_kernels(&kern[L]);
|
||||
|
||||
bool compile_ok = true;
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
printf(" Compiling layer %d/%d... (%d compiles)\r", L+1, NLAYERS, g_compile_count);
|
||||
fflush(stdout);
|
||||
if (!compile_layer_kernels(&kern[L], &lw[L])) {
|
||||
printf("\nCompile failed at layer %d, restart\n", L);
|
||||
compile_ok = false; break;
|
||||
}
|
||||
}
|
||||
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
|
||||
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
if (!sdpaBwd2[L]) {
|
||||
sdpaBwd2[L] = compile_sdpa_bwd2();
|
||||
if (!sdpaBwd2[L]) { printf("sdpaBwd2 recompile failed\n"); return 1; }
|
||||
}
|
||||
}
|
||||
|
||||
double cms = tb_ms(mach_absolute_time() - tc);
|
||||
total_compile_ms += cms;
|
||||
printf(" Compiled %d kernels in %.0fms \n", TOTAL_WEIGHT_KERNELS, cms);
|
||||
|
||||
// Zero gradient accumulators
|
||||
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
|
||||
memset(grms_final, 0, DIM*4);
|
||||
memset(gembed, 0, (size_t)VOCAB*DIM*4);
|
||||
if (metal_ok) metal_dw_zero(&metal_ctx);
|
||||
|
||||
int steps_batch = 0;
|
||||
uint64_t tt = mach_absolute_time();
|
||||
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0,t_metal=0,t_bwd=0;
|
||||
|
||||
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
|
||||
uint64_t t0,t1;
|
||||
size_t max_pos = n_tokens - SEQ - 1;
|
||||
size_t pos = (size_t)(drand48() * max_pos);
|
||||
uint16_t *input_tokens = token_data + pos;
|
||||
uint16_t *target_tokens = token_data + pos + 1;
|
||||
|
||||
// Phase 1: Vectorized embedding lookup
|
||||
t0=mach_absolute_time();
|
||||
embed_lookup_opt(x_cur, embed, input_tokens, DIM, SEQ, embed_tmp);
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
|
||||
|
||||
// ===== FORWARD (12 layers) =====
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
LayerActs *ac = &acts[L];
|
||||
LayerFP16Cache *fc = &fp16cache[L];
|
||||
|
||||
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
|
||||
|
||||
// Attention forward
|
||||
t0=mach_absolute_time();
|
||||
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
|
||||
t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1;
|
||||
io_write_fp16(kern[L].fwdAttn->ioIn, x_cur, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
ane_eval(kern[L].fwdAttn);
|
||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// Read o_out (needed on main thread for residual)
|
||||
io_read_fp16(kern[L].fwdAttn->ioOut, ac->o_out, 0, DIM, SEQ);
|
||||
// Phase 2: Read dW-only activations as raw fp16 (skip conversion on main thread)
|
||||
io_read_raw_fp16(kern[L].fwdAttn->ioOut, fc->attn_out_fp16, 4*DIM, DIM, SEQ);
|
||||
io_read_raw_fp16(kern[L].fwdAttn->ioOut, fc->xnorm_fp16, 5*DIM, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// FFN forward
|
||||
io_write_fp16(kern[L].fwdFFN->ioIn, ac->x2, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
|
||||
ane_eval(kern[L].fwdFFN);
|
||||
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// Read ffn_out (needed on main thread for residual)
|
||||
io_read_fp16(kern[L].fwdFFN->ioOut, ac->ffn_out, 0, DIM, SEQ);
|
||||
// h1, h3 NOT read here — backward uses io_copy from fwdFFN->ioOut directly
|
||||
// silu_out and x2norm are dW-only → read as fp16
|
||||
io_read_raw_fp16(kern[L].fwdFFN->ioOut, fc->silu_out_fp16, DIM+2*HIDDEN, HIDDEN, SEQ);
|
||||
io_read_raw_fp16(kern[L].fwdFFN->ioOut, fc->x2norm_fp16, DIM+3*HIDDEN, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0);
|
||||
|
||||
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
|
||||
}
|
||||
|
||||
// Final RMSNorm (CPU)
|
||||
t0=mach_absolute_time();
|
||||
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
|
||||
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// Classifier
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||
VOCAB, SEQ, DIM, 1.0f,
|
||||
embed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
|
||||
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
float loss = cross_entropy_loss(dlogits, logits, target_tokens, VOCAB, SEQ);
|
||||
last_loss = loss;
|
||||
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
|
||||
|
||||
// ===== BACKWARD =====
|
||||
uint64_t t_bwd_start = mach_absolute_time();
|
||||
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
|
||||
DIM, SEQ, VOCAB, 1.0f,
|
||||
embed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
|
||||
|
||||
// dW embed (classifier) — async on dW queue
|
||||
memcpy(capt_dlogits, dlogits, SEQ*VOCAB*4);
|
||||
memcpy(capt_xfinal, x_final, SEQ*DIM*4);
|
||||
|
||||
// Classifier dW on CPU (gembed is CPU-side accumulator, not Metal buffer)
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
VOCAB, DIM, SEQ, 1.0f,
|
||||
capt_dlogits, SEQ, capt_xfinal, SEQ, 1.0f, gembed, DIM);
|
||||
});
|
||||
|
||||
// Final RMSNorm backward (using pre-allocated scratch)
|
||||
memset(dx_rms_scratch, 0, SEQ*DIM*4);
|
||||
rmsnorm_bwd(dx_rms_scratch, grms_final, dy, x_cur, rms_final, DIM, SEQ);
|
||||
memcpy(dy, dx_rms_scratch, SEQ*DIM*4);
|
||||
|
||||
// ===== BACKWARD (12 layers, reverse) =====
|
||||
for (int L=NLAYERS-1; L>=0; L--) {
|
||||
LayerActs *ac = &acts[L];
|
||||
LayerGrads *gr = &grads[L];
|
||||
LayerCaptures *cp = &caps[L];
|
||||
LayerFP16Cache *fc = &fp16cache[L];
|
||||
|
||||
memcpy(dffn, dy, SEQ*DIM*4);
|
||||
|
||||
// FFN backward (ANE)
|
||||
io_write_fp16_at(kern[L].ffnBwd->ioIn, 0, dffn, DIM, SEQ);
|
||||
io_copy(kern[L].ffnBwd->ioIn, DIM, kern[L].fwdFFN->ioOut, DIM, 2*HIDDEN, SEQ);
|
||||
ane_eval(kern[L].ffnBwd);
|
||||
io_read_fp16(kern[L].ffnBwd->ioOut, dx_ffn, 0, DIM, SEQ);
|
||||
// dh1, dh3: only used for dW captures → read as raw fp16
|
||||
io_read_raw_fp16(kern[L].ffnBwd->ioOut, cp->dh1_fp16, DIM, HIDDEN, SEQ);
|
||||
io_read_raw_fp16(kern[L].ffnBwd->ioOut, cp->dh3_fp16, DIM+HIDDEN, HIDDEN, SEQ);
|
||||
|
||||
memcpy(cp->dffn, dffn, SEQ*DIM*4);
|
||||
|
||||
if (metal_ok) {
|
||||
// Metal path: convert all on main thread for GPU buffers
|
||||
cvt_f16_f32(cp->dh1, cp->dh1_fp16, SEQ*HIDDEN);
|
||||
cvt_f16_f32(cp->dh3, cp->dh3_fp16, SEQ*HIDDEN);
|
||||
cvt_f16_f32(cp->silu_out, fc->silu_out_fp16, SEQ*HIDDEN);
|
||||
cvt_f16_f32(cp->x2norm, fc->x2norm_fp16, SEQ*DIM);
|
||||
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> cmdBuf = [metal_ctx.queue commandBuffer];
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->dffn, DIM, SEQ, cp->silu_out, HIDDEN,
|
||||
metal_ctx.dW_bufs[L][MW_2]);
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->dh1, HIDDEN, SEQ, cp->x2norm, DIM,
|
||||
metal_ctx.dW_bufs[L][MW_1]);
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->dh3, HIDDEN, SEQ, cp->x2norm, DIM,
|
||||
metal_ctx.dW_bufs[L][MW_3]);
|
||||
[cmdBuf commit];
|
||||
metal_ctx.lastCmdBuf = cmdBuf;
|
||||
}
|
||||
} else {
|
||||
// CPU: concurrent dispatch, convert fp16→fp32 in each block
|
||||
_Float16 *fc_silu = fc->silu_out_fp16;
|
||||
_Float16 *cp_dh1_f16 = cp->dh1_fp16, *cp_dh3_f16 = cp->dh3_fp16;
|
||||
float *cp_dffn = cp->dffn, *cp_silu = cp->silu_out;
|
||||
float *cp_dh1 = cp->dh1, *cp_dh3 = cp->dh3, *cp_x2n = cp->x2norm;
|
||||
float *gr_W2 = gr->W2, *gr_W1 = gr->W1, *gr_W3 = gr->W3;
|
||||
|
||||
// Convert shared x2norm on main thread (W1+W3 blocks read concurrently)
|
||||
cvt_f16_f32(cp_x2n, fc->x2norm_fp16, SEQ*DIM);
|
||||
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_silu, fc_silu, SEQ*HIDDEN);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
|
||||
1.0f, cp_dffn, SEQ, cp_silu, SEQ, 1.0f, gr_W2, HIDDEN);
|
||||
});
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_dh1, cp_dh1_f16, SEQ*HIDDEN);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
|
||||
1.0f, cp_dh1, SEQ, cp_x2n, SEQ, 1.0f, gr_W1, DIM);
|
||||
});
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_dh3, cp_dh3_f16, SEQ*HIDDEN);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
|
||||
1.0f, cp_dh3, SEQ, cp_x2n, SEQ, 1.0f, gr_W3, DIM);
|
||||
});
|
||||
}
|
||||
|
||||
// RMSNorm2 backward
|
||||
memset(dx2, 0, SEQ*DIM*4);
|
||||
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
|
||||
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
|
||||
|
||||
// dWo async
|
||||
memcpy(cp->do_buf, dx2, SEQ*DIM*4);
|
||||
|
||||
if (metal_ok) {
|
||||
cvt_f16_f32(cp->attn_out, fc->attn_out_fp16, SEQ*DIM);
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> cmdBuf = [metal_ctx.queue commandBuffer];
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->do_buf, DIM, SEQ, cp->attn_out, DIM,
|
||||
metal_ctx.dW_bufs[L][MW_O]);
|
||||
[cmdBuf commit];
|
||||
metal_ctx.lastCmdBuf = cmdBuf;
|
||||
}
|
||||
} else {
|
||||
_Float16 *fc_attn = fc->attn_out_fp16;
|
||||
float *cp_do = cp->do_buf, *cp_attn = cp->attn_out;
|
||||
float *gr_Wo = gr->Wo;
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_attn, fc_attn, SEQ*DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, cp_do, SEQ, cp_attn, SEQ, 1.0f, gr_Wo, DIM);
|
||||
});
|
||||
}
|
||||
|
||||
// SDPA backward (ANE)
|
||||
io_copy(kern[L].sdpaBwd1->ioIn, 0, kern[L].fwdAttn->ioOut, DIM, 3*DIM, SEQ);
|
||||
io_write_fp16_at(kern[L].sdpaBwd1->ioIn, 3*DIM, dx2, DIM, SEQ);
|
||||
ane_eval(kern[L].sdpaBwd1);
|
||||
io_copy(sdpaBwd2[L]->ioIn, 0, kern[L].sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
|
||||
io_copy(sdpaBwd2[L]->ioIn, 2*SCORE_CH, kern[L].fwdAttn->ioOut, DIM, 2*DIM, SEQ);
|
||||
ane_eval(sdpaBwd2[L]);
|
||||
|
||||
// dq, dk, dv: only used for dW captures → read as raw fp16
|
||||
io_read_raw_fp16(sdpaBwd2[L]->ioOut, cp->dq_fp16, 0, DIM, SEQ);
|
||||
io_read_raw_fp16(sdpaBwd2[L]->ioOut, cp->dk_fp16, DIM, DIM, SEQ);
|
||||
io_read_raw_fp16(kern[L].sdpaBwd1->ioOut, cp->dv_fp16, 0, DIM, SEQ);
|
||||
|
||||
if (metal_ok) {
|
||||
// Metal path: convert all on main thread for GPU buffers
|
||||
cvt_f16_f32(cp->dq, cp->dq_fp16, SEQ*DIM);
|
||||
cvt_f16_f32(cp->dk, cp->dk_fp16, SEQ*DIM);
|
||||
cvt_f16_f32(cp->dv, cp->dv_fp16, SEQ*DIM);
|
||||
cvt_f16_f32(cp->xnorm, fc->xnorm_fp16, SEQ*DIM);
|
||||
@autoreleasepool {
|
||||
id<MTLCommandBuffer> cmdBuf = [metal_ctx.queue commandBuffer];
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->dq, DIM, SEQ, cp->xnorm, DIM,
|
||||
metal_ctx.dW_bufs[L][MW_Q]);
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->dk, DIM, SEQ, cp->xnorm, DIM,
|
||||
metal_ctx.dW_bufs[L][MW_K]);
|
||||
metal_encode_dw_sgemm(cmdBuf, metal_ctx.device,
|
||||
cp->dv, DIM, SEQ, cp->xnorm, DIM,
|
||||
metal_ctx.dW_bufs[L][MW_V]);
|
||||
[cmdBuf commit];
|
||||
metal_ctx.lastCmdBuf = cmdBuf;
|
||||
}
|
||||
} else {
|
||||
_Float16 *cp_dq_f16 = cp->dq_fp16, *cp_dk_f16 = cp->dk_fp16, *cp_dv_f16 = cp->dv_fp16;
|
||||
float *cp_dq = cp->dq, *cp_dk = cp->dk, *cp_dv = cp->dv, *cp_xn = cp->xnorm;
|
||||
float *gr_Wq = gr->Wq, *gr_Wk = gr->Wk, *gr_Wv = gr->Wv;
|
||||
// Convert shared xnorm on main thread (all 3 blocks read concurrently)
|
||||
cvt_f16_f32(cp_xn, fc->xnorm_fp16, SEQ*DIM);
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_dq, cp_dq_f16, SEQ*DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, cp_dq, SEQ, cp_xn, SEQ, 1.0f, gr_Wq, DIM);
|
||||
});
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_dk, cp_dk_f16, SEQ*DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, cp_dk, SEQ, cp_xn, SEQ, 1.0f, gr_Wk, DIM);
|
||||
});
|
||||
dispatch_group_async(dw_grp, dw_q, ^{
|
||||
cvt_f16_f32(cp_dv, cp_dv_f16, SEQ*DIM);
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
|
||||
1.0f, cp_dv, SEQ, cp_xn, SEQ, 1.0f, gr_Wv, DIM);
|
||||
});
|
||||
}
|
||||
|
||||
// QKV backward (ANE)
|
||||
io_copy(kern[L].qkvBwd->ioIn, 0, sdpaBwd2[L]->ioOut, 0, 2*DIM, SEQ);
|
||||
io_copy(kern[L].qkvBwd->ioIn, 2*DIM, kern[L].sdpaBwd1->ioOut, 0, DIM, SEQ);
|
||||
ane_eval(kern[L].qkvBwd);
|
||||
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
|
||||
|
||||
// RMSNorm1 backward
|
||||
memset(dx_rms_scratch, 0, SEQ*DIM*4);
|
||||
rmsnorm_bwd(dx_rms_scratch, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
|
||||
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms_scratch[i] + dx2[i];
|
||||
}
|
||||
|
||||
// Embedding backward
|
||||
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
|
||||
// Phase 1: Vectorized embed backward
|
||||
embed_backward_opt(gembed, dy, input_tokens, DIM, SEQ, embed_tmp);
|
||||
t_bwd += tb_ms(mach_absolute_time() - t_bwd_start);
|
||||
|
||||
steps_batch++;
|
||||
if (step % 10 == 0 || step == start_step)
|
||||
printf("step %-4d loss=%.4f\n", step, loss);
|
||||
|
||||
// JSON telemetry to stderr
|
||||
double step_ane = t_ane/steps_batch, step_io = t_io/steps_batch;
|
||||
double step_cls = t_cls/steps_batch, step_elem = t_elem/steps_batch;
|
||||
double step_rms = t_rms/steps_batch, step_cbw = t_cblas_wait/steps_batch;
|
||||
fprintf(stderr, "{\"type\":\"step\",\"step\":%d,\"loss\":%.6f,"
|
||||
"\"t_ane\":%.3f,\"t_io\":%.3f,\"t_cls\":%.3f,"
|
||||
"\"t_elem\":%.3f,\"t_rms\":%.3f,\"t_cblas_wait\":%.3f,"
|
||||
"\"t_bwd\":%.3f,\"t_metal\":%.3f,\"compiles\":%d}\n",
|
||||
step, loss, step_ane, step_io, step_cls, step_elem, step_rms, step_cbw,
|
||||
t_bwd/steps_batch, t_metal/steps_batch, g_compile_count);
|
||||
}
|
||||
double tms = tb_ms(mach_absolute_time() - tt);
|
||||
total_train_ms += tms;
|
||||
total_steps_done += steps_batch;
|
||||
total_batches++;
|
||||
|
||||
// Ensure all async dW finished (CPU cblas or Metal)
|
||||
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
|
||||
|
||||
// Phase 3: If Metal, wait for GPU then copy gradient accumulators to CPU grads
|
||||
if (metal_ok) {
|
||||
// Must wait for all GPU command buffers to complete before reading
|
||||
if (metal_ctx.lastCmdBuf) {
|
||||
[metal_ctx.lastCmdBuf waitUntilCompleted];
|
||||
metal_ctx.lastCmdBuf = nil;
|
||||
}
|
||||
for (int L = 0; L < NLAYERS; L++) {
|
||||
float *gpu_ptrs[7];
|
||||
float *cpu_ptrs[7];
|
||||
size_t sizes[7] = {WQ_SZ*4, WQ_SZ*4, WQ_SZ*4, WO_SZ*4,
|
||||
W1_SZ*4, W2_SZ*4, W3_SZ*4};
|
||||
int indices[7] = {MW_Q, MW_K, MW_V, MW_O, MW_1, MW_2, MW_3};
|
||||
cpu_ptrs[0]=grads[L].Wq; cpu_ptrs[1]=grads[L].Wk; cpu_ptrs[2]=grads[L].Wv;
|
||||
cpu_ptrs[3]=grads[L].Wo; cpu_ptrs[4]=grads[L].W1; cpu_ptrs[5]=grads[L].W2;
|
||||
cpu_ptrs[6]=grads[L].W3;
|
||||
|
||||
for (int w = 0; w < 7; w++) {
|
||||
gpu_ptrs[w] = (float*)[metal_ctx.dW_bufs[L][indices[w]] contents];
|
||||
// Accumulate GPU gradients into CPU accumulators
|
||||
vDSP_vadd(gpu_ptrs[w], 1, cpu_ptrs[w], 1, cpu_ptrs[w], 1,
|
||||
(vDSP_Length)(sizes[w]/4));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adam update (scale gradients by 1/steps_batch)
|
||||
float gsc = 1.0f / steps_batch;
|
||||
adam_t++;
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
LayerGrads *g = &grads[L];
|
||||
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
|
||||
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
|
||||
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
|
||||
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
|
||||
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
|
||||
|
||||
// Phase 1: NEON Adam
|
||||
adam_update_opt(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
adam_update_opt(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
}
|
||||
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
|
||||
adam_update_opt(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
|
||||
adam_update_opt(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
|
||||
|
||||
printf(" [batch %d: compile=%.0fms train=%.1fms (%.1fms/step) compiles=%d]\n",
|
||||
steps_batch, cms, tms, tms/steps_batch, g_compile_count);
|
||||
printf(" fwd: ane=%.1f io=%.1f cls=%.1f elem=%.1f rms=%.1f | bwd=%.1f | cblas_wait=%.1f ms/step\n",
|
||||
t_ane/steps_batch, t_io/steps_batch, t_cls/steps_batch, t_elem/steps_batch,
|
||||
t_rms/steps_batch, t_bwd/steps_batch, t_cblas_wait/steps_batch);
|
||||
|
||||
// JSON batch telemetry to stderr
|
||||
{
|
||||
double bf = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
|
||||
double bs = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
|
||||
double ane_f_batch = (bf*2 + bs) * steps_batch;
|
||||
double ane_tflops = ane_f_batch / (tms * 1e9);
|
||||
fprintf(stderr, "{\"type\":\"batch\",\"batch\":%d,\"compile_ms\":%.1f,"
|
||||
"\"train_ms\":%.1f,\"ms_per_step\":%.1f}\n",
|
||||
steps_batch, cms, tms, tms/steps_batch);
|
||||
fprintf(stderr, "{\"type\":\"perf\",\"ane_tflops\":%.3f,\"ane_util_pct\":%.2f,"
|
||||
"\"metal_dw\":%s}\n",
|
||||
ane_tflops, 100.0*ane_tflops/15.8, metal_ok ? "true" : "false");
|
||||
}
|
||||
}
|
||||
|
||||
// Efficiency report
|
||||
double wall = tb_ms(mach_absolute_time() - t_wall_start);
|
||||
total_compile_ms += cum_compile; total_train_ms += cum_train;
|
||||
wall += cum_wall; total_steps_done += cum_steps; total_batches += cum_batches;
|
||||
double fwd_flops = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
|
||||
double sdpa_flops = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
|
||||
double cls_flops = 2.0*VOCAB*DIM*SEQ;
|
||||
double total_flops = (fwd_flops*3 + sdpa_flops + cls_flops*3) * total_steps_done;
|
||||
double ane_flops = (fwd_flops*2 + sdpa_flops) * total_steps_done;
|
||||
printf("\n=== Efficiency Report (OPTIMIZED) ===\n");
|
||||
printf("Total steps: %d\n", total_steps_done);
|
||||
printf("Wall time: %.0f ms (%.1f s)\n", wall, wall/1000);
|
||||
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100*total_compile_ms/wall);
|
||||
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100*total_train_ms/wall);
|
||||
printf("Avg train: %.1f ms/step\n", total_train_ms/total_steps_done);
|
||||
printf("ANE TFLOPS: %.2f sustained\n", ane_flops / (total_train_ms * 1e9));
|
||||
printf("Total TFLOPS: %.2f (ANE+%s)\n", total_flops / (total_train_ms * 1e9),
|
||||
metal_ok ? "GPU" : "CPU");
|
||||
printf("ANE utilization: %.1f%% of 15.8 TFLOPS\n", 100*ane_flops/(total_train_ms*1e9)/15.8);
|
||||
printf("Metal GPU dW: %s\n", metal_ok ? "ENABLED" : "disabled");
|
||||
|
||||
// Cleanup
|
||||
for (int L=0; L<NLAYERS; L++) {
|
||||
free_layer_kernels(&kern[L]);
|
||||
free_kern(sdpaBwd2[L]);
|
||||
layer_weights_free(&lw[L]);
|
||||
layer_adam_free(&la[L]);
|
||||
layer_acts_free(&acts[L]);
|
||||
layer_grads_free(&grads[L]);
|
||||
layer_captures_free(&caps[L]);
|
||||
layer_fp16_cache_free(&fp16cache[L]);
|
||||
}
|
||||
munmap(token_data, data_len);
|
||||
close(data_fd);
|
||||
free(rms_final); free(embed); free(grms_final); free(gembed);
|
||||
adam_free(&arms_final); adam_free(&aembed);
|
||||
free(dy); free(dffn); free(dh1); free(dh3); free(dx_ffn); free(dx2);
|
||||
free(do_out_buf); free(dq); free(dk); free(dv); free(dx_attn);
|
||||
free(x_cur); free(x_final); free(logits); free(dlogits);
|
||||
free(dx_rms_scratch); free(embed_tmp);
|
||||
free(capt_dlogits); free(capt_xfinal);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
Loading…
Reference in New Issue