[feat] Merge upstream: dynamic weight training, CLI fixes, dashboard v2

This commit is contained in:
Erik Bray 2026-03-03 14:38:52 +01:00
commit 99b06838bc
12 changed files with 3092 additions and 164 deletions

View File

@ -8,43 +8,68 @@ Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
- **109.53M params** (84.95M transformer + 24.58M embedding)
- **72 ANE kernels** per compile (60 weight-bearing, 12 weight-free sdpaBwd2)
- **6 kernel types per layer**: fwdAttn, fwdFFN, ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask — decompose into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
## Performance
## Three Training Pipelines
| Component | Time (ms/step) |
|-----------|---------------|
| ANE eval | 9.6 |
| IO (fp16 conversion) | 4.1 |
| Classifier (cblas) | 9.1 |
| Cross-entropy + residuals | 14.4 |
| RMSNorm | 0.1 |
| **Total** | **107 ms/step** |
### 1. Static Baseline (`train_large`)
Original pipeline. Weights baked as constants in MIL kernels — recompile every 10 steps via `exec()` restart.
- 60 weight-bearing + 12 weight-free kernels = 72 per compile batch
- Classifier + softmax + RMSNorm backward on CPU
- **106.7 ms/step**, 7.6s compile per restart
### 2. Static + ANE Extras (`train_large_ane`) — PR#19
Offloads classifier forward (32K conv), softmax, final RMSNorm, and RMSNorm backward to ANE. Bridge API for C-callable ANE access.
- 86 kernels per compile batch (+24 rmsnorm_bwd, +1 classifier, +1 finalRms)
- **91.8 ms/step** (14% faster), 9.6s compile per restart
- Use `--no-ane-extras` to disable and fall back to CPU (for debugging)
### 3. Dynamic Weight Pipeline (`training_dynamic/`)
Weights passed via IOSurface spatial dimension — compile 9 kernels once at startup, no recompilation needed.
- 9 shared kernels across all 12 layers
- **111 ms/step**, 0.4s one-time compile
- No exec() restart, no compile limit issues
## Performance Comparison (20 Steps)
| | Static Baseline | PR#19 + ANE extras | PR#19 no extras | Dynamic |
|---|---|---|---|---|
| **Wall time** | **10.1s** | **11.7s** | **10.7s** | **~2.6s** |
| Compile | 7.6s (75.7%) | 9.6s (81.6%) | 7.5s (69.7%) | 0.4s (15%) |
| Train | 2.1s (21.2%) | 1.8s (15.6%) | 2.9s (27.4%) | 2.2s (85%) |
| **ms/step** | **106.7** | **91.8** | **147.0** | **111** |
| Kernels/restart | 72 | 86 | 60 | 9 (once) |
| ANE TFLOPS | 0.87 | 1.15 | 0.72 | — |
| Total TFLOPS | 1.63 | 1.90 | 1.19 | — |
**Key insights:**
- Dynamic wins on wall time for any practical run length (3.9x faster at 20 steps)
- PR#19 has the best per-step throughput (92ms) but compile overhead dominates short runs
- Static restarts every 10 steps, so dynamic's zero-recompile advantage compounds
## Files
| File | Description |
|------|-------------|
| `train_large.m` | Main training loop — 12-layer forward/backward, checkpoint, exec() restart |
| `stories_config.h` | Model config, structs, alloc helpers |
| `train_large.m` | Static baseline — 72 kernels, classifier/softmax on CPU |
| `train_large_ane.m` | PR#19 — 86 kernels, classifier/softmax/rmsnorm_bwd on ANE |
| `training_dynamic/train.m` | Dynamic pipeline — 9 kernels, weights via IOSurface |
| `training_dynamic/mil_dynamic.h` | MIL generators for dynamic weight kernels |
| `training_dynamic/config.h` | Model config (DIM=768, HIDDEN=2048, etc.) |
| `training_dynamic/io.h` | IOSurface I/O + MIL compilation helpers |
| `training_dynamic/cpu_ops.h` | CPU ops (SiLU backward, cross-entropy, Adam) |
| `stories_config.h` | Static pipeline config, structs, alloc helpers |
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
| `stories_mil.h` | MIL program generators for all 6 ANE kernel types |
| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam, embedding ops |
| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs, text generation |
| `tokenize.py` | Extract pretokenized TinyStories data |
| `stories_mil.h` | MIL generators for static pipeline (6 kernel types) |
| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam |
| `ane_classifier.h` | ANE classifier fwd (32K conv), softmax kernels |
| `ane_rmsnorm_bwd.h` | ANE rmsnorm backward kernel |
| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs |
| `Makefile` | Build targets |
## How it works
1. **Forward pass**: Each layer runs fwdAttn (QKV + SDPA + Wo) and fwdFFN (W1 + SiLU(W3) + W2) on ANE via MIL-compiled kernels. Final RMSNorm + classifier matmul on CPU (cblas).
2. **Backward pass**: Reverse layer order. ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd on ANE. Weight gradients (dW) via async cblas_sgemm on CPU. RMSNorm backward via vDSP.
3. **Compile budget**: ANE has a ~119 compile limit per process. With 72 kernels per batch, we run 10 accumulation steps then `exec()` restart with checkpoint resume.
4. **Data**: Real TinyStories text (20M tokens), mmap'd uint16 token IDs, random position sampling per step.
## Usage
### 1. Download Training Data
@ -53,69 +78,63 @@ Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly
bash download_data.sh
```
Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from [enio/TinyStories](https://huggingface.co/datasets/enio/TinyStories) on HuggingFace. Produces `tinystories_data00.bin` (~41 MB, ~20M tokens).
Downloads pretokenized TinyStories (Llama 2 BPE, 32K vocab) from HuggingFace. Produces `tinystories_data00.bin` (~41 MB, ~20M tokens).
### 2. Build & Train
```bash
# Baseline: classifier + softmax on CPU
# Static baseline (classifier + softmax on CPU)
make train_large
./train_large --steps 100 # quick test
./train_large # full 10k steps
./train_large --resume # resume from checkpoint
./train_large stories110M.bin 256 100 1e-4
./train_large --model stories110M.bin --steps 100 --lr 1e-4
# ANE-offloaded: classifier + softmax on ANE (faster)
# PR#19: ANE-offloaded classifier + softmax + rmsnorm_bwd
make train_large_ane
./train_large_ane --steps 100
./train_large_ane stories110M.bin 256 100 1e-4
./train_large_ane --no-ane-extras --steps 100 # disable ANE extras
# Dynamic pipeline (no recompilation)
cd training_dynamic && make train
./train --scratch # train from random init
./train # resume from checkpoint
./train --steps 200 --lr 1e-4 # custom steps/lr
```
**CLI flags:** `--steps N` (default 10000), `--lr F` (default 3e-4), `--resume`.
**CLI flags (all pipelines):**
- `--steps N` (default 10000)
- `--lr F` (default 3e-4)
- `--model PATH` — pretrained weights file
- `--ckpt PATH` — checkpoint file (preserved across exec() restarts)
- `--resume` — resume from checkpoint
- `--no-ane-extras` — (train_large_ane only) disable ANE classifier/softmax/rmsnorm_bwd
### 3. Monitor with Dashboard
```bash
pip install blessed psutil numpy
sudo python3 dashboard.py # live mode (needs powermetrics)
sudo python3 dashboard.py --resume # attach to resumed training
sudo python3 dashboard.py # static pipeline
sudo python3 dashboard.py --dynamic # dynamic pipeline
```
### 4. Benchmarking
Both programs print an **Efficiency Report** at completion:
All programs print an **Efficiency Report** at completion:
```
=== Efficiency Report ===
Total steps: 100
Avg train: 107.0 ms/step
ANE TFLOPS: 2.45 sustained
ANE utilization: 15.5% of 15.8 TFLOPS
Total steps: 20
Wall time: 11738 ms (11.7 s)
Compile time: 9583 ms (81.6%)
Train time: 1835 ms (15.6%)
Avg train: 91.8 ms/step
ANE TFLOPS: 1.15 sustained
```
Per-batch timing breakdown during training:
## Key Techniques
```
ane=9.6 io=4.1 cls=9.1 elem=14.4 rms=0.1 cblas_wait=2.3 ms/step
```
| Metric | What it measures |
|--------|-----------------|
| `ane` | ANE kernel evaluation |
| `io` | fp16↔fp32 IOSurface transfer |
| `cls` | Classifier matmul (CPU cblas) |
| `elem` | Embedding, residual adds, cross-entropy |
| `rms` | RMSNorm forward/backward |
| `cblas_wait` | Waiting for async dW gradient sgemms |
Compare baseline vs ANE-offloaded:
```bash
make train_large && ./train_large --steps 100
make train_large_ane && ./train_large_ane --steps 100
```
## Key techniques
- **NEON vectorized fp16<->fp32**: ARM NEON intrinsics for fast IOSurface data transfer
- **NEON vectorized fp16↔fp32**: ARM NEON intrinsics for fast IOSurface data transfer
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask, so we decompose attention into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
- **Vocab compaction** (dynamic): 32K → 9.2K active tokens, 3.5x reduction in classifier work
- **Dynamic weight packing**: Activations + weights concatenated in IOSurface spatial dimension — one kernel serves all 12 layers
- **exec() restart**: Workaround for ANE ~119 compile limit per process

View File

@ -1,6 +1,6 @@
"""TUI dashboard for ANE training (train_large). Uses blessed for terminal UI."""
import argparse, fcntl, math, os, re, select, signal, struct, subprocess, sys, time, threading
import argparse, fcntl, json, math, os, re, select, signal, struct, subprocess, sys, time, threading
from collections import deque
from pathlib import Path
@ -20,7 +20,9 @@ except ImportError:
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
HD = DIM // HEADS
CKPT_PATH = 'ane_stories110M_ckpt.bin'
CKPT_PATH_STATIC = 'ane_stories110M_ckpt.bin'
CKPT_PATH_DYNAMIC = 'training_dynamic/ane_stories110M_dyn_ckpt.bin'
CKPT_PATH = CKPT_PATH_STATIC # set in main() based on --dynamic
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent.parent / 'assets' / 'models' / 'tokenizer.bin')
@ -56,6 +58,9 @@ class State:
self.mem_mb_history = deque(maxlen=300)
self.proc_mem_mb_history = deque(maxlen=300)
self.train_pid = None
self.step_timestamps = [] # (step, time.monotonic()) for running ms/step
self.train_start = None # wall clock when first step seen
self.compile_ms = 0.0 # total compile time
S = State()
@ -278,23 +283,69 @@ def sysmetrics_thread():
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
RE_KERNELS_DYN = re.compile(r'Kernels: (\d+) compiled, (\d+) weight-bearing')
RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)')
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)')
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)(?:\s+lr=([\d.e+-]+))?(?:\s+([\d.]+)ms/step)?')
RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]')
RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)')
RE_TIMING_DYN = re.compile(r'ane_fwd=([\d.]+) io_fwd=([\d.]+) rms=([\d.]+) ane_bwd=([\d.]+) io_bwd=([\d.]+) silu=([\d.]+) rms_bwd=([\d.]+) cls=([\d.]+) cblas_wait=([\d.]+) dw_copy=([\d.]+)')
RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)')
RE_RESUME = re.compile(r'\[RESUMED step (\d+), loss=([\d.]+)\]')
RE_FLOPS = re.compile(r'FLOPs/step: fwd=([\d.]+)M bwd_dx=([\d.]+)M bwd_dW=([\d.]+)M sdpa_bwd=([\d.]+)M total=([\d.]+)M')
RE_ANE_FLOPS = re.compile(r'ANE FLOPs/step: ([\d.]+)M')
RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Compile|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
RE_COMPILED = re.compile(r'Compiled (\d+) kernels in (\d+)ms')
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
def parse_line(line):
S.logs.append(line)
# Parse JSON lines from static pipeline ({"type":"step",...} or {"type":"batch",...})
stripped = line.strip()
if stripped.startswith('{'):
try:
j = json.loads(stripped)
jt = j.get('type')
if jt == 'step':
S.step, S.loss = j['step'], j['loss']
S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss)
S.compiles = j.get('compiles', S.compiles)
now = time.monotonic()
if S.train_start is None:
S.train_start = now
S.step_timestamps.append((S.step, now))
if len(S.step_timestamps) >= 2:
dt = S.step_timestamps[-1][1] - S.step_timestamps[-2][1]
if dt > 0:
S.ms_per_step = dt * 1000
# Extract component timing from JSON
ct = {}
for k in ('t_ane', 't_io', 't_cls', 't_elem', 't_rms', 't_cblas_wait'):
if k in j:
ct[k[2:]] = j[k] # strip 't_' prefix
if ct:
S.component_timing = ct
return
elif jt == 'batch':
S.batch_num = j.get('batch', S.batch_num)
compile_ms = j.get('compile_ms', 0)
train_ms = j.get('train_ms', 0)
S.ms_per_step = j.get('ms_per_step', S.ms_per_step)
S.compile_ms += compile_ms
S.compile_pct = 100 * S.compile_ms / (S.compile_ms + train_ms) if S.compile_ms + train_ms > 0 else 0
return
elif jt == 'perf':
if 'ane_tflops' in j:
S.flops['ane_tflops'] = j['ane_tflops']
if 'ane_util_pct' in j:
S.flops['ane_util'] = j['ane_util_pct']
return
except (json.JSONDecodeError, KeyError):
pass
m = RE_CONFIG.search(line)
if m:
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
@ -303,7 +354,7 @@ def parse_line(line):
if m:
S.params = {'total': float(m[1]), 'transformer': float(m[2]), 'embed': float(m[3])}
return
m = RE_KERNELS.search(line)
m = RE_KERNELS_DYN.search(line) or RE_KERNELS.search(line)
if m:
S.kernels = {'total': int(m[1]), 'weight_bearing': int(m[2])}
return
@ -323,6 +374,18 @@ def parse_line(line):
m = RE_STEP.search(line)
if m:
S.step, S.loss = int(m[1]), float(m[2])
if m[3]:
S.training['lr'] = m[3]
if m[4]:
S.ms_per_step = float(m[4])
now = time.monotonic()
if S.train_start is None:
S.train_start = now
S.step_timestamps.append((S.step, now))
if not m[4] and len(S.step_timestamps) >= 2:
dt = S.step_timestamps[-1][1] - S.step_timestamps[-2][1]
if dt > 0:
S.ms_per_step = dt * 1000
S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss)
return
@ -334,6 +397,16 @@ def parse_line(line):
S.compiles = int(m[5])
S.compile_pct = 100 * compile_ms / (compile_ms + train_ms) if compile_ms + train_ms > 0 else 0
return
m = RE_TIMING_DYN.search(line)
if m:
vals = list(map(float, m.groups()))
S.component_timing = {
'ane_fwd': vals[0], 'io_fwd': vals[1], 'rms': vals[2],
'ane_bwd': vals[3], 'io_bwd': vals[4], 'silu': vals[5],
'rms_bwd': vals[6], 'cls': vals[7], 'cblas_wait': vals[8], 'dw_copy': vals[9],
'_dynamic': True
}
return
m = RE_TIMING.search(line)
if m:
S.component_timing = dict(zip(['ane', 'io', 'cls', 'elem', 'rms', 'cblas_wait'], map(float, m.groups())))
@ -346,6 +419,11 @@ def parse_line(line):
if m:
S.flops['ane_util'] = float(m[1])
return
m = RE_COMPILED.search(line)
if m:
S.compiles = int(m[1])
S.compile_ms += float(m[2])
return
m = RE_EFFICIENCY.search(line)
if m:
S.efficiency[m[1].strip()] = m[2].strip()
@ -514,23 +592,49 @@ def draw(term):
# Training stats (right panel)
sr = row
step_str = f'{S.step}' + (f'/{S.total_steps}' if S.total_steps and S.total_steps < 999999 else '')
put(sr, mid_x + 1, f' Step: {step_str} Loss: {S.loss:.4f}' if S.loss else ' Step: --', term.yellow)
# Elapsed time
elapsed = 0.0
if S.train_start:
elapsed = time.monotonic() - S.train_start
elapsed_str = f'{elapsed:.1f}s' if elapsed < 60 else f'{elapsed/60:.1f}m'
put(sr, mid_x + 1, f' Step: {step_str} Loss: {S.loss:.4f} [{elapsed_str}]' if S.loss else ' Step: --', term.yellow)
sr += 1
put(sr, mid_x + 1, f' Best: {S.best_loss:.4f} ms/step: {S.ms_per_step:.1f}' if S.best_loss < float('inf') else ' Best: --')
# ms/step + steps/sec
sps = 1000.0 / S.ms_per_step if S.ms_per_step > 0 else 0
put(sr, mid_x + 1, f' Best: {S.best_loss:.4f} {S.ms_per_step:.1f}ms/step ({sps:.1f} steps/s)' if S.best_loss < float('inf') else ' Best: --')
sr += 1
# TFLOPS
ane_tflops = S.flops.get('ane_tflops', 0)
ane_util = S.flops.get('ane_util', 0)
total_tflops = 0
if S.ms_per_step > 0 and S.flops.get('ane', 0) > 0:
if not ane_tflops:
ane_tflops = (S.flops['ane'] * 1e6) / (S.ms_per_step * 1e-3) / 1e12
total_tflops = (S.flops.get('total', 0) * 1e6) / (S.ms_per_step * 1e-3) / 1e12
if not ane_util and ane_tflops:
ane_util = 100.0 * ane_tflops / 15.8
compile_str = f' Compile: {S.compile_ms/1000:.1f}s' if S.compile_ms > 0 else ''
if ane_tflops:
put(sr, mid_x + 1, f' ANE: {ane_tflops:.2f}T Compile: {S.compile_pct:.0f}% Util: {ane_util:.1f}%')
else:
put(sr, mid_x + 1, f' Compile: {S.compile_pct:.0f}%')
tflops_str = f' ANE: {ane_tflops:.2f}T'
if total_tflops:
tflops_str += f' Total: {total_tflops:.2f}T'
tflops_str += f' Util: {ane_util:.1f}%{compile_str}'
put(sr, mid_x + 1, tflops_str)
elif compile_str:
put(sr, mid_x + 1, f'{compile_str}')
sr += 1
ct = S.component_timing
if ct:
put(sr, mid_x + 1, f' ane={ct.get("ane", 0):.1f} io={ct.get("io", 0):.1f} cls={ct.get("cls", 0):.1f} elem={ct.get("elem", 0):.1f}')
sr += 1
put(sr, mid_x + 1, f' rms={ct.get("rms", 0):.1f} cblas_wait={ct.get("cblas_wait", 0):.1f} ms/step')
sr += 1
if ct.get('_dynamic'):
put(sr, mid_x + 1, f' fwd={ct.get("ane_fwd",0):.1f} bwd={ct.get("ane_bwd",0):.1f} io={ct.get("io_fwd",0)+ct.get("io_bwd",0):.1f} silu={ct.get("silu",0):.1f}')
sr += 1
put(sr, mid_x + 1, f' cls={ct.get("cls",0):.1f} rms={ct.get("rms",0)+ct.get("rms_bwd",0):.1f} dw={ct.get("dw_copy",0):.1f} ms/step')
sr += 1
else:
put(sr, mid_x + 1, f' ane={ct.get("ane", 0):.1f} io={ct.get("io", 0):.1f} cls={ct.get("cls", 0):.1f} elem={ct.get("elem", 0):.1f}')
sr += 1
put(sr, mid_x + 1, f' rms={ct.get("rms", 0):.1f} cblas_wait={ct.get("cblas_wait", 0):.1f} ms/step')
sr += 1
pw = S.power
if any(pw.values()):
put(sr, mid_x + 1, '\u2500 Power ' + '\u2500' * max(0, right_w - 9), term.cyan)
@ -659,10 +763,24 @@ def set_nonblock(fd):
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
def spawn_training(resume=False, steps=10000):
cmd = 'make train_large 2>&1 && ./train_large'
def spawn_training(resume=False, steps=10000, dynamic=False, ane=False, scratch=False,
lr=None, accum=None, no_ane_extras=False):
if dynamic:
cmd = 'cd training_dynamic && make 2>&1 && ./train'
elif ane:
cmd = 'make train_large_ane 2>&1 && ./train_large_ane'
else:
cmd = 'make train_large 2>&1 && ./train_large'
if resume:
cmd += ' --resume'
if scratch and dynamic:
cmd += ' --scratch'
if lr is not None:
cmd += f' --lr {lr}'
if accum is not None and dynamic:
cmd += f' --accum {accum}'
if no_ane_extras and ane:
cmd += ' --no-ane-extras'
cmd += f' --steps {steps}'
proc = subprocess.Popen(
['bash', '-c', cmd],
@ -686,6 +804,12 @@ def spawn_powermetrics():
def main():
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
parser.add_argument('--dynamic', action='store_true', help='Dynamic weight pipeline (training_dynamic/)')
parser.add_argument('--ane', action='store_true', help='PR#19: ANE-offloaded classifier/softmax/rmsnorm_bwd')
parser.add_argument('--no-ane-extras', action='store_true', help='Disable ANE extras (use with --ane)')
parser.add_argument('--scratch', action='store_true', help='Train from scratch (random init)')
parser.add_argument('--lr', type=float, default=None, help='Learning rate')
parser.add_argument('--accum', type=int, default=None, help='Gradient accumulation steps')
parser.add_argument('--infinite', action='store_true', help='Train indefinitely')
parser.add_argument('--no-powermetrics', action='store_true')
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
@ -696,10 +820,15 @@ def main():
args.steps = 999999999
S.total_steps = args.steps
global CKPT_PATH
CKPT_PATH = CKPT_PATH_DYNAMIC if args.dynamic else CKPT_PATH_STATIC
term = Terminal()
procs = []
train_proc = spawn_training(resume=args.resume, steps=args.steps)
train_proc = spawn_training(resume=args.resume, steps=args.steps, dynamic=args.dynamic,
scratch=args.scratch, lr=args.lr, accum=args.accum,
ane=args.ane, no_ane_extras=args.no_ane_extras)
S.train_pid = train_proc.pid
procs.append(train_proc)
@ -839,7 +968,9 @@ def main():
if train_proc:
train_proc.terminate()
train_proc.wait()
train_proc = spawn_training(resume=True, steps=args.steps)
train_proc = spawn_training(resume=True, steps=args.steps, dynamic=args.dynamic,
lr=args.lr, accum=args.accum,
ane=args.ane, no_ane_extras=args.no_ane_extras)
S.train_pid = train_proc.pid
procs = [p for p in procs if p.poll() is None]
procs.append(train_proc)

View File

@ -0,0 +1,333 @@
// test_dynamic_matmul.m Benchmark dynamic matmul on ANE (no recompile)
// Layout: input [1, D, 1, S+D] activations in sp[0:S], weight rows in sp[S:S+D]
// MIL: slice reshape matmul reshape output
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach_time.h>
#include <arm_neon.h>
#include <Accelerate/Accelerate.h>
#include "stories_io.h"
// Generate MIL for y = x @ W where both come from input IOSurface
// Input: [1, IC, 1, SEQ+OC] fp32
// sp[0:SEQ] = activations x[IC, SEQ]
// sp[SEQ:SEQ+OC] = weight W[IC, OC] (each channel d holds W[d, :])
// Output: [1, OC, 1, SEQ] fp32
static NSString *gen_dynamic_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string];
[m appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
int sp_total = seq + oc;
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp_total];
// Cast to fp16
[m appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", ic, sp_total];
// Slice activations [1, IC, 1, SEQ]
[m appendString:@" tensor<int32, [4]> ba = const()[name = string(\"ba\"), val = tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sa = const()[name = string(\"sa\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", ic, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", ic, seq];
// Slice weight [1, IC, 1, OC]
[m appendFormat:@" tensor<int32, [4]> bw = const()[name = string(\"bw\"), val = tensor<int32, [4]>([0,0,0,%d])];\n", seq];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name = string(\"sw\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", ic, oc];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> wt = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"wt\")];\n", ic, oc];
// Reshape act: [1,IC,1,SEQ] [1,1,IC,SEQ] transpose [1,1,SEQ,IC]
[m appendFormat:@" tensor<int32, [4]> ra = const()[name = string(\"ra\"), val = tensor<int32, [4]>([1,1,%d,%d])];\n", ic, seq];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", ic, seq];
[m appendString:@" tensor<int32, [4]> pm = const()[name = string(\"pm\"), val = tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a3 = transpose(perm=pm,x=a2)[name=string(\"a3\")];\n", seq, ic];
// Reshape weight: [1,IC,1,OC] [1,1,IC,OC]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name = string(\"rw\"), val = tensor<int32, [4]>([1,1,%d,%d])];\n", ic, oc];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W = reshape(shape=rw,x=wt)[name=string(\"W\")];\n", ic, oc];
// matmul: [1,1,SEQ,IC] @ [1,1,IC,OC] [1,1,SEQ,OC]
[m appendString:@" bool bF = const()[name = string(\"bF\"), val = bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yh = matmul(transpose_x=bF,transpose_y=bF,x=a3,y=W)[name=string(\"mm\")];\n", seq, oc];
// Reshape+transpose back: [1,1,SEQ,OC] transpose [1,1,OC,SEQ] reshape [1,OC,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=yh)[name=string(\"yt\")];\n", oc, seq];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name = string(\"ro\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", oc, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", oc, seq];
// Cast back to fp32
[m appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype = to32, x = yr)[name = string(\"cout\")];\n", oc, seq];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// Tiled version: splits OC into tiles, each tile is a separate kernel
// For W[IC, OC], tile along OC: each tile handles W[:, t*T:(t+1)*T]
// Input per tile: [1, IC, 1, SEQ+T]
// Output per tile: [1, T, 1, SEQ]
typedef struct {
Kern **tiles;
int n_tiles, tile_oc, ic, oc, seq;
} TiledMatmul;
static TiledMatmul *compile_tiled_matmul(int ic, int oc, int tile_oc, int seq) {
TiledMatmul *tm = (TiledMatmul*)calloc(1, sizeof(TiledMatmul));
tm->ic = ic; tm->oc = oc; tm->seq = seq; tm->tile_oc = tile_oc;
tm->n_tiles = (oc + tile_oc - 1) / tile_oc;
tm->tiles = (Kern**)calloc(tm->n_tiles, sizeof(Kern*));
for (int t = 0; t < tm->n_tiles; t++) {
int this_oc = (t == tm->n_tiles-1 && oc % tile_oc) ? (oc % tile_oc) : tile_oc;
NSString *mil = gen_dynamic_matmul_mil(ic, this_oc, seq);
int in_bytes = ic * (seq + this_oc) * 4;
int out_bytes = this_oc * seq * 4;
tm->tiles[t] = compile_kern_mil_w(mil, @{}, in_bytes, out_bytes);
if (!tm->tiles[t]) { printf("Tile %d compile FAIL\n", t); return NULL; }
}
return tm;
}
// Write activations + weight tile into IOSurface
// act: [IC, SEQ] column-major (channel-first)
// W: [IC, OC] full weight matrix, we extract the tile
static void write_tile_input(TiledMatmul *tm, int tile_idx, const float *act, const float *W) {
Kern *k = tm->tiles[tile_idx];
int ic = tm->ic, seq = tm->seq, toc = tm->tile_oc;
int oc_off = tile_idx * toc;
int this_oc = (tile_idx == tm->n_tiles-1 && tm->oc % toc) ? (tm->oc % toc) : toc;
IOSurfaceLock(k->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(k->ioIn);
// Activations: buf[d * (seq+this_oc) + t] = act[d * seq + t]
for (int d = 0; d < ic; d++) {
memcpy(buf + d*(seq+this_oc), act + d*seq, seq*sizeof(float));
// Weight: buf[d * (seq+this_oc) + seq + c] = W[d * oc + oc_off + c]
for (int c = 0; c < this_oc; c++)
buf[d*(seq+this_oc) + seq + c] = W[d*tm->oc + oc_off + c];
}
IOSurfaceUnlock(k->ioIn, 0, NULL);
}
// Read tile output into full output buffer
static void read_tile_output(TiledMatmul *tm, int tile_idx, float *out) {
Kern *k = tm->tiles[tile_idx];
int seq = tm->seq, toc = tm->tile_oc;
int oc_off = tile_idx * toc;
int this_oc = (tile_idx == tm->n_tiles-1 && tm->oc % toc) ? (tm->oc % toc) : toc;
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float *obuf = (float*)IOSurfaceGetBaseAddress(k->ioOut);
for (int c = 0; c < this_oc; c++)
memcpy(out + (oc_off+c)*seq, obuf + c*seq, seq*sizeof(float));
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
}
int main(int argc, char **argv) {
@autoreleasepool {
mach_timebase_info(&g_tb);
ane_init();
// === Test 1: Single 64×64 dynamic matmul (correctness) ===
printf("=== Test 1: 64×64 dynamic matmul correctness ===\n");
{
int D = 64, S = 64;
NSString *mil = gen_dynamic_matmul_mil(D, D, S);
int in_b = D * (S+D) * 4, out_b = D * S * 4;
Kern *k = compile_kern_mil_w(mil, @{}, in_b, out_b);
if (!k) { printf("FAIL\n"); return 1; }
// Identity test
IOSurfaceLock(k->ioIn, 0, NULL);
float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn);
memset(inp, 0, in_b);
for (int d = 0; d < D; d++)
for (int s = 0; s < S; s++)
inp[d*(S+D) + s] = (float)(d*S + s) * 0.001f;
for (int d = 0; d < D; d++)
for (int c = 0; c < D; c++)
inp[d*(S+D) + S + c] = (d == c) ? 1.0f : 0.0f;
IOSurfaceUnlock(k->ioIn, 0, NULL);
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out = (float*)IOSurfaceGetBaseAddress(k->ioOut);
float me = 0;
for (int d = 0; d < D; d++)
for (int s = 0; s < S; s++) {
float e = fabsf(out[d*S+s] - inp[d*(S+D)+s]);
if (e > me) me = e;
}
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Identity: max_err=%.6f %s\n", me, me < 0.01 ? "PASS" : "FAIL");
// 2× test
IOSurfaceLock(k->ioIn, 0, NULL);
for (int d = 0; d < D; d++)
for (int c = 0; c < D; c++)
inp[d*(S+D) + S + c] = (d == c) ? 2.0f : 0.0f;
IOSurfaceUnlock(k->ioIn, 0, NULL);
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float sr = 0; int cnt = 0;
for (int i = 0; i < D*S; i++)
if (fabsf(inp[i/(S)*((S)+D) + i%S]) > 0.001f) { sr += out[i]/inp[i/S*(S+D)+i%S]; cnt++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("2× W: ratio=%.3f %s\n\n", cnt?sr/cnt:0, fabsf(sr/cnt-2.0f)<0.1?"PASS":"FAIL");
free_kern(k);
}
// === Test 2: 768×768 single kernel (if it compiles) ===
printf("=== Test 2: 768×768 single dynamic matmul ===\n");
{
int D = 768, S = 256;
int sp_total = S + D; // 256 + 768 = 1024
int in_b = D * sp_total * 4; // 768 * 1024 * 4 = 3.1MB
int out_b = D * S * 4; // 768 * 256 * 4 = 786KB
printf("IOSurface: in=%.1fMB out=%.1fKB\n", in_b/1e6, out_b/1e3);
NSString *mil = gen_dynamic_matmul_mil(D, D, S);
uint64_t t0 = mach_absolute_time();
Kern *k = compile_kern_mil_w(mil, @{}, in_b, out_b);
double compile_ms = tb_ms(mach_absolute_time() - t0);
if (!k) { printf("768×768 compile FAIL\n"); }
else {
printf("Compile: %.1fms\n", compile_ms);
// Random weights
float *act = (float*)calloc(D*S, sizeof(float));
float *W = (float*)calloc(D*D, sizeof(float));
for (int i = 0; i < D*S; i++) act[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.1f;
for (int i = 0; i < D*D; i++) W[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.01f;
// Write to IOSurface
IOSurfaceLock(k->ioIn, 0, NULL);
float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn);
for (int d = 0; d < D; d++) {
memcpy(inp + d*(S+D), act + d*S, S*4);
memcpy(inp + d*(S+D) + S, W + d*D, D*4);
}
IOSurfaceUnlock(k->ioIn, 0, NULL);
// Warmup
for (int i = 0; i < 3; i++) ane_eval(k);
// Benchmark
int iters = 50;
t0 = mach_absolute_time();
for (int i = 0; i < iters; i++) ane_eval(k);
double total_ms = tb_ms(mach_absolute_time() - t0);
double per_eval = total_ms / iters;
double flops = 2.0 * D * D * S; // matmul FLOPs
double gflops = flops / (per_eval * 1e6);
printf("768×768×256 matmul: %.3fms/eval %.1f GFLOP/s\n", per_eval, gflops);
// Benchmark with IO write (simulating weight update)
t0 = mach_absolute_time();
for (int i = 0; i < iters; i++) {
IOSurfaceLock(k->ioIn, 0, NULL);
float *p = (float*)IOSurfaceGetBaseAddress(k->ioIn);
for (int d = 0; d < D; d++)
memcpy(p + d*(S+D) + S, W + d*D, D*4);
IOSurfaceUnlock(k->ioIn, 0, NULL);
ane_eval(k);
}
total_ms = tb_ms(mach_absolute_time() - t0);
per_eval = total_ms / iters;
gflops = flops / (per_eval * 1e6);
printf("With weight IO: %.3fms/eval %.1f GFLOP/s\n", per_eval, gflops);
free(act); free(W); free_kern(k);
}
}
// === Test 3: Tiled matmul benchmark ===
int tile_sizes[] = {64, 128, 256, 384, 768};
int n_tiles_test = sizeof(tile_sizes)/sizeof(tile_sizes[0]);
printf("\n=== Test 3: Tiled 768×768 matmul (varying tile_oc) ===\n");
printf("%-10s %-8s %-10s %-12s %-10s\n", "tile_oc", "tiles", "compile", "eval/ms", "GFLOP/s");
{
int D = 768, S = 256;
float *act = (float*)calloc(D*S, sizeof(float));
float *W = (float*)calloc(D*D, sizeof(float));
float *out_full = (float*)calloc(D*S, sizeof(float));
for (int i = 0; i < D*S; i++) act[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.1f;
for (int i = 0; i < D*D; i++) W[i] = ((float)arc4random() / UINT32_MAX - 0.5f) * 0.01f;
for (int ti = 0; ti < n_tiles_test; ti++) {
int T = tile_sizes[ti];
if (T > D) continue;
uint64_t t0 = mach_absolute_time();
TiledMatmul *tm = compile_tiled_matmul(D, D, T, S);
double compile_ms = tb_ms(mach_absolute_time() - t0);
if (!tm) { printf("%-10d FAIL\n", T); continue; }
// Warmup
for (int w = 0; w < 2; w++) {
for (int t = 0; t < tm->n_tiles; t++) {
write_tile_input(tm, t, act, W);
ane_eval(tm->tiles[t]);
}
}
// Benchmark (with IO)
int iters = 20;
t0 = mach_absolute_time();
for (int i = 0; i < iters; i++) {
for (int t = 0; t < tm->n_tiles; t++) {
write_tile_input(tm, t, act, W);
ane_eval(tm->tiles[t]);
read_tile_output(tm, t, out_full);
}
}
double total_ms = tb_ms(mach_absolute_time() - t0);
double per_matmul = total_ms / iters;
double flops = 2.0 * D * D * S;
double gflops = flops / (per_matmul * 1e6);
printf("%-10d %-8d %-10.0fms %-12.3fms %-10.1f\n",
T, tm->n_tiles, compile_ms, per_matmul, gflops);
for (int t = 0; t < tm->n_tiles; t++) free_kern(tm->tiles[t]);
free(tm->tiles); free(tm);
}
// === Correctness check: compare with cblas ===
printf("\n=== Correctness: dynamic matmul vs cblas_sgemm ===\n");
{
int T = 768; // full, no tiling
TiledMatmul *tm = compile_tiled_matmul(D, D, T, S);
if (tm) {
write_tile_input(tm, 0, act, W);
ane_eval(tm->tiles[0]);
read_tile_output(tm, 0, out_full);
// Reference: cblas y = act^T @ W y[s,oc] = sum_d act[d,s]*W[d,oc]
// act is [D,S] col-major, W is [D,D] row-major
// We want out[oc,s] = sum_d act[d,s] * W[d,oc]
// = W^T @ act where W^T is [D,D] and act is [D,S] out is [D,S]
float *ref = (float*)calloc(D*S, sizeof(float));
// out[oc*S+s] = sum_d W[d*D+oc] * act[d*S+s]
// This is: (W^T) @ act in column-major: M=D,N=S,K=D
// cblas: C = alpha*A*B + beta*C
// A=W^T [D×D], B=act [D×S], C=ref [D×S]
cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
D, S, D, 1.0f, W, D, act, D, 0.0f, ref, D);
float me = 0;
for (int i = 0; i < D*S; i++) {
float e = fabsf(out_full[i] - ref[i]);
if (e > me) me = e;
}
printf("vs cblas: max_err=%.6f %s\n", me, me < 1.0 ? "PASS" : "FAIL");
free(ref);
for (int t = 0; t < tm->n_tiles; t++) free_kern(tm->tiles[t]);
free(tm->tiles); free(tm);
}
}
free(act); free(W); free(out_full);
}
// === Summary for training ===
printf("\n=== Summary ===\n");
printf("Stories110M: 12 layers × 10 matmuls/layer = 120 matmuls/step\n");
printf("Sizes: Wq/Wk/Wv/Wo [768,768], W1/W3 [2048,768], W2 [768,2048]\n");
printf("With dynamic weights: compile once, update IOSurface every step\n");
printf("\nDone.\n");
}
return 0;
}

View File

@ -0,0 +1,450 @@
// test_weight_patch.m Test whether ANE weights can be patched after compile
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach.h>
#import <mach/mach_time.h>
#import <mach/vm_map.h>
#include <arm_neon.h>
#include <Accelerate/Accelerate.h>
#include "stories_io.h"
// MIL: fp32 in cast fp16 conv cast fp32 out (matches inmem_peak.m pattern)
static NSString *gen_conv_mil(int ic, int oc, int sp) {
NSMutableString *m = [NSMutableString string];
[m appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp];
[m appendString:
@" string pt = const()[name = string(\"pt\"), val = string(\"valid\")];\n"
" tensor<int32, [2]> st = const()[name = string(\"st\"), val = tensor<int32, [2]>([1, 1])];\n"
" tensor<int32, [4]> pd = const()[name = string(\"pd\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
" tensor<int32, [2]> dl = const()[name = string(\"dl\"), val = tensor<int32, [2]>([1, 1])];\n"
" int32 gr = const()[name = string(\"gr\"), val = int32(1)];\n"
" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cast_in\")];\n", ic, sp];
[m appendFormat:@" tensor<fp16, [%d, %d, 1, 1]> W = const()[name = string(\"W\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = string(\"@model_path/weights/w.bin\"), offset = uint64(64)))];\n",
oc, ic, oc, ic];
[m appendFormat:@" tensor<fp16, [1, %d, 1, %d]> yh = conv(dilations = dl, groups = gr, pad = pd, pad_type = pt, strides = st, weight = W, x = xh)"
"[name = string(\"conv\")];\n", oc, sp];
[m appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1, %d, 1, %d]> y = cast(dtype = to32, x = yh)[name = string(\"cast_out\")];\n", oc, sp];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
int main(int argc, char **argv) {
@autoreleasepool {
mach_timebase_info(&g_tb);
ane_init();
int IC = 256, OC = 256, SP = 64;
int io_bytes = IC * SP * 4; // fp32
// Identity weight
float *W_id = (float*)calloc(OC*IC, sizeof(float));
for (int i = 0; i < IC; i++) W_id[i*IC+i] = 1.0f;
NSString *mil = gen_conv_mil(IC, OC, SP);
NSDictionary *wd = @{@"@model_path/weights/w.bin": @{@"offset":@0, @"data":build_blob(W_id, OC, IC)}};
printf("=== Compiling conv %dx%d sp=%d ===\n", OC, IC, SP);
Kern *k = compile_kern_mil_w(mil, wd, io_bytes, io_bytes);
if (!k) { printf("COMPILE FAILED\n"); free(W_id); return 1; }
printf("Compile OK!\n");
// Write fp32 input
IOSurfaceLock(k->ioIn, 0, NULL);
float *inp = (float*)IOSurfaceGetBaseAddress(k->ioIn);
for (int i = 0; i < IC*SP; i++) inp[i] = (i % 100) * 0.01f;
IOSurfaceUnlock(k->ioIn, 0, NULL);
// Eval with identity
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out = (float*)IOSurfaceGetBaseAddress(k->ioOut);
printf("In: [%.3f, %.3f, %.3f, %.3f]\n", inp[0], inp[1], inp[2], inp[3]);
printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float max_err = 0;
for (int i = 0; i < OC*SP; i++) {
float err = fabsf(out[i] - inp[i]);
if (err > max_err) max_err = err;
}
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Identity max_err=%.6f %s\n\n", max_err, max_err < 0.1 ? "PASS" : "FAIL");
// === Approach 1: Patch weight on disk, unload+reload ===
printf("=== Approach 1: Disk patch + unload/reload ===\n");
float *W_2x = (float*)calloc(OC*IC, sizeof(float));
for (int i = 0; i < IC; i++) W_2x[i*IC+i] = 2.0f;
[build_blob(W_2x, OC, IC) writeToFile:
[(__bridge NSString*)k->tmpDir stringByAppendingPathComponent:@"weights/w.bin"] atomically:YES];
id mdl = (__bridge id)k->model;
NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
e = nil;
BOOL ok = ((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e);
printf("Reload: %s\n", ok?"OK":"FAIL");
if (ok) {
// Re-create request after reload
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
CFRelease(k->request);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr = 0; int cnt = 0;
for (int i = 0; i < OC*SP; i++)
if (fabsf(inp[i]) > 0.01f) { sr += out[i]/inp[i]; cnt++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Ratio: %.3f (2.0=patched, 1.0=cached)\n\n", cnt>0?sr/cnt:0);
}
// === Approach 2: Memory scan ===
printf("=== Approach 2: Memory scan ===\n");
uint16_t pat1[8] = {0x3C00, 0, 0, 0, 0, 0, 0, 0};
uint16_t pat2[8] = {0x4000, 0, 0, 0, 0, 0, 0, 0};
mach_port_t task = mach_task_self();
vm_address_t addr = 0; vm_size_t sz; natural_t depth = 1;
int f1 = 0, f2 = 0;
while (1) {
struct vm_region_submap_info_64 info;
mach_msg_type_number_t count = VM_REGION_SUBMAP_INFO_COUNT_64;
if (vm_region_recurse_64(task, &addr, &sz, &depth, (vm_region_recurse_info_t)&info, &count) != KERN_SUCCESS) break;
if (info.is_submap) { depth++; continue; }
if (!(info.protection & VM_PROT_READ) || sz < (size_t)(OC*IC*2)) { addr += sz; continue; }
uint8_t *base = (uint8_t*)addr;
for (size_t off = 0; off + OC*IC*2 <= sz; off += 2) {
int w = 0;
if (memcmp(base+off, pat1, 16) == 0) w = 1;
else if (memcmp(base+off, pat2, 16) == 0) w = 2;
if (!w) continue;
uint16_t *p = (uint16_t*)(base+off), diag = (w==1)?0x3C00:0x4000;
int ok2 = 1;
for (int r = 0; r < OC && ok2; r++)
for (int c = 0; c < IC && ok2; c++)
if (p[r*IC+c] != ((r==c)?diag:0)) ok2 = 0;
if (!ok2) continue;
if (w==1) f1++; else f2++;
printf(" FOUND %dx @%p prot=%d/%d %s\n", w, (void*)(addr+off),
info.protection, info.max_protection, (info.protection&VM_PROT_WRITE)?"WR":"RO");
}
addr += sz;
}
printf("Found: 1x=%d 2x=%d\n", f1, f2);
// Now patch ALL found weight patterns to 3× and re-eval
if (f1 > 0 || f2 > 0) {
printf("Patching all found patterns to 3x identity...\n");
addr = 0; depth = 1;
while (1) {
struct vm_region_submap_info_64 info2;
mach_msg_type_number_t count2 = VM_REGION_SUBMAP_INFO_COUNT_64;
if (vm_region_recurse_64(task, &addr, &sz, &depth, (vm_region_recurse_info_t)&info2, &count2) != KERN_SUCCESS) break;
if (info2.is_submap) { depth++; continue; }
if (!(info2.protection & VM_PROT_READ) || sz < (size_t)(OC*IC*2)) { addr += sz; continue; }
uint8_t *base2 = (uint8_t*)addr;
for (size_t off = 0; off + OC*IC*2 <= sz; off += 2) {
int w2 = 0;
if (memcmp(base2+off, pat1, 16) == 0) w2 = 1;
else if (memcmp(base2+off, pat2, 16) == 0) w2 = 2;
if (!w2) continue;
uint16_t *p2 = (uint16_t*)(base2+off), diag2 = (w2==1)?0x3C00:0x4000;
int ok3 = 1;
for (int r = 0; r < OC && ok3; r++)
for (int c = 0; c < IC && ok3; c++)
if (p2[r*IC+c] != ((r==c)?diag2:0)) ok3 = 0;
if (!ok3) continue;
if (info2.protection & VM_PROT_WRITE) {
printf(" Patching %dx @%p to 3x\n", w2, (void*)(addr+off));
for (int r = 0; r < OC; r++)
for (int c = 0; c < IC; c++)
p2[r*IC+c] = (r==c) ? 0x4200 : 0; // fp16(3.0)
}
}
addr += sz;
}
printf("\n=== Eval after memory patch (expect 3x) ===\n");
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr2 = 0; int cnt2 = 0;
for (int i = 0; i < OC*SP; i++)
if (fabsf(inp[i]) > 0.01f) { sr2 += out[i]/inp[i]; cnt2++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("Ratio: %.3f (3.0=mem patch works!, 1.0=ANE uses SRAM copy)\n", cnt2>0?sr2/cnt2:0);
}
printf("\n");
// === Approach 3: Explore classes ===
printf("=== ANE classes ===\n");
const char *cn[] = {"_ANEWeight", "_ANEProgramForEvaluation", "_ANEChainingRequest", NULL};
for (int i = 0; cn[i]; i++) {
Class cls = NSClassFromString([NSString stringWithUTF8String:cn[i]]);
if (!cls) { printf("%s: NOT FOUND\n", cn[i]); continue; }
printf("%s:\n", cn[i]);
unsigned int mc = 0; Method *ms = class_copyMethodList(cls, &mc);
for (unsigned j = 0; j < mc; j++) printf(" - %s\n", sel_getName(method_getName(ms[j])));
free(ms);
mc = 0; ms = class_copyMethodList(object_getClass(cls), &mc);
for (unsigned j = 0; j < mc; j++) printf(" + %s\n", sel_getName(method_getName(ms[j])));
free(ms); printf("\n");
}
@try { printf("programHandle: %s\n", [[[mdl valueForKey:@"programHandle"] description] UTF8String]); } @catch(id x) {}
@try { printf("intermediateBufferHandle: %s\n", [[[mdl valueForKey:@"intermediateBufferHandle"] description] UTF8String]); } @catch(id x) {}
// === Approach 4: _ANEWeight + updateWeightURL ===
printf("\n=== Approach 4: _ANEWeight API ===\n");
Class AW = NSClassFromString(@"_ANEWeight");
if (AW) {
// Write 5× identity weights to a new file
float *W_5x = (float*)calloc(OC*IC, sizeof(float));
for (int i = 0; i < IC; i++) W_5x[i*IC+i] = 5.0f;
NSString *wpath = [NSTemporaryDirectory() stringByAppendingPathComponent:@"patched_w.bin"];
[build_blob(W_5x, OC, IC) writeToFile:wpath atomically:YES];
free(W_5x);
NSURL *wurl = [NSURL fileURLWithPath:wpath];
id wobj = ((id(*)(Class,SEL,id,id))objc_msgSend)(AW,
@selector(weightWithSymbolAndURL:weightURL:), @"W", wurl);
printf(" _ANEWeight: %s\n", wobj ? [[wobj description] UTF8String] : "nil");
if (wobj) {
printf(" weightSymbol: %s\n", [((id(*)(id,SEL))objc_msgSend)(wobj, @selector(weightSymbol)) UTF8String]);
printf(" weightURL: %s\n", [[((id(*)(id,SEL))objc_msgSend)(wobj, @selector(weightURL)) description] UTF8String]);
}
// Try to pass as weightsBuffer in request
printf("\n Trying weightsBuffer in request...\n");
id wI2 = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO2 = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
// Try passing weight array as weightsBuffer
if (wobj) {
CFRelease(k->request);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI2], @[@0], @[wO2], @[@0], @[wobj], nil, @0));
printf(" Request with weightsBuffer created\n");
@try {
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr3 = 0; int cnt3 = 0;
for (int i2 = 0; i2 < OC*SP; i2++)
if (fabsf(inp[i2]) > 0.01f) { sr3 += out[i2]/inp[i2]; cnt3++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Ratio: %.3f (5.0=weightsBuffer works!)\n", cnt3>0?sr3/cnt3:0);
} @catch(NSException *ex) {
printf(" Eval exception: %s\n", [[ex description] UTF8String]);
}
}
// Also try IOSurface as weightsBuffer
printf("\n Trying IOSurface as weightsBuffer...\n");
IOSurfaceRef wSurf = make_surface(OC*IC*2); // fp16 weights
IOSurfaceLock(wSurf, 0, NULL);
_Float16 *wfp16 = (_Float16*)IOSurfaceGetBaseAddress(wSurf);
for (int r = 0; r < OC; r++)
for (int c2 = 0; c2 < IC; c2++)
wfp16[r*IC+c2] = (r==c2) ? (_Float16)7.0f : (_Float16)0.0f; // 7× identity
IOSurfaceUnlock(wSurf, 0, NULL);
id wSurfObj = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), wSurf);
CFRelease(k->request);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI2], @[@0], @[wO2], @[@0], wSurfObj, nil, @0));
@try {
ane_eval(k);
IOSurfaceLock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Out: [%.3f, %.3f, %.3f, %.3f]\n", out[0], out[1], out[2], out[3]);
float sr4 = 0; int cnt4 = 0;
for (int i3 = 0; i3 < OC*SP; i3++)
if (fabsf(inp[i3]) > 0.01f) { sr4 += out[i3]/inp[i3]; cnt4++; }
IOSurfaceUnlock(k->ioOut, kIOSurfaceLockReadOnly, NULL);
printf(" Ratio: %.3f (7.0=IOSurface weights work!)\n", cnt4>0?sr4/cnt4:0);
} @catch(NSException *ex) {
printf(" Eval exception: %s\n", [[ex description] UTF8String]);
}
CFRelease(wSurf);
}
// === Approach 5: Weights packed into input IOSurface (fp16 with cast) ===
printf("\n=== Approach 5: Dynamic weights via input IOSurface ===\n");
// Element-wise mul: x * w where both come from input
// Input [1, IC*2, 1, SP] fp32 cast fp16 slice mul cast fp32
{
int C5 = IC;
NSMutableString *m5 = [NSMutableString string];
[m5 appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
[m5 appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", C5*2, SP];
[m5 appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m5 appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", C5*2, SP];
[m5 appendFormat:@" tensor<int32, [4]> b0 = const()[name = string(\"b0\"), val = tensor<int32, [4]>([0,0,0,0])];\n"];
[m5 appendFormat:@" tensor<int32, [4]> s0 = const()[name = string(\"s0\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", C5, SP];
[m5 appendFormat:@" tensor<fp16, [1,%d,1,%d]> data = slice_by_size(x=xh,begin=b0,size=s0)[name=string(\"data\")];\n", C5, SP];
[m5 appendFormat:@" tensor<int32, [4]> b1 = const()[name = string(\"b1\"), val = tensor<int32, [4]>([0,%d,0,0])];\n", C5];
[m5 appendFormat:@" tensor<fp16, [1,%d,1,%d]> wt = slice_by_size(x=xh,begin=b1,size=s0)[name=string(\"wt\")];\n", C5, SP];
[m5 appendFormat:@" tensor<fp16, [1,%d,1,%d]> yh = mul(x=data,y=wt)[name=string(\"mul\")];\n", C5, SP];
[m5 appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m5 appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype = to32, x = yh)[name = string(\"cout\")];\n", C5, SP];
[m5 appendString:@" } -> (y);\n}\n"];
int io5_in = C5*2*SP*4;
int io5_out = C5*SP*4;
Kern *k5 = compile_kern_mil_w(m5, @{}, io5_in, io5_out);
if (k5) {
printf("Compile OK!\n");
IOSurfaceLock(k5->ioIn, 0, NULL);
float *in5 = (float*)IOSurfaceGetBaseAddress(k5->ioIn);
for (int i = 0; i < C5*SP; i++) in5[i] = (i%100)*0.01f;
for (int i = 0; i < C5*SP; i++) in5[C5*SP+i] = 2.0f;
IOSurfaceUnlock(k5->ioIn, 0, NULL);
ane_eval(k5);
IOSurfaceLock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out5 = (float*)IOSurfaceGetBaseAddress(k5->ioOut);
printf("data=[%.3f,%.3f,%.3f], w=2.0 → out=[%.3f,%.3f,%.3f]\n",
in5[0],in5[1],in5[2], out5[0],out5[1],out5[2]);
IOSurfaceUnlock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
// Change weight dynamically NO recompile!
IOSurfaceLock(k5->ioIn, 0, NULL);
for (int i = 0; i < C5*SP; i++) in5[C5*SP+i] = 5.0f;
IOSurfaceUnlock(k5->ioIn, 0, NULL);
ane_eval(k5);
IOSurfaceLock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("w=5.0 → out=[%.3f,%.3f,%.3f] (expect 5×)\n", out5[0],out5[1],out5[2]);
IOSurfaceUnlock(k5->ioOut, kIOSurfaceLockReadOnly, NULL);
free_kern(k5);
} else printf("Compile FAILED\n");
}
// === Approach 6: matmul with dynamic weights from input ===
printf("\n=== Approach 6: matmul with dynamic W from input ===\n");
// Pack x[1,D,S,1] and W[1,D,1,D] into input, then reshape+matmul
// Input shape: [1, D+D*D, 1, S] first D channels=activations, rest=weight matrix flattened
// Actually, matmul needs [1,H,S,D] shapes. Let's try:
// Input: [1, D*(S+D), 1, 1] reshaped as needed
// Simpler: just test matmul with two sliced inputs
{
int D6 = 64, S6 = 64; // small for test
// Input: [1, D6+D6, S6, D6] but that's 4D...
// Actually ANE matmul works on [1,H,M,K] @ [1,H,K,N] [1,H,M,N]
// Let's pack x[1,1,S6,D6] and W[1,1,D6,D6] into [1,2,S6,D6]
// Then slice matmul
NSMutableString *m6 = [NSMutableString string];
[m6 appendString:@"program(1.3)\n"
"[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, "
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, "
"{\"coremltools-version\", \"9.0\"}})]\n{\n"];
// Input: [1, D6+D6, 1, S6*D6] flatten everything, then reshape
// Actually simplest: two separate regions in channel dim
// x_data: [1, D6, 1, S6] and W: [1, D6*D6, 1, 1]
// Total input channels: D6 + D6*D6
int total_ch = D6 + D6*D6;
[m6 appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", total_ch, S6];
[m6 appendString:@" string to16 = const()[name = string(\"to16\"), val = string(\"fp16\")];\n"];
[m6 appendFormat:@" tensor<fp16, [1, %d, 1, %d]> xh = cast(dtype = to16, x = x)[name = string(\"cin\")];\n", total_ch, S6];
// Slice activations: [1, D6, 1, S6]
[m6 appendFormat:@" tensor<int32, [4]> b0 = const()[name = string(\"b0\"), val = tensor<int32, [4]>([0,0,0,0])];\n"];
[m6 appendFormat:@" tensor<int32, [4]> sa = const()[name = string(\"sa\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=b0,size=sa)[name=string(\"act\")];\n", D6, S6];
// Slice weight: [1, D6*D6, 1, S6] but we only need [D6, D6] reshape
[m6 appendFormat:@" tensor<int32, [4]> bw = const()[name = string(\"bw\"), val = tensor<int32, [4]>([0,%d,0,0])];\n", D6];
[m6 appendFormat:@" tensor<int32, [4]> sw = const()[name = string(\"sw\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", D6*D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,%d]> wf = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"wf\")];\n", D6*D6, S6];
// Reshape weight to [1, D6, D6, S6] for matmul-like operation
// Actually for conv: weight needs to be [OC, IC, 1, 1] const. Can't use dynamic weight with conv.
// For matmul: need [1, 1, D6, D6] or similar
// Let's try: reshape wf to [1, D6, D6, S6], take first slice [:,:,:,0] no, that's hard
// Simpler: reshape to [D6, D6] and use matmul
// But matmul expects specific ranks... let me try:
[m6 appendFormat:@" tensor<int32, [4]> ws = const()[name = string(\"ws\"), val = tensor<int32, [4]>([1, 1, %d, %d])];\n", D6, D6];
// Only take first column of wf to get [1, D6*D6, 1, 1]
[m6 appendFormat:@" tensor<int32, [4]> sw1 = const()[name = string(\"sw1\"), val = tensor<int32, [4]>([1,%d,1,1])];\n", D6*D6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,1]> wf1 = slice_by_size(x=wf,begin=b0,size=sw1)[name=string(\"wf1\")];\n", D6*D6];
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> W = reshape(shape=ws,x=wf1)[name=string(\"W\")];\n", D6, D6];
// Reshape act to [1, 1, S6, D6] for matmul
[m6 appendFormat:@" tensor<int32, [4]> as2 = const()[name = string(\"as2\"), val = tensor<int32, [4]>([1, 1, %d, %d])];\n", D6, S6];
[m6 appendFormat:@" tensor<int32, [4]> pm = const()[name = string(\"pm\"), val = tensor<int32, [4]>([0, 1, 3, 2])];\n"];
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=as2,x=act)[name=string(\"a2\")];\n", D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> a3 = transpose(perm=pm,x=a2)[name=string(\"a3\")];\n", S6, D6];
// matmul: [1,1,S6,D6] @ [1,1,D6,D6] [1,1,S6,D6]
[m6 appendString:@" bool bF = const()[name = string(\"bF\"), val = bool(false)];\n"];
[m6 appendFormat:@" tensor<fp16, [1, 1, %d, %d]> yh = matmul(transpose_x = bF, transpose_y = bF, x = a3, y = W)[name = string(\"mm\")];\n", S6, D6];
// Reshape back to [1, D6, 1, S6]
[m6 appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=yh)[name=string(\"yt\")];\n", D6, S6];
[m6 appendFormat:@" tensor<int32, [4]> os = const()[name = string(\"os\"), val = tensor<int32, [4]>([1,%d,1,%d])];\n", D6, S6];
[m6 appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=os,x=yt)[name=string(\"yr\")];\n", D6, S6];
[m6 appendString:@" string to32 = const()[name = string(\"to32\"), val = string(\"fp32\")];\n"];
[m6 appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype = to32, x = yr)[name = string(\"cout\")];\n", D6, S6];
[m6 appendString:@" } -> (y);\n}\n"];
int io6_in = total_ch * S6 * 4;
int io6_out = D6 * S6 * 4;
Kern *k6 = compile_kern_mil_w(m6, @{}, io6_in, io6_out);
if (k6) {
printf("Dynamic matmul compile OK!\n");
// Set up: identity W, ramp input
IOSurfaceLock(k6->ioIn, 0, NULL);
float *in6 = (float*)IOSurfaceGetBaseAddress(k6->ioIn);
memset(in6, 0, io6_in);
// Activations: [D6, S6] in channel-first layout
for (int d = 0; d < D6; d++)
for (int s = 0; s < S6; s++)
in6[d*S6+s] = (d*S6+s) * 0.001f;
// Weight: identity matrix [D6, D6] packed in channels D6..D6+D6*D6, only col 0
float *wbase = in6 + D6*S6;
for (int r = 0; r < D6; r++)
for (int c = 0; c < D6; c++)
wbase[(r*D6+c)*S6] = (r==c) ? 1.0f : 0.0f; // only sp=0 matters
IOSurfaceUnlock(k6->ioIn, 0, NULL);
ane_eval(k6);
IOSurfaceLock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
float *out6 = (float*)IOSurfaceGetBaseAddress(k6->ioOut);
printf("Identity W: in=[%.4f,%.4f,%.4f] out=[%.4f,%.4f,%.4f]\n",
in6[0],in6[1],in6[2], out6[0],out6[1],out6[2]);
// Check
float me6 = 0;
for (int i = 0; i < D6*S6; i++) {
float e6 = fabsf(out6[i] - in6[i]);
if (e6 > me6) me6 = e6;
}
IOSurfaceUnlock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("max_err=%.6f %s\n", me6, me6 < 0.1 ? "PASS" : "FAIL");
// Now: 2× identity just change the IOSurface weight, no recompile!
IOSurfaceLock(k6->ioIn, 0, NULL);
for (int r = 0; r < D6; r++)
for (int c = 0; c < D6; c++)
wbase[(r*D6+c)*S6] = (r==c) ? 2.0f : 0.0f;
IOSurfaceUnlock(k6->ioIn, 0, NULL);
ane_eval(k6);
IOSurfaceLock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
printf("2× W: in=[%.4f,%.4f] out=[%.4f,%.4f] (expect 2×)\n",
in6[0],in6[1], out6[0],out6[1]);
IOSurfaceUnlock(k6->ioOut, kIOSurfaceLockReadOnly, NULL);
free_kern(k6);
} else printf("Dynamic matmul compile FAILED\n");
}
free_kern(k); free(W_id); free(W_2x);
printf("\nDone.\n");
}
return 0;
}

View File

@ -5,19 +5,15 @@
#include "stories_mil.h"
#include "stories_cpu_ops.h"
#define DEFAULT_CKPT_PATH "ane_stories110M_ckpt.bin"
#define DEFAULT_MODEL_PATH "../../assets/models/stories110M.bin"
#define DEFAULT_DATA_PATH "tinystories_data00.bin"
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
#define MODEL_PATH_DEFAULT "../../assets/models/stories110M.bin"
#define DATA_PATH "tinystories_data00.bin"
static const char *get_path(const char *env_var, const char *default_val) {
const char *v = getenv(env_var);
return (v && v[0]) ? v : default_val;
}
#define CKPT_PATH get_path("ANE_CKPT_PATH", DEFAULT_CKPT_PATH)
#define MODEL_PATH get_path("ANE_MODEL_PATH", DEFAULT_MODEL_PATH)
#define DATA_PATH get_path("ANE_DATA_PATH", DEFAULT_DATA_PATH)
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
@ -211,12 +207,24 @@ int main(int argc, char *argv[]) {
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
// Parse args
// Parse args (env vars set defaults, CLI flags override)
const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT);
const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT);
bool do_resume = false;
int pos = 0;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i];
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
else if (argv[i][0] != '-') {
if (pos == 0) model_path = argv[i];
else if (pos == 1) { /* seq - compile-time constant */ }
else if (pos == 2) total_steps = atoi(argv[i]);
else if (pos == 3) lr = atof(argv[i]);
pos++;
}
}
// Allocate per-layer state
@ -247,7 +255,7 @@ int main(int argc, char *argv[]) {
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
@ -255,8 +263,8 @@ int main(int argc, char *argv[]) {
if (!resuming) {
printf("=== ANE Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
printf("model=%s data=%s ckpt=%s\n", MODEL_PATH, DATA_PATH, CKPT_PATH);
if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
printf("model=%s data=%s ckpt=%s\n", model_path, DATA_PATH, ckpt_path);
if (!load_pretrained(lw, rms_final, embed, model_path)) {
printf("Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
@ -348,13 +356,13 @@ int main(int argc, char *argv[]) {
if (g_compile_count + TOTAL_WEIGHT_KERNELS > MAX_COMPILES) {
for (int L=0; L<NLAYERS; L++) { free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]); }
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step, total_steps, lr, last_loss,
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout);
execl(argv[0], argv[0], "--resume", NULL);
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, NULL);
perror("execl"); return 1;
}

View File

@ -16,19 +16,15 @@
#include "ane_rmsnorm_bwd.h"
#include "ane_classifier.h"
#define DEFAULT_CKPT_PATH "ane_stories110M_ckpt.bin"
#define DEFAULT_MODEL_PATH "../../assets/models/stories110M.bin"
#define DEFAULT_DATA_PATH "tinystories_data00.bin"
#define CKPT_PATH_DEFAULT "ane_stories110M_ckpt.bin"
#define MODEL_PATH_DEFAULT "../../assets/models/stories110M.bin"
#define DATA_PATH "tinystories_data00.bin"
static const char *get_path(const char *env_var, const char *default_val) {
const char *v = getenv(env_var);
return (v && v[0]) ? v : default_val;
}
#define CKPT_PATH get_path("ANE_CKPT_PATH", DEFAULT_CKPT_PATH)
#define MODEL_PATH get_path("ANE_MODEL_PATH", DEFAULT_MODEL_PATH)
#define DATA_PATH get_path("ANE_DATA_PATH", DEFAULT_DATA_PATH)
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
@ -212,11 +208,25 @@ int main(int argc, char *argv[]) {
float lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
const char *ckpt_path = get_path("ANE_CKPT_PATH", CKPT_PATH_DEFAULT);
const char *model_path = get_path("ANE_MODEL_PATH", MODEL_PATH_DEFAULT);
bool do_resume = false;
bool ane_extras = true;
int pos = 0;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--no-ane-extras") == 0) ane_extras = false;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
else if (strcmp(argv[i], "--ckpt") == 0 && i+1<argc) ckpt_path = argv[++i];
else if (strcmp(argv[i], "--model") == 0 && i+1<argc) model_path = argv[++i];
else if (argv[i][0] != '-') {
if (pos == 0) model_path = argv[i];
else if (pos == 1) { /* seq - compile-time constant */ }
else if (pos == 2) total_steps = atoi(argv[i]);
else if (pos == 3) lr = atof(argv[i]);
pos++;
}
}
LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS];
@ -238,7 +248,7 @@ int main(int argc, char *argv[]) {
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
resuming = load_checkpoint(ckpt_path, &start_step, &total_steps, &lr, &resume_loss,
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
@ -246,9 +256,10 @@ int main(int argc, char *argv[]) {
if (!resuming) {
printf("=== ANE Training: Stories110M (ANE-offloaded) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
printf("model=%s data=%s ckpt=%s\n", MODEL_PATH, DATA_PATH, CKPT_PATH);
printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n");
if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
printf("model=%s data=%s ckpt=%s\n", model_path, DATA_PATH, ckpt_path);
if (ane_extras) printf("NEW: final_rmsnorm, classifier_fwd, softmax, rmsnorm_bwd on ANE\n");
else printf("ANE extras DISABLED (classifier/softmax/rmsnorm_bwd on CPU)\n");
if (!load_pretrained(lw, rms_final, embed, model_path)) {
printf("Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
@ -318,9 +329,12 @@ int main(int argc, char *argv[]) {
memset(rmsFFNBwd, 0, sizeof(rmsFFNBwd));
// Softmax kernel (no weights compile once)
Kern *softmaxKern = compile_softmax_kern();
if (!softmaxKern) { printf("softmax compile failed\n"); return 1; }
printf("Softmax kernel compiled (no weights)\n");
Kern *softmaxKern = NULL;
if (ane_extras) {
softmaxKern = compile_softmax_kern();
if (!softmaxKern) { printf("softmax compile failed\n"); return 1; }
printf("Softmax kernel compiled (no weights)\n");
}
// Final RMSNorm and classifier are recompiled per batch since they have baked weights
Kern *finalRmsKern = NULL, *classifierKern = NULL;
@ -337,8 +351,8 @@ int main(int argc, char *argv[]) {
int step = start_step;
while (step < total_steps) {
// Check compile budget account for new kernels
// Per batch: 60 layer kernels + 24 rmsnorm_bwd + 1 classifier + 1 final_rms = 86
int kernels_needed = TOTAL_WEIGHT_KERNELS + 2*NLAYERS + 2;
// Per batch: 60 layer kernels [+ 24 rmsnorm_bwd + 1 classifier + 1 final_rms = 86 with extras]
int kernels_needed = TOTAL_WEIGHT_KERNELS + (ane_extras ? 2*NLAYERS + 2 : 0);
if (g_compile_count + kernels_needed > MAX_COMPILES) {
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&kern[L]); free_kern(sdpaBwd2[L]);
@ -346,13 +360,16 @@ int main(int argc, char *argv[]) {
}
free_kern(softmaxKern); free_kern(finalRmsKern); free_kern(classifierKern);
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step, total_steps, lr, last_loss,
save_checkpoint(ckpt_path, step, total_steps, lr, last_loss,
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout);
execl(argv[0], argv[0], "--resume", NULL);
if (ane_extras)
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, NULL);
else
execl(argv[0], argv[0], "--resume", "--ckpt", ckpt_path, "--no-ane-extras", NULL);
perror("execl"); return 1;
}
@ -367,13 +384,15 @@ int main(int argc, char *argv[]) {
printf("\nCompile failed at layer %d\n", L);
compile_ok = false; break;
}
// NEW: Compile RMSNorm backward kernels for this layer
free_kern(rmsAttBwd[L]); free_kern(rmsFFNBwd[L]);
rmsAttBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_att);
rmsFFNBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_ffn);
if (!rmsAttBwd[L] || !rmsFFNBwd[L]) {
printf("\nrmsnorm_bwd compile failed at layer %d\n", L);
compile_ok = false; break;
// Compile RMSNorm backward kernels for this layer (if ane_extras)
if (ane_extras) {
free_kern(rmsAttBwd[L]); free_kern(rmsFFNBwd[L]);
rmsAttBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_att);
rmsFFNBwd[L] = compile_rmsnorm_bwd_kern(lw[L].rms_ffn);
if (!rmsAttBwd[L] || !rmsFFNBwd[L]) {
printf("\nrmsnorm_bwd compile failed at layer %d\n", L);
compile_ok = false; break;
}
}
}
if (!compile_ok) { g_compile_count = MAX_COMPILES; continue; }
@ -386,18 +405,19 @@ int main(int argc, char *argv[]) {
}
}
// NEW: Compile final RMSNorm and classifier with current weights
free_kern(finalRmsKern); free_kern(classifierKern);
finalRmsKern = compile_final_rmsnorm_kern(rms_final);
classifierKern = compile_classifier_fwd(embed);
if (!finalRmsKern || !classifierKern) {
printf("finalRms or classifier compile failed\n");
g_compile_count = MAX_COMPILES; continue;
}
// Re-compile softmax if needed
if (!softmaxKern) {
softmaxKern = compile_softmax_kern();
if (!softmaxKern) { printf("softmax recompile failed\n"); return 1; }
// Compile final RMSNorm and classifier with current weights (if ane_extras)
if (ane_extras) {
free_kern(finalRmsKern); free_kern(classifierKern);
finalRmsKern = compile_final_rmsnorm_kern(rms_final);
classifierKern = compile_classifier_fwd(embed);
if (!finalRmsKern || !classifierKern) {
printf("finalRms or classifier compile failed\n");
g_compile_count = MAX_COMPILES; continue;
}
if (!softmaxKern) {
softmaxKern = compile_softmax_kern();
if (!softmaxKern) { printf("softmax recompile failed\n"); return 1; }
}
}
double cms = tb_ms(mach_absolute_time() - tc);
@ -461,26 +481,46 @@ int main(int argc, char *argv[]) {
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
}
// CHANGED: Final RMSNorm on ANE (was CPU)
t0=mach_absolute_time();
io_write_fp16(finalRmsKern->ioIn, x_cur, DIM, SEQ);
ane_eval(finalRmsKern);
io_read_fp16(finalRmsKern->ioOut, x_final, 0, DIM, SEQ);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
if (ane_extras) {
// Final RMSNorm on ANE
io_write_fp16(finalRmsKern->ioIn, x_cur, DIM, SEQ);
ane_eval(finalRmsKern);
io_read_fp16(finalRmsKern->ioOut, x_final, 0, DIM, SEQ);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// CHANGED: Classifier on ANE (was CPU cblas)
io_write_fp16(classifierKern->ioIn, x_final, DIM, SEQ);
ane_eval(classifierKern);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Classifier on ANE
io_write_fp16(classifierKern->ioIn, x_final, DIM, SEQ);
ane_eval(classifierKern);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// CHANGED: Softmax on ANE, then read probs back for NLL on CPU
io_copy(softmaxKern->ioIn, 0, classifierKern->ioOut, 0, VOCAB, SEQ);
ane_eval(softmaxKern);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Softmax on ANE
io_copy(softmaxKern->ioIn, 0, classifierKern->ioOut, 0, VOCAB, SEQ);
ane_eval(softmaxKern);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
// Read probs back for NLL loss + gradient (needs target indexing CPU)
io_read_fp16(softmaxKern->ioOut, probs, 0, VOCAB, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
io_read_fp16(softmaxKern->ioOut, probs, 0, VOCAB, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
} else {
// CPU fallback: rmsnorm + classifier + softmax
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
VOCAB, SEQ, DIM, 1.0f,
embed, DIM, x_final, SEQ, 0.0f, probs, SEQ);
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
// CPU softmax
for (int t = 0; t < SEQ; t++) {
float maxv = -1e30f;
for (int v = 0; v < VOCAB; v++) { float val = probs[v*SEQ+t]; if (val > maxv) maxv = val; }
float sum = 0;
for (int v = 0; v < VOCAB; v++) { probs[v*SEQ+t] = expf(probs[v*SEQ+t] - maxv); sum += probs[v*SEQ+t]; }
for (int v = 0; v < VOCAB; v++) probs[v*SEQ+t] /= sum;
}
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
}
// NLL loss + gradient on CPU: dlogits = probs - one_hot(targets)
float total_loss = 0;
@ -548,17 +588,19 @@ int main(int argc, char *argv[]) {
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
});
// CHANGED: RMSNorm2 backward on ANE
// Write concat(dx_ffn, x2) into rmsnorm_bwd kernel
io_write_fp16_at(rmsFFNBwd[L]->ioIn, 0, dx_ffn, DIM, SEQ);
io_write_fp16_at(rmsFFNBwd[L]->ioIn, DIM, ac->x2, DIM, SEQ);
ane_eval(rmsFFNBwd[L]);
io_read_fp16(rmsFFNBwd[L]->ioOut, dx2, 0, DIM, SEQ);
// dw for rmsnorm_ffn still on CPU (accumulate per step)
// RMSNorm2 backward
if (ane_extras) {
io_write_fp16_at(rmsFFNBwd[L]->ioIn, 0, dx_ffn, DIM, SEQ);
io_write_fp16_at(rmsFFNBwd[L]->ioIn, DIM, ac->x2, DIM, SEQ);
ane_eval(rmsFFNBwd[L]);
io_read_fp16(rmsFFNBwd[L]->ioOut, dx2, 0, DIM, SEQ);
}
// dw for rmsnorm_ffn on CPU (accumulate per step)
{
float *dw_tmp = (float*)calloc(DIM, 4);
float *dx_scratch = (float*)malloc(SEQ*DIM*4);
rmsnorm_bwd(dx_scratch, dw_tmp, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
if (!ane_extras) memcpy(dx2, dx_scratch, SEQ*DIM*4);
for(int i=0;i<DIM;i++) gr->rms_ffn[i] += dw_tmp[i];
free(dx_scratch); free(dw_tmp);
}
@ -608,17 +650,20 @@ int main(int argc, char *argv[]) {
ane_eval(kern[L].qkvBwd);
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
// CHANGED: RMSNorm1 backward on ANE
io_write_fp16_at(rmsAttBwd[L]->ioIn, 0, dx_attn, DIM, SEQ);
io_write_fp16_at(rmsAttBwd[L]->ioIn, DIM, ac->layer_in, DIM, SEQ);
ane_eval(rmsAttBwd[L]);
// RMSNorm1 backward
float *dx_rms1 = (float*)malloc(SEQ*DIM*4);
io_read_fp16(rmsAttBwd[L]->ioOut, dx_rms1, 0, DIM, SEQ);
// dw for rmsnorm_att still on CPU
if (ane_extras) {
io_write_fp16_at(rmsAttBwd[L]->ioIn, 0, dx_attn, DIM, SEQ);
io_write_fp16_at(rmsAttBwd[L]->ioIn, DIM, ac->layer_in, DIM, SEQ);
ane_eval(rmsAttBwd[L]);
io_read_fp16(rmsAttBwd[L]->ioOut, dx_rms1, 0, DIM, SEQ);
}
// dw for rmsnorm_att on CPU
{
float *dw_tmp = (float*)calloc(DIM, 4);
float *dx_scratch = (float*)malloc(SEQ*DIM*4);
rmsnorm_bwd(dx_scratch, dw_tmp, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
if (!ane_extras) memcpy(dx_rms1, dx_scratch, SEQ*DIM*4);
for(int i=0;i<DIM;i++) gr->rms_att[i] += dw_tmp[i];
free(dx_scratch); free(dw_tmp);
}

View File

@ -0,0 +1,9 @@
CC = xcrun clang
CFLAGS = -O2 -framework Foundation -framework IOSurface -framework Accelerate \
-isysroot $(shell xcrun --show-sdk-path) -fobjc-arc
train: train.m config.h io.h cpu_ops.h mil_dynamic.h
$(CC) $(CFLAGS) -o train train.m
clean:
rm -f train

View File

@ -0,0 +1,156 @@
// config.h — Stories110M model config, structs, ANE init
#pragma once
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach_time.h>
#import <Accelerate/Accelerate.h>
#include <math.h>
#include <unistd.h>
#include <dispatch/dispatch.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <arm_neon.h>
// Stories110M config
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define HD (DIM/HEADS)
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
// Weight sizes per layer
#define WQ_SZ (DIM*DIM)
#define WO_SZ (DIM*DIM)
#define W1_SZ (HIDDEN*DIM)
#define W2_SZ (DIM*HIDDEN)
#define W3_SZ (HIDDEN*DIM)
#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
// Attention score channels for SDPA backward
#define SCORE_CH (HEADS*SEQ)
// Per-layer weights
typedef struct {
float *Wq, *Wk, *Wv, *Wo;
float *W1, *W2, *W3;
float *rms_att, *rms_ffn;
} LayerWeights;
// Adam optimizer state
typedef struct { float *m, *v; size_t n; } AdamState;
typedef struct {
AdamState Wq, Wk, Wv, Wo, W1, W2, W3, rms_att, rms_ffn;
} LayerAdam;
// Per-layer activations (saved for backward)
typedef struct {
float *layer_in, *xnorm, *Q, *K, *V, *attn_out, *o_out;
float *x2, *x2norm, *h1, *h3, *silu_out, *ffn_out;
} LayerActs;
// Per-layer gradients
typedef struct {
float *Wq, *Wk, *Wv, *Wo, *W1, *W2, *W3, *rms_att, *rms_ffn;
} LayerGrads;
// ANE kernel handle
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
// Checkpoint header
typedef struct {
int magic, version, step, total_steps;
int n_layers, vocab_size, dim, hidden_dim, n_heads, seq_len;
float lr, loss;
double cum_compile, cum_train, cum_wall;
int cum_steps, cum_batches, adam_t;
int pad[3];
} CkptHdr;
// llama2.c model file header
typedef struct {
int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len;
} Llama2Config;
// Globals
static Class g_D, g_I, g_AR, g_AIO;
static mach_timebase_info_data_t g_tb;
static int g_compile_count = 0;
static void ane_init(void) {
dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor");
g_I = NSClassFromString(@"_ANEInMemoryModel");
g_AR = NSClassFromString(@"_ANERequest");
g_AIO= NSClassFromString(@"_ANEIOSurfaceObject");
}
static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; }
// Alloc helpers
static AdamState adam_alloc(size_t n) { AdamState s; s.m=(float*)calloc(n,4); s.v=(float*)calloc(n,4); s.n=n; return s; }
static void adam_free(AdamState *s) { free(s->m); free(s->v); }
static LayerWeights layer_weights_alloc(void) {
LayerWeights w;
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4);
w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4);
w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4);
return w;
}
static void layer_weights_free(LayerWeights *w) {
free(w->Wq);free(w->Wk);free(w->Wv);free(w->Wo);
free(w->W1);free(w->W2);free(w->W3);free(w->rms_att);free(w->rms_ffn);
}
static LayerAdam layer_adam_alloc(void) {
LayerAdam a;
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ);
a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ);
a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM);
return a;
}
static void layer_adam_free(LayerAdam *a) {
adam_free(&a->Wq);adam_free(&a->Wk);adam_free(&a->Wv);adam_free(&a->Wo);
adam_free(&a->W1);adam_free(&a->W2);adam_free(&a->W3);
adam_free(&a->rms_att);adam_free(&a->rms_ffn);
}
static LayerActs layer_acts_alloc(void) {
LayerActs a;
a.layer_in=(float*)malloc(SEQ*DIM*4);
a.xnorm=(float*)malloc(SEQ*DIM*4);
a.Q=(float*)malloc(SEQ*DIM*4); a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4);
a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4);
a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4);
a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4);
return a;
}
static void layer_acts_free(LayerActs *a) {
free(a->layer_in);free(a->xnorm);
free(a->Q);free(a->K);free(a->V);
free(a->attn_out);free(a->o_out);free(a->x2);free(a->x2norm);
free(a->h1);free(a->h3);free(a->silu_out);free(a->ffn_out);
}
static LayerGrads layer_grads_alloc(void) {
LayerGrads g;
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4);
g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4);
g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4);
return g;
}
static void layer_grads_zero(LayerGrads *g) {
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4);
memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4);
memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4);
}
static void layer_grads_free(LayerGrads *g) {
free(g->Wq);free(g->Wk);free(g->Wv);free(g->Wo);
free(g->W1);free(g->W2);free(g->W3);free(g->rms_att);free(g->rms_ffn);
}

View File

@ -0,0 +1,164 @@
// cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, embedding
#pragma once
#include "config.h"
static float *g_rms_tmp = NULL;
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
int n = S; vvrsqrtf(ss, ss, &n);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
}
free(ss);
}
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
float *rrms = (float*)malloc(S*4);
int n = S; vvrsqrtf(rrms, ss, &n);
float *dot = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
}
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
dw[i] += s;
}
free(ss); free(rrms); free(dot);
}
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
}
// Cross-entropy loss: operates on logits[V, S] column-major (each column = one token)
// Avoids transposing by using a per-token temp buffer
static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) {
float *col = (float*)malloc(V * 4); // single column buffer
float total_loss = 0;
float invS = 1.0f / S;
for (int t = 0; t < S; t++) {
// Gather column t: logits[v, t] = logits[v*S + t], stride=S
cblas_scopy(V, logits + t, S, col, 1);
// Softmax
float maxv; vDSP_maxv(col, 1, &maxv, (vDSP_Length)V);
float neg_max = -maxv;
vDSP_vsadd(col, 1, &neg_max, col, 1, (vDSP_Length)V);
int n = V; vvexpf(col, col, &n);
float sum; vDSP_sve(col, 1, &sum, (vDSP_Length)V);
float inv_sum = 1.0f / sum;
vDSP_vsmul(col, 1, &inv_sum, col, 1, (vDSP_Length)V);
// Loss + gradient
int tgt = targets[t];
total_loss -= logf(col[tgt] + 1e-10f);
col[tgt] -= 1.0f;
vDSP_vsmul(col, 1, &invS, col, 1, (vDSP_Length)V);
// Scatter back: dlogits[v*S + t] = col[v]
cblas_scopy(V, col, 1, dlogits + t, S);
}
free(col);
return total_loss / S;
}
// Vocab compaction: build mapping from full 32K vocab to compact vocab
typedef struct {
int compact_vocab; // number of active tokens
int *full_to_compact; // [VOCAB] → compact id (-1 if unused)
int *compact_to_full; // [compact_vocab] → full vocab id
} VocabMap;
static VocabMap vocab_map_build(const uint16_t *data, size_t n_tokens, int full_vocab) {
VocabMap vm;
vm.full_to_compact = (int*)malloc(full_vocab * sizeof(int));
memset(vm.full_to_compact, -1, full_vocab * sizeof(int));
// Scan for used tokens
for (size_t i = 0; i < n_tokens; i++) {
vm.full_to_compact[data[i]] = 0; // mark as used
}
// Assign compact IDs
int cid = 0;
for (int v = 0; v < full_vocab; v++) {
if (vm.full_to_compact[v] == 0)
vm.full_to_compact[v] = cid++;
else
vm.full_to_compact[v] = -1;
}
vm.compact_vocab = cid;
vm.compact_to_full = (int*)malloc(cid * sizeof(int));
for (int v = 0; v < full_vocab; v++) {
if (vm.full_to_compact[v] >= 0)
vm.compact_to_full[vm.full_to_compact[v]] = v;
}
return vm;
}
// Create compact embedding from full embedding
static float *vocab_compact_embed(const float *full_embed, const VocabMap *vm, int dim) {
float *ce = (float*)malloc((size_t)vm->compact_vocab * dim * 4);
for (int c = 0; c < vm->compact_vocab; c++)
memcpy(ce + c*dim, full_embed + vm->compact_to_full[c]*dim, dim*4);
return ce;
}
// Scatter compact embed gradients back to full embed
static void vocab_scatter_grads(float *full_gembed, const float *compact_gembed, const VocabMap *vm, int dim) {
for (int c = 0; c < vm->compact_vocab; c++) {
int fv = vm->compact_to_full[c];
for (int d = 0; d < dim; d++)
full_gembed[fv*dim + d] += compact_gembed[c*dim + d];
}
}
// Update full embed from compact embed (after adam)
static void vocab_update_full(float *full_embed, const float *compact_embed, const VocabMap *vm, int dim) {
for (int c = 0; c < vm->compact_vocab; c++)
memcpy(full_embed + vm->compact_to_full[c]*dim, compact_embed + c*dim, dim*4);
}
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
for (int d = 0; d < dim; d++)
x[d*seq + t] = embed[tok*dim + d];
}
}
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
for (int d = 0; d < dim; d++)
d_embed[tok*dim + d] += dx[d*seq + t];
}
}

View File

@ -0,0 +1,147 @@
// io.h — IOSurface helpers, NEON conversion, kernel compile/eval
#pragma once
#include "config.h"
static IOSurfaceRef make_surface(size_t bytes) {
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
}
// Blob builders for const weights (mask, rms)
static NSData *build_blob(const float *w, int rows, int cols) {
int ws=rows*cols*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
_Float16 *fp16=(_Float16*)(b+128);
for(int i=0;i<rows*cols;i++) fp16[i]=(_Float16)w[i];
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
static NSData *build_blob_fp16(_Float16 *d, int cnt) {
int ws=cnt*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
memcpy(b+128,d,ws);
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
// NEON vectorized conversion
static void cvt_f16_f32(float *dst, const _Float16 *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vld1q_f16((const __fp16*)(src+i));
vst1q_f32(dst+i, vcvt_f32_f16(vget_low_f16(h)));
vst1q_f32(dst+i+4, vcvt_f32_f16(vget_high_f16(h)));
}
for (; i < n; i++) dst[i] = (float)src[i];
}
static void cvt_f32_f16(_Float16 *dst, const float *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vcombine_f16(vcvt_f16_f32(vld1q_f32(src+i)),
vcvt_f16_f32(vld1q_f32(src+i+4)));
vst1q_f16((__fp16*)(dst+i), h);
}
for (; i < n; i++) dst[i] = (_Float16)src[i];
}
// IOSurface I/O (channel-first [C,S] layout, fp16 on surface)
static void io_write_fp16(IOSurfaceRef s, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s), data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
static void io_read_fp16(IOSurfaceRef s, float *data, int ch_off, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
cvt_f16_f32(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
static void io_copy(IOSurfaceRef dst, int dst_ch, IOSurfaceRef src, int src_ch, int channels, int sp) {
IOSurfaceLock(dst, 0, NULL);
IOSurfaceLock(src, kIOSurfaceLockReadOnly, NULL);
memcpy((_Float16*)IOSurfaceGetBaseAddress(dst) + dst_ch*sp,
(_Float16*)IOSurfaceGetBaseAddress(src) + src_ch*sp,
channels * sp * sizeof(_Float16));
IOSurfaceUnlock(src, kIOSurfaceLockReadOnly, NULL);
IOSurfaceUnlock(dst, 0, NULL);
}
static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// fp32 IOSurface I/O (for dynamic matmul kernels that use fp32 input/output)
// Layout: [1, IC, 1, SP] where SP = SEQ + OC
// Write activations at sp[0:SEQ] and weights at sp[SEQ:SEQ+OC]
static void io_write_dyn(IOSurfaceRef s, const float *act, int ic, int seq,
const float *W, int oc) {
int sp = seq + oc;
IOSurfaceLock(s, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(s);
for (int d = 0; d < ic; d++) {
memcpy(buf + d*sp, act + d*seq, seq*4);
memcpy(buf + d*sp + seq, W + d*oc, oc*4);
}
IOSurfaceUnlock(s, 0, NULL);
}
// Read output from dynamic matmul kernel: [1, OC, 1, SEQ]
static void io_read_dyn(IOSurfaceRef s, float *out, int oc, int seq) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
memcpy(out, (float*)IOSurfaceGetBaseAddress(s), oc * seq * 4);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
// Compile MIL to ANE kernel
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
@autoreleasepool {
NSData *md = [mil dataUsingEncoding:NSUTF8StringEncoding];
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(g_D, @selector(modelWithMILText:weights:optionsPlist:), md, weights, nil);
if (!desc) { printf(" [compile] desc=NULL\n"); return NULL; }
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(g_I, @selector(inMemoryModelWithDescriptor:), desc);
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
[[NSFileManager defaultManager] createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"] withIntermediateDirectories:YES attributes:nil error:nil];
[md writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
for (NSString *path in weights) {
NSString *rel = [path stringByReplacingOccurrencesOfString:@"@model_path/" withString:@""];
[weights[path][@"data"] writeToFile:[td stringByAppendingPathComponent:rel] atomically:YES];
}
NSError *e = nil;
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] FAIL: %s\n", e ? [[e description] UTF8String] : "no error"); return NULL;
}
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] load FAIL\n"); return NULL;
}
__sync_fetch_and_add(&g_compile_count, 1);
Kern *k = (Kern*)calloc(1, sizeof(Kern));
k->model = (void*)CFBridgingRetain(mdl);
k->ioIn = make_surface(ic_bytes);
k->ioOut = make_surface(oc_bytes);
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
k->tmpDir = (void*)CFBridgingRetain(td);
return k;
}
}
static void free_kern(Kern *k) {
if (!k) return;
id mdl = (__bridge id)k->model; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
CFRelease(k->ioIn); CFRelease(k->ioOut);
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
free(k);
}
static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}

View File

@ -0,0 +1,590 @@
// mil_dynamic.h — MIL generators using dynamic matmul (weights via IOSurface)
// Instead of conv(const_weight, x), we use matmul(x, W) where both come from input.
// Input layout: [1, IC, 1, SP] fp32, SP = SEQ + total_weight_cols
// Activations in sp[0:SEQ], weight matrices packed sequentially in sp[SEQ:]
#pragma once
#include "io.h"
#define MIL_HDR \
@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
// Helper: generate a dynamic matmul within a MIL function
// Slices activation [1,ic,1,seq] and weight [1,ic,1,oc] from input, does matmul
// act_sp_off: spatial offset for activations (usually 0)
// w_sp_off: spatial offset for weight block
// Returns variable name of result [1,oc,1,seq] in fp16
static void gen_dyn_matmul(NSMutableString *m, const char *prefix,
int ic, int oc, int seq,
int act_sp_off, int w_sp_off,
const char *input_var) {
// Slice activations
[m appendFormat:@" tensor<int32, [4]> %s_ba = const()[name=string(\"%s_ba\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, act_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sa = const()[name=string(\"%s_sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_act = slice_by_size(x=%s,begin=%s_ba,size=%s_sa)[name=string(\"%s_act\")];\n", ic, seq, prefix, input_var, prefix, prefix, prefix];
// Slice weight
[m appendFormat:@" tensor<int32, [4]> %s_bw = const()[name=string(\"%s_bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", prefix, prefix, w_sp_off];
[m appendFormat:@" tensor<int32, [4]> %s_sw = const()[name=string(\"%s_sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_wt = slice_by_size(x=%s,begin=%s_bw,size=%s_sw)[name=string(\"%s_wt\")];\n", ic, oc, prefix, input_var, prefix, prefix, prefix];
// Reshape act: [1,ic,1,seq] → [1,1,ic,seq] → transpose → [1,1,seq,ic]
[m appendFormat:@" tensor<int32, [4]> %s_ra = const()[name=string(\"%s_ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, seq];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a2 = reshape(shape=%s_ra,x=%s_act)[name=string(\"%s_a2\")];\n", ic, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_pm = const()[name=string(\"%s_pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n", prefix, prefix];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_a3 = transpose(perm=%s_pm,x=%s_a2)[name=string(\"%s_a3\")];\n", seq, ic, prefix, prefix, prefix, prefix];
// Reshape weight: [1,ic,1,oc] → [1,1,ic,oc]
[m appendFormat:@" tensor<int32, [4]> %s_rw = const()[name=string(\"%s_rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", prefix, prefix, ic, oc];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_W = reshape(shape=%s_rw,x=%s_wt)[name=string(\"%s_W\")];\n", ic, oc, prefix, prefix, prefix, prefix];
// matmul: [1,1,seq,ic] @ [1,1,ic,oc] → [1,1,seq,oc]
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yh = matmul(transpose_x=bF,transpose_y=bF,x=%s_a3,y=%s_W)[name=string(\"%s_yh\")];\n", seq, oc, prefix, prefix, prefix, prefix];
// Transpose back + reshape: [1,1,seq,oc] → [1,1,oc,seq] → [1,oc,1,seq]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> %s_yt = transpose(perm=%s_pm,x=%s_yh)[name=string(\"%s_yt\")];\n", oc, seq, prefix, prefix, prefix, prefix];
[m appendFormat:@" tensor<int32, [4]> %s_ro = const()[name=string(\"%s_ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", prefix, prefix, oc, seq];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> %s_y = reshape(shape=%s_ro,x=%s_yt)[name=string(\"%s_y\")];\n", oc, seq, prefix, prefix, prefix, prefix];
}
// ===== Dynamic matmul kernel: y = x @ W =====
// Input: [1, IC, 1, SEQ+OC] fp32 — act[0:SEQ] + W[SEQ:SEQ+OC]
// Output: [1, OC, 1, SEQ] fp32
static NSString *gen_dyn_matmul_mil(int ic, int oc, int seq) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
int sp = seq + oc;
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", ic, sp];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", ic, sp];
gen_dyn_matmul(m, "mm", ic, oc, seq, 0, seq, "xh");
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=mm_y)[name=string(\"cout\")];\n", oc, seq];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== SDPA forward (dynamic weights) =====
// Replaces gen_sdpa_fwd_taps: RMSNorm done on CPU, this kernel does QKV matmul + SDPA + Wo matmul
// Input: [1, DIM, 1, SEQ + 4*DIM] fp32
// sp[0:SEQ] = xnorm (rmsnorm output, DIM channels)
// sp[SEQ:SEQ+DIM] = Wq[DIM,DIM]
// sp[SEQ+DIM:SEQ+2D] = Wk[DIM,DIM]
// sp[SEQ+2D:SEQ+3D] = Wv[DIM,DIM]
// sp[SEQ+3D:SEQ+4D] = Wo[DIM,DIM]
// Output: [1, 6*DIM, 1, SEQ] fp16 = concat(o_out, Q, K, V, attn_out, xnorm_pass)
// NOTE: mask is still a const weight (it doesn't change)
static NSString *gen_sdpa_fwd_dynamic(void) {
float sc = 1.0f/sqrtf((float)HD);
int w_total = 4*DIM; // Wq+Wk+Wv+Wo
int sp_in = SEQ + w_total;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
// Cast to fp16
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm [1,DIM,1,SEQ]
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice Wq [1,DIM,1,DIM]
[m appendFormat:@" tensor<int32, [4]> bq = const()[name=string(\"bq\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wq = slice_by_size(x=xh,begin=bq,size=sw)[name=string(\"Wq\")];\n", DIM, DIM];
// Slice Wk
[m appendFormat:@" tensor<int32, [4]> bk = const()[name=string(\"bk\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wk = slice_by_size(x=xh,begin=bk,size=sw)[name=string(\"Wk\")];\n", DIM, DIM];
// Slice Wv
[m appendFormat:@" tensor<int32, [4]> bv = const()[name=string(\"bv\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wv = slice_by_size(x=xh,begin=bv,size=sw)[name=string(\"Wv\")];\n", DIM, DIM];
// Slice Wo
[m appendFormat:@" tensor<int32, [4]> bo = const()[name=string(\"bo\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wo = slice_by_size(x=xh,begin=bo,size=sw)[name=string(\"Wo\")];\n", DIM, DIM];
// Reshape for matmul: [1,D,1,S] → [1,1,D,S] → [1,1,S,D]
[m appendFormat:@" tensor<int32, [4]> r2 = const()[name=string(\"r2\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=r2,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
// Reshape weights: [1,D,1,D] → [1,1,D,D]
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wq2 = reshape(shape=rw,x=Wq)[name=string(\"Wq2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wk2 = reshape(shape=rw,x=Wk)[name=string(\"Wk2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wv2 = reshape(shape=rw,x=Wv)[name=string(\"Wv2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wo2 = reshape(shape=rw,x=Wo)[name=string(\"Wo2\")];\n", DIM, DIM];
// QKV matmul: [1,1,S,D] @ [1,1,D,D] → [1,1,S,D]
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wq2)[name=string(\"qm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> km = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wk2)[name=string(\"km\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vm = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=Wv2)[name=string(\"vm\")];\n", SEQ, DIM];
// Transpose back: [1,1,S,D] → [1,1,D,S] → reshape [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> qt = transpose(perm=pm,x=qm)[name=string(\"qt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> kt = transpose(perm=pm,x=km)[name=string(\"kt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> vt = transpose(perm=pm,x=vm)[name=string(\"vt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = reshape(shape=os,x=qt)[name=string(\"qf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = reshape(shape=os,x=kt)[name=string(\"kf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = reshape(shape=os,x=vt)[name=string(\"vf\")];\n", DIM, SEQ];
// SDPA: reshape to heads, matmul, mask, softmax, matmul
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
// Q @ K^T
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
// Causal mask (still const — doesn't change)
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ];
// Softmax
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ];
// scores @ V
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=bF,transpose_y=bF,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS, SEQ, HD];
// Reshape back to [1,DIM,1,SEQ]
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM, SEQ];
// Wo matmul: af → [1,1,S,D] @ Wo[1,1,D,D] → [1,1,S,D] → [1,D,1,S]
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> af2 = reshape(shape=r2,x=af)[name=string(\"af2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> aft = transpose(perm=pm,x=af2)[name=string(\"aft\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> om = matmul(transpose_x=bF,transpose_y=bF,x=aft,y=Wo2)[name=string(\"om\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ot = transpose(perm=pm,x=om)[name=string(\"ot\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = reshape(shape=os,x=ot)[name=string(\"oo\")];\n", DIM, SEQ];
// Output: concat(o_out, qf, kf, vf, af, xn) — same as original for backward compatibility
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM, SEQ];
// Cast to fp32
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 6*DIM, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
return m;
}
// ===== FFN forward (dynamic weights) =====
// RMSNorm on CPU. This kernel: xnorm @ W1 → SiLU, xnorm @ W3 → gate, gate*silu @ W2 → out
// Input: [1, DIM, 1, SEQ + HIDDEN + HIDDEN + DIM] fp32
// sp[0:SEQ] = xnorm [DIM,SEQ]
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN] = W3[DIM,HIDDEN]
// sp[SEQ+2*HIDDEN:SEQ+2*HIDDEN+DIM]= W2[HIDDEN→DIM] — but W2 is [DIM,HIDDEN], we need HIDDEN input channels
// PROBLEM: W2 has shape [DIM,HIDDEN] = HIDDEN input channels, but our kernel has DIM input channels.
// Solution: separate kernels for W1/W3 (DIM→HIDDEN) and W2 (HIDDEN→DIM)
// OR: do W1,W3 in one kernel, SiLU on CPU/ANE, W2 in another kernel.
// Simpler: 3 separate matmul kernels per FFN direction. But that's too many dispatches.
// Better: one kernel for W1+W3 (same input dim), CPU SiLU, one kernel for W2.
// FFN part 1: xnorm @ W1, xnorm @ W3 (both DIM→HIDDEN)
// Input: [1, DIM, 1, SEQ + 2*HIDDEN] fp32
// sp[0:SEQ] = xnorm
// sp[SEQ:SEQ+HIDDEN] = W1[DIM,HIDDEN]
// sp[SEQ+HIDDEN:SEQ+2*HIDDEN]= W3[DIM,HIDDEN]
// Output: [1, 2*HIDDEN, 1, SEQ] fp32 = concat(h1, h3)
static NSString *gen_ffn_w13_dynamic(void) {
int sp_in = SEQ + 2*HIDDEN;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice xnorm
[m appendString:@" tensor<int32, [4]> bx = const()[name=string(\"bx\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sx = const()[name=string(\"sx\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = slice_by_size(x=xh,begin=bx,size=sx)[name=string(\"xn\")];\n", DIM, SEQ];
// Slice W1
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1 = slice_by_size(x=xh,begin=b1,size=s1)[name=string(\"W1\")];\n", DIM, HIDDEN];
// Slice W3
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3 = slice_by_size(x=xh,begin=b3,size=s1)[name=string(\"W3\")];\n", DIM, HIDDEN];
// Reshape for matmul
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xn2 = reshape(shape=rd,x=xn)[name=string(\"xn2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> xnt = transpose(perm=pm,x=xn2)[name=string(\"xnt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W12 = reshape(shape=rw,x=W1)[name=string(\"W12\")];\n", DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W32 = reshape(shape=rw,x=W3)[name=string(\"W32\")];\n", DIM, HIDDEN];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W12)[name=string(\"h1m\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3m = matmul(transpose_x=bF,transpose_y=bF,x=xnt,y=W32)[name=string(\"h3m\")];\n", SEQ, HIDDEN];
// Transpose back
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h1t = transpose(perm=pm,x=h1m)[name=string(\"h1t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> h3t = transpose(perm=pm,x=h3m)[name=string(\"h3t\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> rh = const()[name=string(\"rh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = reshape(shape=rh,x=h1t)[name=string(\"h1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = reshape(shape=rh,x=h3t)[name=string(\"h3\")];\n", HIDDEN, SEQ];
// SiLU + gate
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN, SEQ];
// Concat output: (h1, h3, gate)
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(h1,h3,gate))[name=string(\"cat\")];\n", 2*HIDDEN+HIDDEN, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> out32 = cast(dtype=to32,x=out)[name=string(\"cout\")];\n", 3*HIDDEN, SEQ];
[m appendString:@" } -> (out32);\n}\n"];
return m;
}
// FFN part 2: gate @ W2 (HIDDEN→DIM)
// Input: [1, HIDDEN, 1, SEQ + DIM] fp32
// sp[0:SEQ] = gate [HIDDEN,SEQ]
// sp[SEQ:SEQ+DIM] = W2[HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32
static NSString *gen_ffn_w2_dynamic(void) {
int sp_in = SEQ + DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
[m appendString:@" tensor<int32, [4]> ba = const()[name=string(\"ba\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sa = const()[name=string(\"sa\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> act = slice_by_size(x=xh,begin=ba,size=sa)[name=string(\"act\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> bw = const()[name=string(\"bw\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W2 = slice_by_size(x=xh,begin=bw,size=sw)[name=string(\"W2\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> a2 = reshape(shape=ra,x=act)[name=string(\"a2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> at = transpose(perm=pm,x=a2)[name=string(\"at\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W22 = reshape(shape=rw,x=W2)[name=string(\"W22\")];\n", HIDDEN, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> ym = matmul(transpose_x=bF,transpose_y=bF,x=at,y=W22)[name=string(\"ym\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> yt = transpose(perm=pm,x=ym)[name=string(\"yt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> yr = reshape(shape=ro,x=yt)[name=string(\"yr\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=yr)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== FFN backward (dynamic weights) =====
// Input: [1, DIM+2*HIDDEN, 1, SEQ + HIDDEN + DIM + DIM] fp32
// Actually simpler to split into separate backward kernels like forward.
// FFN backward part 1: dffn @ W2^T → dsilu (HIDDEN), then SiLU derivative
// Input: [1, DIM, 1, SEQ + HIDDEN] fp32
// sp[0:SEQ] = dffn [DIM, SEQ]
// sp[SEQ:SEQ+HIDDEN]= W2^T [DIM, HIDDEN]
// Output: [1, HIDDEN, 1, SEQ] fp32 = dsilu_raw
static NSString *gen_ffn_bwd_w2t_dynamic(void) {
return gen_dyn_matmul_mil(DIM, HIDDEN, SEQ);
}
// FFN backward part 2: dh1 @ W1^T + dh3 @ W3^T → dx
// We need h1,h3 for SiLU derivative, but those are on CPU.
// Actually the SiLU derivative + gating is element-wise, do on CPU.
// Then: dh1 @ W1^T and dh3 @ W3^T are two separate matmuls (HIDDEN→DIM).
// Combine into one kernel:
// Input: [1, HIDDEN, 1, SEQ + SEQ + DIM + DIM] fp32
// sp[0:SEQ] = dh1 [HIDDEN,SEQ]
// sp[SEQ:2*SEQ] = dh3 [HIDDEN,SEQ]
// sp[2*SEQ:2*SEQ+DIM] = W1^T [HIDDEN,DIM]
// sp[2*SEQ+DIM:2*SEQ+2D] = W3^T [HIDDEN,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dx1 + dx3
static NSString *gen_ffn_bwd_w13t_dynamic(void) {
int sp_in = 2*SEQ + 2*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", HIDDEN, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", HIDDEN, sp_in];
// Slice dh1 [HIDDEN, SEQ]
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sh = const()[name=string(\"sh\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = slice_by_size(x=xh,begin=b0,size=sh)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
// Slice dh3
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = slice_by_size(x=xh,begin=b1,size=sh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
// Slice W1^T [HIDDEN, DIM]
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W1t = slice_by_size(x=xh,begin=b2,size=sw)[name=string(\"W1t\")];\n", HIDDEN, DIM];
// Slice W3^T
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> W3t = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"W3t\")];\n", HIDDEN, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
// dh1 matmul: [S,H] @ [H,D] → [S,D]
[m appendFormat:@" tensor<int32, [4]> ra = const()[name=string(\"ra\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh12 = reshape(shape=ra,x=dh1)[name=string(\"dh12\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh1t = transpose(perm=pm,x=dh12)[name=string(\"dh1t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh32 = reshape(shape=ra,x=dh3)[name=string(\"dh32\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dh3t = transpose(perm=pm,x=dh32)[name=string(\"dh3t\")];\n", SEQ, HIDDEN];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W1t2 = reshape(shape=rw,x=W1t)[name=string(\"W1t2\")];\n", HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> W3t2 = reshape(shape=rw,x=W3t)[name=string(\"W3t2\")];\n", HIDDEN, DIM];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx1m = matmul(transpose_x=bF,transpose_y=bF,x=dh1t,y=W1t2)[name=string(\"dx1m\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dx3m = matmul(transpose_x=bF,transpose_y=bF,x=dh3t,y=W3t2)[name=string(\"dx3m\")];\n", SEQ, DIM];
// Add
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxm = add(x=dx1m,y=dx3m)[name=string(\"dxm\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxm)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// ===== SDPA backward part 1 (dynamic Wo^T) =====
// Same as original gen_sdpa_bwd1 but Wo^T comes from input instead of const
// Input: [1, 4*DIM, 1, SEQ + DIM] fp32 — Q,K,V,dx2 in channels, Wo^T in spatial
// Wait — channels must match for all data. Q,K,V are [DIM,SEQ], dx2 is [DIM,SEQ].
// Total input channels = 4*DIM. But Wo^T is [DIM,DIM] = DIM channels of DIM spatial.
// Problem: can't mix 4*DIM channels for data with DIM channels for Wo^T.
// Solution: Wo^T matmul as separate kernel, then SDPA part purely element-wise on ANE.
// Wo^T matmul: dx2 @ Wo^T → da (DIM→DIM)
static NSString *gen_wot_dynamic(void) {
return gen_dyn_matmul_mil(DIM, DIM, SEQ);
}
// SDPA backward part 1 (no weights, all data): Q,K,V,da → dV,probs,dp
// Same as original but without Wo^T conv (already done)
// Input: [1, 4*DIM, 1, SEQ] fp16
static NSString *gen_sdpa_bwd1_noweight(void) {
float sc = 1.0f/sqrtf((float)HD);
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
// Slice Q,K,V,da
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> da = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM, SEQ];
// Reshape to heads
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=da)[name=string(\"rd\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dat = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS, SEQ, HD];
// Forward attention scores (recompute)
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ, SEQ, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS, SEQ, SEQ];
// dV = probs^T @ da, dp = da @ V^T
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=dat)[name=string(\"dv\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=dat,y=v)[name=string(\"dp\")];\n", HEADS, SEQ, SEQ];
// Reshape dV back
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM, SEQ];
// Flatten probs and dp for output
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 2: same as original (no weights, pure computation)
static NSString *gen_sdpa_bwd2(void) {
float sc = 1.0f/sqrtf((float)HD);
int bwd2_in = 2*SCORE_CH + 2*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS, HD, SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS, SEQ, SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS, SEQ, SEQ];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS, SEQ, HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS, HD, SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// QKV backward (dynamic): dq @ Wq^T + dk @ Wk^T + dv @ Wv^T → dx
// Input: [1, DIM, 1, 3*SEQ + 3*DIM] fp32
// sp[0:SEQ] = dq [DIM,SEQ]
// sp[SEQ:2*SEQ] = dk [DIM,SEQ]
// sp[2*SEQ:3*SEQ] = dv [DIM,SEQ]
// sp[3*SEQ:3*SEQ+DIM] = Wq^T [DIM,DIM]
// sp[3*SEQ+DIM:3*SEQ+2D] = Wk^T [DIM,DIM]
// sp[3*SEQ+2D:3*SEQ+3D] = Wv^T [DIM,DIM]
// Output: [1, DIM, 1, SEQ] fp32 = dxq + dxk + dxv
static NSString *gen_qkvb_dynamic(void) {
int sp_in = 3*SEQ + 3*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp32, [1, %d, 1, %d]> x) {\n", DIM, sp_in];
[m appendString:@" string to16 = const()[name=string(\"to16\"), val=string(\"fp16\")];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xh = cast(dtype=to16,x=x)[name=string(\"cin\")];\n", DIM, sp_in];
// Slice dq, dk, dv
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=xh,begin=b0,size=sd)[name=string(\"dq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=xh,begin=b1,size=sd)[name=string(\"dk\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 2*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=xh,begin=b2,size=sd)[name=string(\"dv\")];\n", DIM, SEQ];
// Slice Wq^T, Wk^T, Wv^T
[m appendFormat:@" tensor<int32, [4]> sw = const()[name=string(\"sw\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wqt = slice_by_size(x=xh,begin=b3,size=sw)[name=string(\"Wqt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b4 = const()[name=string(\"b4\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wkt = slice_by_size(x=xh,begin=b4,size=sw)[name=string(\"Wkt\")];\n", DIM, DIM];
[m appendFormat:@" tensor<int32, [4]> b5 = const()[name=string(\"b5\"), val=tensor<int32, [4]>([0,0,0,%d])];\n", 3*SEQ+2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> Wvt = slice_by_size(x=xh,begin=b5,size=sw)[name=string(\"Wvt\")];\n", DIM, DIM];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
// Reshape and matmul for each
[m appendFormat:@" tensor<int32, [4]> rd = const()[name=string(\"rd\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> rw = const()[name=string(\"rw\"), val=tensor<int32, [4]>([1,1,%d,%d])];\n", DIM, DIM];
// dq @ Wq^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dq2 = reshape(shape=rd,x=dq)[name=string(\"dq2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dqt = transpose(perm=pm,x=dq2)[name=string(\"dqt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wqt2 = reshape(shape=rw,x=Wqt)[name=string(\"Wqt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxq = matmul(transpose_x=bF,transpose_y=bF,x=dqt,y=Wqt2)[name=string(\"dxq\")];\n", SEQ, DIM];
// dk @ Wk^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dk2 = reshape(shape=rd,x=dk)[name=string(\"dk2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dkt = transpose(perm=pm,x=dk2)[name=string(\"dkt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wkt2 = reshape(shape=rw,x=Wkt)[name=string(\"Wkt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxk = matmul(transpose_x=bF,transpose_y=bF,x=dkt,y=Wkt2)[name=string(\"dxk\")];\n", SEQ, DIM];
// dv @ Wv^T
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dv2 = reshape(shape=rd,x=dv)[name=string(\"dv2\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dvt = transpose(perm=pm,x=dv2)[name=string(\"dvt\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> Wvt2 = reshape(shape=rw,x=Wvt)[name=string(\"Wvt2\")];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxv = matmul(transpose_x=bF,transpose_y=bF,x=dvt,y=Wvt2)[name=string(\"dxv\")];\n", SEQ, DIM];
// Sum: dxq + dxk + dxv
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxall = add(x=dxqk,y=dxv)[name=string(\"aall\")];\n", SEQ, DIM];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> dxt = transpose(perm=pm,x=dxall)[name=string(\"dxt\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> ro = const()[name=string(\"ro\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = reshape(shape=ro,x=dxt)[name=string(\"dx\")];\n", DIM, SEQ];
[m appendString:@" string to32 = const()[name=string(\"to32\"), val=string(\"fp32\")];\n"];
[m appendFormat:@" tensor<fp32, [1,%d,1,%d]> y = cast(dtype=to32,x=dx)[name=string(\"cout\")];\n", DIM, SEQ];
[m appendString:@" } -> (y);\n}\n"];
return m;
}
// Causal mask blob (used by sdpa_fwd and sdpa_bwd1)
static NSData *g_mask_blob = nil;
static NSData *get_mask_blob(void) {
if (!g_mask_blob) {
_Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16));
for(int t=0;t<SEQ;t++) for(int t2=0;t2<SEQ;t2++)
mask[t*SEQ+t2] = (t2<=t) ? (_Float16)0.0f : (_Float16)(-65504.0f);
g_mask_blob = build_blob_fp16(mask, SEQ*SEQ);
free(mask);
}
return g_mask_blob;
}

View File

@ -0,0 +1,876 @@
// train.m Dynamic weight ANE training for Stories110M
// Compile kernels ONCE at startup, update weights via IOSurface every step.
// No exec() restart needed eliminates 76% compile overhead.
#include "mil_dynamic.h"
#include "cpu_ops.h"
#define CKPT_PATH "ane_stories110M_dyn_ckpt.bin"
#define MODEL_PATH "../../../assets/models/stories110M.bin"
#define DATA_PATH "../tinystories_data00.bin"
// Dynamic kernel set per layer
typedef struct {
Kern *sdpaFwd; // QKV matmul + SDPA + Wo matmul (dynamic weights via IOSurface)
Kern *ffnW13; // W1,W3 matmul (dynamic)
Kern *ffnW2; // W2 matmul (dynamic)
Kern *ffnBwdW2t; // dffn @ W2^T (dynamic)
Kern *ffnBwdW13t; // dh1@W1^T + dh3@W3^T (dynamic)
Kern *wotBwd; // dx2 @ Wo^T (dynamic)
Kern *sdpaBwd1; // Q,K,V,da dV,probs,dp (weight-free, has mask const)
Kern *sdpaBwd2; // probs,dp,Q,K dQ,dK (weight-free)
Kern *qkvBwd; // dq@Wq^T + dk@Wk^T + dv@Wv^T (dynamic)
} DynLayerKernels;
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch!\n"); fclose(f); return false;
}
int V = abs(cfg.vocab_size);
fread(embed, 4, V * DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
fread(rms_final, 4, DIM, f);
fclose(f);
printf(" Loaded pretrained weights\n");
return true;
}
// Transpose W[rows,cols] W^T[cols,rows] stored as [cols channels, rows spatial]
static void transpose_weight(float *dst, const float *src, int rows, int cols) {
for (int r = 0; r < rows; r++)
for (int c = 0; c < cols; c++)
dst[c * rows + r] = src[r * cols + c];
}
// ===== Compile all dynamic kernels (ONCE) =====
static bool compile_dynamic_kernels(DynLayerKernels *dk) {
NSDictionary *mask_w = @{@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()}};
// SDPA forward: [1, DIM, 1, SEQ+4*DIM] fp32 [1, 6*DIM, 1, SEQ] fp32
printf(" Compiling sdpaFwd...\n");
dk->sdpaFwd = compile_kern_mil_w(gen_sdpa_fwd_dynamic(), mask_w,
DIM*(SEQ+4*DIM)*4, 6*DIM*SEQ*4);
if (!dk->sdpaFwd) return false;
// FFN W1+W3: [1, DIM, 1, SEQ+2*HIDDEN] fp32 [1, 3*HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnW13...\n");
dk->ffnW13 = compile_kern_mil_w(gen_ffn_w13_dynamic(), @{},
DIM*(SEQ+2*HIDDEN)*4, 3*HIDDEN*SEQ*4);
if (!dk->ffnW13) return false;
// FFN W2: [1, HIDDEN, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling ffnW2...\n");
dk->ffnW2 = compile_kern_mil_w(gen_ffn_w2_dynamic(), @{},
HIDDEN*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->ffnW2) return false;
// FFN backward W2^T: [1, DIM, 1, SEQ+HIDDEN] fp32 [1, HIDDEN, 1, SEQ] fp32
printf(" Compiling ffnBwdW2t...\n");
dk->ffnBwdW2t = compile_kern_mil_w(gen_ffn_bwd_w2t_dynamic(), @{},
DIM*(SEQ+HIDDEN)*4, HIDDEN*SEQ*4);
if (!dk->ffnBwdW2t) return false;
// FFN backward W1^T+W3^T: [1, HIDDEN, 1, 2*SEQ+2*DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling ffnBwdW13t...\n");
dk->ffnBwdW13t = compile_kern_mil_w(gen_ffn_bwd_w13t_dynamic(), @{},
HIDDEN*(2*SEQ+2*DIM)*4, DIM*SEQ*4);
if (!dk->ffnBwdW13t) return false;
// Wo^T backward: [1, DIM, 1, SEQ+DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling wotBwd...\n");
dk->wotBwd = compile_kern_mil_w(gen_wot_dynamic(), @{},
DIM*(SEQ+DIM)*4, DIM*SEQ*4);
if (!dk->wotBwd) return false;
// SDPA bwd1 (no dynamic weights, has mask): [1, 4*DIM, 1, SEQ] fp16 [1, DIM+2*SCORE_CH, 1, SEQ] fp16
printf(" Compiling sdpaBwd1...\n");
dk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1_noweight(), mask_w,
4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
if (!dk->sdpaBwd1) return false;
// SDPA bwd2 (no weights): [1, 2*SCORE_CH+2*DIM, 1, SEQ] fp16 [1, 2*DIM, 1, SEQ] fp16
printf(" Compiling sdpaBwd2...\n");
dk->sdpaBwd2 = compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
if (!dk->sdpaBwd2) return false;
// QKV backward: [1, DIM, 1, 3*SEQ+3*DIM] fp32 [1, DIM, 1, SEQ] fp32
printf(" Compiling qkvBwd...\n");
dk->qkvBwd = compile_kern_mil_w(gen_qkvb_dynamic(), @{},
DIM*(3*SEQ+3*DIM)*4, DIM*SEQ*4);
if (!dk->qkvBwd) return false;
return true;
}
// ===== Write dynamic weights into IOSurface =====
// sdpaFwd: [1, DIM, 1, SEQ+4*DIM] xnorm at sp[0:S], Wq/Wk/Wv/Wo at sp[S:]
static void write_sdpa_fwd_input(DynLayerKernels *dk, const float *xnorm,
const float *Wq, const float *Wk, const float *Wv, const float *Wo) {
IOSurfaceLock(dk->sdpaFwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->sdpaFwd->ioIn);
int sp = SEQ + 4*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+DIM, Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+2*DIM, Wv + d*DIM, DIM*4);
memcpy(buf + d*sp + SEQ+3*DIM, Wo + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->sdpaFwd->ioIn, 0, NULL);
}
// ffnW13: [1, DIM, 1, SEQ+2*HIDDEN] xnorm at sp[0:S], W1,W3 at sp[S:]
static void write_ffn_w13_input(DynLayerKernels *dk, const float *xnorm,
const float *W1, const float *W3) {
IOSurfaceLock(dk->ffnW13->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW13->ioIn);
int sp = SEQ + 2*HIDDEN;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, xnorm + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W1 + d*HIDDEN, HIDDEN*4);
memcpy(buf + d*sp + SEQ+HIDDEN, W3 + d*HIDDEN, HIDDEN*4);
}
IOSurfaceUnlock(dk->ffnW13->ioIn, 0, NULL);
}
// ffnW2: [1, HIDDEN, 1, SEQ+DIM] gate at sp[0:S], W2 at sp[S:]
static void write_ffn_w2_input(DynLayerKernels *dk, const float *gate, const float *W2) {
IOSurfaceLock(dk->ffnW2->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk->ffnW2->ioIn);
int sp = SEQ + DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, gate + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, W2 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk->ffnW2->ioIn, 0, NULL);
}
// ===== Checkpoint =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double ct, double cw, int cs, int adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 3;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_train = ct; h.cum_wall = cw; h.cum_steps = cs; h.adam_t = adam_t;
fwrite(&h, sizeof(h), 1, f);
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
}
fwrite(rms_final,4,DIM,f);
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
fwrite(embed,4,VOCAB*DIM,f);
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
fclose(f);
}
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
double *ct, double *cw, int *cs, int *adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "rb");
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 3) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*ct = h.cum_train; *cw = h.cum_wall; *cs = h.cum_steps; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
}
fread(rms_final,4,DIM,f);
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
fread(embed,4,VOCAB*DIM,f);
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
fclose(f);
return true;
}
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 10000;
float max_lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
int accum_steps = 10;
int warmup_steps = 100;
float grad_clip = 1.0f;
float min_lr_frac = 0.1f; // min_lr = max_lr * 0.1
bool do_resume = false, from_scratch = false;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--scratch") == 0) from_scratch = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) max_lr = atof(argv[++i]);
else if (strcmp(argv[i], "--accum") == 0 && i+1<argc) accum_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--warmup") == 0 && i+1<argc) warmup_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--clip") == 0 && i+1<argc) grad_clip = atof(argv[++i]);
}
float lr = max_lr;
// Allocate per-layer state
LayerWeights lw[NLAYERS]; LayerAdam la[NLAYERS];
LayerActs acts[NLAYERS]; LayerGrads grads[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
lw[L] = layer_weights_alloc(); la[L] = layer_adam_alloc();
acts[L] = layer_acts_alloc(); grads[L] = layer_grads_alloc();
}
float *rms_final = (float*)malloc(DIM*4);
float *embed = (float*)malloc(VOCAB*DIM*4);
float *grms_final = (float*)calloc(DIM, 4);
float *gembed = (float*)calloc(VOCAB*DIM, 4);
AdamState arms_final = adam_alloc(DIM);
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
double cum_train=0, cum_wall=0; int cum_steps=0;
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
&cum_train, &cum_wall, &cum_steps, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Dynamic Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
// Param counts for dashboard
double xformer_m = (double)NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ + 2.0*DIM) / 1e6;
double embed_m = (double)VOCAB*DIM / 1e6;
printf("Params: %.1fM (transformer %.1fM + embed %.1fM)\n", xformer_m+embed_m, xformer_m, embed_m);
printf("Kernels: 9 compiled, 9 weight-bearing\n");
printf("Accum %d steps, LR=%g\n", accum_steps, max_lr);
// FLOPs estimate: 6*N*B*T for transformer (forward+backward 3x forward)
double fwd_flops = 2.0*NLAYERS*(4.0*WQ_SZ + 2.0*W1_SZ + W2_SZ + W3_SZ) * SEQ;
double total_flops = 3.0 * fwd_flops; // fwd + bwd 3x fwd
printf("FLOPs/step: fwd=%.1fM bwd_dx=%.1fM bwd_dW=%.1fM sdpa_bwd=0.0M total=%.1fM\n",
fwd_flops/1e6, fwd_flops/1e6, fwd_flops/1e6, total_flops/1e6);
printf("ANE FLOPs/step: %.1fM\n", total_flops/1e6);
if (from_scratch || !load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
if (from_scratch) printf(" Training from scratch (random init)\n");
else printf(" Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
}
}
// Precompute transposed weights (for backward pass kernels)
// These get updated after each Adam step
float *Wqt_buf[NLAYERS], *Wkt_buf[NLAYERS], *Wvt_buf[NLAYERS], *Wot_buf[NLAYERS];
float *W1t_buf[NLAYERS], *W2t_buf[NLAYERS], *W3t_buf[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
Wqt_buf[L]=(float*)malloc(WQ_SZ*4); Wkt_buf[L]=(float*)malloc(WQ_SZ*4);
Wvt_buf[L]=(float*)malloc(WQ_SZ*4); Wot_buf[L]=(float*)malloc(WO_SZ*4);
W1t_buf[L]=(float*)malloc(W1_SZ*4); W2t_buf[L]=(float*)malloc(W2_SZ*4);
W3t_buf[L]=(float*)malloc(W3_SZ*4);
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
}
// mmap token data
int data_fd = open(DATA_PATH, O_RDONLY);
if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; }
struct stat st; fstat(data_fd, &st);
size_t data_len = st.st_size;
uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
size_t n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
// Vocab compaction: map 32K sparse vocab ~9K compact
VocabMap vm = vocab_map_build(token_data, n_tokens, VOCAB);
int CV = vm.compact_vocab;
printf("Vocab compaction: %d → %d active tokens (%.1fx reduction)\n", VOCAB, CV, (float)VOCAB/CV);
// Create compact embedding + adam state
float *cembed = vocab_compact_embed(embed, &vm, DIM);
float *gcembed = (float*)calloc((size_t)CV*DIM, 4);
AdamState acembed = adam_alloc((size_t)CV*DIM);
// ===== Compile all kernels ONCE =====
printf("Compiling %d dynamic kernels (one-time)...\n", 9);
uint64_t tc = mach_absolute_time();
DynLayerKernels dk;
if (!compile_dynamic_kernels(&dk)) {
printf("Compilation failed!\n"); return 1;
}
double compile_ms = tb_ms(mach_absolute_time() - tc);
printf("Compiled 9 kernels in %.0fms (shared across all %d layers)\n\n", compile_ms, NLAYERS);
// Gradient + work buffers
float *dy = (float*)malloc(SEQ*DIM*4);
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk_buf = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4);
float *xnorm_buf = (float*)malloc(SEQ*DIM*4);
float *logits = (float*)malloc(SEQ*CV*4);
float *dlogits = (float*)malloc(SEQ*CV*4);
float *gate_buf = (float*)malloc(SEQ*HIDDEN*4);
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
float *dsilu = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp = (float*)malloc(SEQ*HIDDEN*4);
float *silu_tmp2 = (float*)malloc(SEQ*HIDDEN*4);
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_train_ms = 0;
int total_steps_done = 0;
uint64_t t_wall_start = mach_absolute_time();
srand48(42 + start_step);
for (int step = start_step; step < total_steps; step++) {
uint64_t t0, t1, t_step = mach_absolute_time();
// Sample data
size_t max_pos = n_tokens - SEQ - 1;
size_t pos = (size_t)(drand48() * max_pos);
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens_raw = token_data + pos + 1;
// Map targets to compact vocab IDs
uint16_t ctargets[SEQ];
for (int t = 0; t < SEQ; t++) ctargets[t] = (uint16_t)vm.full_to_compact[target_tokens_raw[t]];
// Embedding lookup (uses full embed for now input tokens are full IDs)
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
// Timing accumulators (reset each step)
double t_rms=0, t_ane_fwd=0, t_io_fwd=0, t_cblas_wait=0;
double t_ane_bwd=0, t_io_bwd=0, t_silu=0, t_rms_bwd=0, t_cls=0, t_dw_copy=0;
// ===== FORWARD (12 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
// RMSNorm1 (CPU)
t0 = mach_absolute_time();
rmsnorm(xnorm_buf, x_cur, lw[L].rms_att, DIM, SEQ);
memcpy(ac->xnorm, xnorm_buf, SEQ*DIM*4);
t_rms += tb_ms(mach_absolute_time() - t0);
// Wait for any pending dW cblas
t0 = mach_absolute_time();
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t_cblas_wait += tb_ms(mach_absolute_time() - t0);
// SDPA forward (ANE): xnorm + Wq,Wk,Wv,Wo o_out,Q,K,V,attn_out,xnorm
t0 = mach_absolute_time();
write_sdpa_fwd_input(&dk, xnorm_buf, Wqt_buf[L], Wkt_buf[L], Wvt_buf[L], Wot_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaFwd);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read output: [1, 6*DIM, 1, SEQ] fp32
t0 = mach_absolute_time();
IOSurfaceLock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
float *fwd_out = (float*)IOSurfaceGetBaseAddress(dk.sdpaFwd->ioOut);
memcpy(ac->o_out, fwd_out + 0*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->Q, fwd_out + 1*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->K, fwd_out + 2*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->V, fwd_out + 3*DIM*SEQ, DIM*SEQ*4);
memcpy(ac->attn_out, fwd_out + 4*DIM*SEQ, DIM*SEQ*4);
IOSurfaceUnlock(dk.sdpaFwd->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x2 = x_cur + o_out
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
// RMSNorm2 (CPU)
t0 = mach_absolute_time();
rmsnorm(xnorm_buf, ac->x2, lw[L].rms_ffn, DIM, SEQ);
memcpy(ac->x2norm, xnorm_buf, SEQ*DIM*4);
t_rms += tb_ms(mach_absolute_time() - t0);
// FFN W1+W3 (ANE): xnorm h1, h3, gate
t0 = mach_absolute_time();
write_ffn_w13_input(&dk, xnorm_buf, W1t_buf[L], W3t_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnW13);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
// Read h1, h3, gate from output [1, 3*HIDDEN, 1, SEQ]
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL);
float *ffn13_out = (float*)IOSurfaceGetBaseAddress(dk.ffnW13->ioOut);
memcpy(ac->h1, ffn13_out, HIDDEN*SEQ*4);
memcpy(ac->h3, ffn13_out + HIDDEN*SEQ, HIDDEN*SEQ*4);
memcpy(gate_buf, ffn13_out + 2*HIDDEN*SEQ, HIDDEN*SEQ*4);
memcpy(ac->silu_out, gate_buf, HIDDEN*SEQ*4);
IOSurfaceUnlock(dk.ffnW13->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// FFN W2 (ANE): gate @ W2 ffn_out
t0 = mach_absolute_time();
write_ffn_w2_input(&dk, gate_buf, W2t_buf[L]);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnW2);
t_ane_fwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
IOSurfaceLock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
memcpy(ac->ffn_out, (float*)IOSurfaceGetBaseAddress(dk.ffnW2->ioOut), DIM*SEQ*4);
IOSurfaceUnlock(dk.ffnW2->ioOut, kIOSurfaceLockReadOnly, NULL);
t_io_fwd += tb_ms(mach_absolute_time() - t0);
// Residual: x_cur = x2 + ffn_out
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
}
// Final RMSNorm + classifier + loss (CPU)
t0 = mach_absolute_time();
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t_rms += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
// Classifier: logits[CV, SEQ] = cembed[CV, DIM] @ x_final[DIM, SEQ]
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
CV, SEQ, DIM, 1.0f, cembed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
float loss = cross_entropy_loss(dlogits, logits, ctargets, CV, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
last_loss = loss;
// ===== BACKWARD =====
// Classifier backward: dy[DIM, SEQ] = cembed^T[DIM, CV] @ dlogits[CV, SEQ]
t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, CV, 1.0f, cembed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
t_cls += tb_ms(mach_absolute_time() - t0);
// dEmbed async: gcembed[CV, DIM] += dlogits[CV, SEQ] @ x_final^T[SEQ, DIM]
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
CV, DIM, SEQ, 1.0f, dlogits, SEQ, x_final, SEQ, 1.0f, gcembed, DIM);
});
// Final RMSNorm backward
float *dx_rms_final = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ);
memcpy(dy, dx_rms_final, SEQ*DIM*4);
free(dx_rms_final);
// ===== BACKWARD (12 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward: dffn @ W2^T dsilu_raw
t0 = mach_absolute_time();
io_write_dyn(dk.ffnBwdW2t->ioIn, dffn, DIM, SEQ, lw[L].W2, HIDDEN);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW2t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW2t->ioOut, dsilu, HIDDEN, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// SiLU derivative (vectorized): dsilu dh1, dh3
// silu(h1) = h1*sig(h1), dsilu_dh1 = sig*(1+h1*(1-sig))
// dh1 = dsilu * h3 * dsilu_dh1, dh3 = dsilu * silu(h1)
t0 = mach_absolute_time();
{
int n = HIDDEN*SEQ;
// sig = 1/(1+exp(-h1))
float minus1 = -1.0f, one = 1.0f;
vDSP_vsmul(ac->h1, 1, &minus1, silu_tmp, 1, (vDSP_Length)n);
vvexpf(silu_tmp, silu_tmp, &n);
vDSP_vsadd(silu_tmp, 1, &one, silu_tmp, 1, (vDSP_Length)n);
vvrecf(silu_tmp, silu_tmp, &n); // silu_tmp = sig
// dh3 = dsilu * h1 * sig (= dsilu * silu(h1))
vDSP_vmul(ac->h1, 1, silu_tmp, 1, dh3, 1, (vDSP_Length)n);
vDSP_vmul(dsilu, 1, dh3, 1, dh3, 1, (vDSP_Length)n);
// dsilu_dh1 = sig*(1+h1*(1-sig)), store in silu_tmp2
vDSP_vsadd(silu_tmp, 1, &minus1, silu_tmp2, 1, (vDSP_Length)n); // sig-1
vDSP_vneg(silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // 1-sig
vDSP_vmul(ac->h1, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // h1*(1-sig)
vDSP_vsadd(silu_tmp2, 1, &one, silu_tmp2, 1, (vDSP_Length)n); // 1+h1*(1-sig)
vDSP_vmul(silu_tmp, 1, silu_tmp2, 1, silu_tmp2, 1, (vDSP_Length)n); // full dsilu_dh1
// dh1 = dsilu * h3 * dsilu_dh1
vDSP_vmul(dsilu, 1, ac->h3, 1, dh1, 1, (vDSP_Length)n);
vDSP_vmul(dh1, 1, silu_tmp2, 1, dh1, 1, (vDSP_Length)n);
}
t_silu += tb_ms(mach_absolute_time() - t0);
// dh1@W1^T + dh3@W3^T dx_ffn (ANE)
t0 = mach_absolute_time();
{
IOSurfaceLock(dk.ffnBwdW13t->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.ffnBwdW13t->ioIn);
int sp = 2*SEQ + 2*DIM;
for (int d = 0; d < HIDDEN; d++) {
memcpy(buf + d*sp, dh1 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dh3 + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, lw[L].W1 + d*DIM, DIM*4);
memcpy(buf + d*sp + 2*SEQ + DIM, lw[L].W3 + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.ffnBwdW13t->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.ffnBwdW13t);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.ffnBwdW13t->ioOut, dx_ffn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dW FFN async (cblas)
t0 = mach_absolute_time();
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4);
float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4);
float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM);
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
});
// RMSNorm2 backward
t0 = mach_absolute_time();
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
// Wo^T backward (ANE): dx2 @ Wo^T da
t0 = mach_absolute_time();
io_write_dyn(dk.wotBwd->ioIn, dx2, DIM, SEQ, lw[L].Wo, DIM);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.wotBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
float *da_buf = (float*)malloc(SEQ*DIM*4);
io_read_dyn(dk.wotBwd->ioOut, da_buf, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWo async
t0 = mach_absolute_time();
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, dx2, SEQ*DIM*4);
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
free(capt_do); free(capt_attn);
});
// SDPA backward part 1 (ANE, fp16): Q,K,V,da dV,probs,dp
t0 = mach_absolute_time();
io_write_fp16_at(dk.sdpaBwd1->ioIn, 0, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, DIM, ac->K, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 2*DIM, ac->V, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd1->ioIn, 3*DIM, da_buf, DIM, SEQ);
free(da_buf);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd1);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
// SDPA backward part 2: probs,dp,Q,K dQ,dK
t0 = mach_absolute_time();
io_copy(dk.sdpaBwd2->ioIn, 0, dk.sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH, ac->Q, DIM, SEQ);
io_write_fp16_at(dk.sdpaBwd2->ioIn, 2*SCORE_CH+DIM, ac->K, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.sdpaBwd2);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_fp16(dk.sdpaBwd2->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(dk.sdpaBwd2->ioOut, dk_buf, DIM, DIM, SEQ);
io_read_fp16(dk.sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// dWq/dWk/dWv async
t0 = mach_absolute_time();
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk_buf, SEQ*DIM*4);
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
t_dw_copy += tb_ms(mach_absolute_time() - t0);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
});
// QKV backward (ANE): dq,dk,dv @ Wq^T,Wk^T,Wv^T dx_attn
t0 = mach_absolute_time();
{
IOSurfaceLock(dk.qkvBwd->ioIn, 0, NULL);
float *buf = (float*)IOSurfaceGetBaseAddress(dk.qkvBwd->ioIn);
int sp = 3*SEQ + 3*DIM;
for (int d = 0; d < DIM; d++) {
memcpy(buf + d*sp, dq + d*SEQ, SEQ*4);
memcpy(buf + d*sp + SEQ, dk_buf + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 2*SEQ, dv + d*SEQ, SEQ*4);
memcpy(buf + d*sp + 3*SEQ, lw[L].Wq + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+DIM, lw[L].Wk + d*DIM, DIM*4);
memcpy(buf + d*sp + 3*SEQ+2*DIM, lw[L].Wv + d*DIM, DIM*4);
}
IOSurfaceUnlock(dk.qkvBwd->ioIn, 0, NULL);
}
t_io_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
ane_eval(dk.qkvBwd);
t_ane_bwd += tb_ms(mach_absolute_time() - t0);
t0 = mach_absolute_time();
io_read_dyn(dk.qkvBwd->ioOut, dx_attn, DIM, SEQ);
t_io_bwd += tb_ms(mach_absolute_time() - t0);
// RMSNorm1 backward
t0 = mach_absolute_time();
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms1[i] + dx2[i];
free(dx_rms1);
t_rms_bwd += tb_ms(mach_absolute_time() - t0);
}
// Embedding backward
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
embed_backward(gembed, dy, input_tokens, DIM, SEQ);
double step_ms = tb_ms(mach_absolute_time() - t_step);
total_train_ms += step_ms;
total_steps_done++;
if (step % 10 == 0 || step == start_step) {
printf(" timing: ane_fwd=%.1f io_fwd=%.1f rms=%.1f ane_bwd=%.1f io_bwd=%.1f silu=%.1f rms_bwd=%.1f cls=%.1f cblas_wait=%.1f dw_copy=%.1f\n",
t_ane_fwd, t_io_fwd, t_rms, t_ane_bwd, t_io_bwd, t_silu, t_rms_bwd, t_cls, t_cblas_wait, t_dw_copy);
float xmx, xmn;
vDSP_maxv(x_cur,1,&xmx,(vDSP_Length)(SEQ*DIM));
vDSP_minv(x_cur,1,&xmn,(vDSP_Length)(SEQ*DIM));
float dmx, dmn;
vDSP_maxv(dy,1,&dmx,(vDSP_Length)(SEQ*DIM));
vDSP_minv(dy,1,&dmn,(vDSP_Length)(SEQ*DIM));
printf("step %-4d loss=%.4f lr=%.2e %.1fms/step x[%.2f,%.2f] dy[%.3e,%.3e]\n",
step, loss, lr, step_ms, xmn, xmx, dmn, dmx);
}
// Adam update every accum_steps
if ((step+1) % accum_steps == 0 || step == total_steps-1) {
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
float gsc = 1.0f / accum_steps;
adam_t++;
// Scale gradients by 1/accum_steps
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
// Merge compact classifier grads into full embed grads
vocab_scatter_grads(gembed, gcembed, &vm, DIM);
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
// Global gradient norm
float grad_norm_sq = 0;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
float s;
vDSP_dotpr(g->Wq,1,g->Wq,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wk,1,g->Wk,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wv,1,g->Wv,1,&s,(vDSP_Length)WQ_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->Wo,1,g->Wo,1,&s,(vDSP_Length)WO_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W1,1,g->W1,1,&s,(vDSP_Length)W1_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W2,1,g->W2,1,&s,(vDSP_Length)W2_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->W3,1,g->W3,1,&s,(vDSP_Length)W3_SZ); grad_norm_sq+=s;
vDSP_dotpr(g->rms_att,1,g->rms_att,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
vDSP_dotpr(g->rms_ffn,1,g->rms_ffn,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
}
{ float s;
vDSP_dotpr(grms_final,1,grms_final,1,&s,(vDSP_Length)DIM); grad_norm_sq+=s;
vDSP_dotpr(gembed,1,gembed,1,&s,(vDSP_Length)(VOCAB*DIM)); grad_norm_sq+=s;
}
float grad_norm = sqrtf(grad_norm_sq);
if ((step+1) % 10 == 0) printf(" grad_norm=%.4f\n", grad_norm);
// Gradient clipping
if (grad_clip > 0 && grad_norm > grad_clip) {
float clip_scale = grad_clip / grad_norm;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
vDSP_vsmul(g->Wq,1,&clip_scale,g->Wq,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wk,1,&clip_scale,g->Wk,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wv,1,&clip_scale,g->Wv,1,(vDSP_Length)WQ_SZ);
vDSP_vsmul(g->Wo,1,&clip_scale,g->Wo,1,(vDSP_Length)WO_SZ);
vDSP_vsmul(g->W1,1,&clip_scale,g->W1,1,(vDSP_Length)W1_SZ);
vDSP_vsmul(g->W2,1,&clip_scale,g->W2,1,(vDSP_Length)W2_SZ);
vDSP_vsmul(g->W3,1,&clip_scale,g->W3,1,(vDSP_Length)W3_SZ);
vDSP_vsmul(g->rms_att,1,&clip_scale,g->rms_att,1,(vDSP_Length)DIM);
vDSP_vsmul(g->rms_ffn,1,&clip_scale,g->rms_ffn,1,(vDSP_Length)DIM);
}
vDSP_vsmul(grms_final,1,&clip_scale,grms_final,1,(vDSP_Length)DIM);
vDSP_vsmul(gembed,1,&clip_scale,gembed,1,(vDSP_Length)(VOCAB*DIM));
}
// Cosine LR schedule with warmup
if (step < warmup_steps) {
lr = max_lr * ((float)(step + 1)) / warmup_steps;
} else {
float decay_ratio = (float)(step - warmup_steps) / (float)(total_steps - warmup_steps);
float min_lr = max_lr * min_lr_frac;
lr = min_lr + 0.5f * (1.0f + cosf(M_PI * decay_ratio)) * (max_lr - min_lr);
}
// Adam update
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Update transposed weight buffers
transpose_weight(Wqt_buf[L], lw[L].Wq, DIM, DIM);
transpose_weight(Wkt_buf[L], lw[L].Wk, DIM, DIM);
transpose_weight(Wvt_buf[L], lw[L].Wv, DIM, DIM);
transpose_weight(Wot_buf[L], lw[L].Wo, DIM, DIM);
transpose_weight(W1t_buf[L], lw[L].W1, HIDDEN, DIM);
transpose_weight(W2t_buf[L], lw[L].W2, DIM, HIDDEN);
transpose_weight(W3t_buf[L], lw[L].W3, HIDDEN, DIM);
}
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Re-extract compact embed from updated full embed
free(cembed);
cembed = vocab_compact_embed(embed, &vm, DIM);
// Zero grads
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
memset(grms_final, 0, DIM*4);
memset(gembed, 0, (size_t)VOCAB*DIM*4);
memset(gcembed, 0, (size_t)CV*DIM*4);
// Checkpoint
if ((step+1) % 100 == 0) {
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step+1, total_steps, lr, last_loss,
total_train_ms+cum_train, wall+cum_wall, total_steps_done+cum_steps, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
}
}
}
// Report
double wall = tb_ms(mach_absolute_time() - t_wall_start);
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Compile: %.0fms (one-time, %.1f%%)\n", compile_ms, 100*compile_ms/(wall+cum_wall));
printf("Train time: %.0fms (%.1fms/step)\n", total_train_ms, total_train_ms/total_steps_done);
printf("Wall time: %.1fs\n", (wall+cum_wall)/1000);
// Cleanup
for (int L=0; L<NLAYERS; L++) {
layer_weights_free(&lw[L]); layer_adam_free(&la[L]);
layer_acts_free(&acts[L]); layer_grads_free(&grads[L]);
free(Wqt_buf[L]); free(Wkt_buf[L]); free(Wvt_buf[L]); free(Wot_buf[L]);
free(W1t_buf[L]); free(W2t_buf[L]); free(W3t_buf[L]);
}
free_kern(dk.sdpaFwd); free_kern(dk.ffnW13); free_kern(dk.ffnW2);
free_kern(dk.ffnBwdW2t); free_kern(dk.ffnBwdW13t); free_kern(dk.wotBwd);
free_kern(dk.sdpaBwd1); free_kern(dk.sdpaBwd2); free_kern(dk.qkvBwd);
munmap(token_data, data_len); close(data_fd);
}
return 0;
}