stories110M: 12-layer ANE training with dashboard, 107ms/step

- Scale to full stories110M (109M params, 12 layers) with real TinyStories data
- vDSP-vectorized cross-entropy (110ms→14ms), NEON fp16 IO, async dW
- TUI dashboard: loss curve, ANE/CPU power, CPU/memory graphs, text generation
- Split into modular headers: config, io, mil, cpu_ops
This commit is contained in:
maderix 2026-03-01 03:14:39 -08:00
parent f213c8db68
commit 4d67db1bdb
10 changed files with 2279 additions and 887 deletions

View File

@ -3,10 +3,18 @@ CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc
FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface
LDFLAGS = $(FRAMEWORKS) -ldl
HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h
train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS)
clean:
rm -f train
train_large: train_large.m $(HEADERS_LARGE)
$(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate
.PHONY: clean
tokenize:
python3 tokenize.py
clean:
rm -f train train_large
.PHONY: clean tokenize

69
training/README.md Normal file
View File

@ -0,0 +1,69 @@
# ANE Training — Stories110M on Apple Neural Engine
Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly on Apple's Neural Engine using private ANE APIs.
![Dashboard](dashboard.gif)
## Architecture
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
- **109.53M params** (84.95M transformer + 24.58M embedding)
- **72 ANE kernels** per compile (60 weight-bearing, 12 weight-free sdpaBwd2)
- **6 kernel types per layer**: fwdAttn, fwdFFN, ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd
## Performance
| Component | Time (ms/step) |
|-----------|---------------|
| ANE eval | 9.6 |
| IO (fp16 conversion) | 4.1 |
| Classifier (cblas) | 9.1 |
| Cross-entropy + residuals | 14.4 |
| RMSNorm | 0.1 |
| **Total** | **107 ms/step** |
## Files
| File | Description |
|------|-------------|
| `train_large.m` | Main training loop — 12-layer forward/backward, checkpoint, exec() restart |
| `stories_config.h` | Model config, structs, alloc helpers |
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
| `stories_mil.h` | MIL program generators for all 6 ANE kernel types |
| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam, embedding ops |
| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs, text generation |
| `tokenize.py` | Extract pretokenized TinyStories data |
| `Makefile` | Build targets |
## How it works
1. **Forward pass**: Each layer runs fwdAttn (QKV + SDPA + Wo) and fwdFFN (W1 + SiLU(W3) + W2) on ANE via MIL-compiled kernels. Final RMSNorm + classifier matmul on CPU (cblas).
2. **Backward pass**: Reverse layer order. ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd on ANE. Weight gradients (dW) via async cblas_sgemm on CPU. RMSNorm backward via vDSP.
3. **Compile budget**: ANE has a ~119 compile limit per process. With 72 kernels per batch, we run 10 accumulation steps then `exec()` restart with checkpoint resume.
4. **Data**: Real TinyStories text (20M tokens), mmap'd uint16 token IDs, random position sampling per step.
## Usage
```bash
# Extract tokenized data
python3 tokenize.py
# Build and train
make train_large
./train_large # fresh start
./train_large --resume # resume from checkpoint
# Monitor with dashboard
pip install blessed psutil numpy
python3 dashboard.py --resume # needs sudo for powermetrics
```
## Key techniques
- **NEON vectorized fp16<->fp32**: ARM NEON intrinsics for fast IOSurface data transfer
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask, so we decompose attention into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)

BIN
training/dashboard.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 232 KiB

882
training/dashboard.py Normal file
View File

@ -0,0 +1,882 @@
"""TUI dashboard for ANE training (train_large). Uses blessed for terminal UI."""
import argparse, fcntl, math, os, re, select, signal, struct, subprocess, sys, time, threading
from collections import deque
from pathlib import Path
import numpy as np
try:
from blessed import Terminal
except ImportError:
print('pip install blessed')
sys.exit(1)
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
HD = DIM // HEADS
CKPT_PATH = 'ane_stories110M_ckpt.bin'
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent.parent / 'assets' / 'models' / 'tokenizer.bin')
class State:
def __init__(self):
self.model_config = {}
self.params = {}
self.kernels = {}
self.training = {}
self.flops = {}
self.step = 0
self.total_steps = 0
self.loss = 0.0
self.best_loss = float('inf')
self.loss_history = []
self.ms_per_step = 0.0
self.compile_pct = 0.0
self.compiles = 0
self.component_timing = {}
self.power = {'ane': 0.0, 'cpu': 0.0, 'gpu': 0.0}
self.power_history_ane = deque(maxlen=300)
self.power_history_cpu = deque(maxlen=300)
self.logs = deque(maxlen=2000)
self.log_scroll = 0
self.auto_scroll = True
self.batch_num = 0
self.efficiency = {}
self.gen_text = ''
self.gen_step = 0
self.gen_status = 'idle'
self.gen_lock = threading.Lock()
self.cpu_pct_history = deque(maxlen=300)
self.mem_mb_history = deque(maxlen=300)
self.proc_mem_mb_history = deque(maxlen=300)
self.train_pid = None
S = State()
class Tokenizer:
def __init__(self, path):
self.vocab = []
self.scores = []
with open(path, 'rb') as f:
max_len = struct.unpack('i', f.read(4))[0]
for _ in range(VOCAB):
score = struct.unpack('f', f.read(4))[0]
slen = struct.unpack('i', f.read(4))[0]
tok = f.read(slen).decode('utf-8', errors='replace')
self.vocab.append(tok)
self.scores.append(score)
def decode(self, token_id):
if 0 <= token_id < len(self.vocab):
s = self.vocab[token_id]
if s.startswith('<0x') and s.endswith('>'):
try:
return chr(int(s[3:-1], 16))
except:
return s
return s
return ''
_tokenizer = None
def get_tokenizer():
global _tokenizer
if _tokenizer is None:
try:
_tokenizer = Tokenizer(TOKENIZER_PATH)
except Exception as e:
S.logs.append(f'[gen] tokenizer load failed: {e}')
return None
return _tokenizer
def load_weights_from_ckpt(path):
try:
with open(path, 'rb') as f:
# CkptHdr: 96 bytes (verified with sizeof)
hdr = f.read(96)
if len(hdr) < 96:
return None
wq_sz = DIM * DIM
wo_sz = DIM * DIM
w1_sz = HIDDEN * DIM
w2_sz = DIM * HIDDEN
w3_sz = HIDDEN * DIM
# Per-layer: weights + adam state (m,v for each)
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
W = {}
for L in range(NLAYERS):
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
# Skip adam state for this layer
f.seek(adam_per_layer * 4, 1)
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
f.seek(DIM * 2 * 4, 1) # skip rms_final adam
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
return W
except Exception as e:
S.logs.append(f'[gen] ckpt load failed: {e}')
return None
def rmsnorm(x, w):
ss = np.mean(x * x) + 1e-5
return x * (1.0 / math.sqrt(ss)) * w
def softmax(x):
x = x - np.max(x)
e = np.exp(x)
return e / np.sum(e)
def generate_text(W, tok, max_tokens=64, temperature=0.8):
tokenizer = get_tokenizer()
if tokenizer is None:
return '[no tokenizer]'
tokens = [1]
text_parts = []
# Precompute RoPE frequencies
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
for pos in range(SEQ):
for i in range(HD // 2):
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
freqs[pos, i] = pos * freq
for step in range(max_tokens):
seq_len = len(tokens)
if seq_len > SEQ:
break
x = W['embed'][tokens[-1]].copy()
for L in range(NLAYERS):
# RMSNorm + QKV
xn = rmsnorm(x, W[f'rms1_{L}'])
q = W[f'Wq{L}'] @ xn
k = W[f'Wk{L}'] @ xn
v = W[f'Wv{L}'] @ xn
# RoPE
pos = seq_len - 1
for h in range(HEADS):
for i in range(HD // 2):
freq = freqs[pos, i]
cos_v, sin_v = math.cos(freq), math.sin(freq)
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
# Attention (single token)
o = np.zeros(DIM, dtype=np.float32)
for h in range(HEADS):
qh = q[h * HD:(h + 1) * HD]
kh = k[h * HD:(h + 1) * HD]
vh = v[h * HD:(h + 1) * HD]
score = np.dot(qh, kh) / math.sqrt(HD)
o[h * HD:(h + 1) * HD] = vh
# Residual + output projection
x2 = x + W[f'Wo{L}'] @ o
# FFN
x2n = rmsnorm(x2, W[f'rms2_{L}'])
h1 = W[f'W1_{L}'] @ x2n
h3 = W[f'W3_{L}'] @ x2n
# SiLU
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
ffn_out = W[f'W2_{L}'] @ h1
x = x2 + ffn_out
x = rmsnorm(x, W['rms_final'])
# Logits
logits = W['embed'] @ x
if temperature < 0.01:
next_tok = int(np.argmax(logits))
else:
logits = logits / temperature
probs = softmax(logits)
next_tok = int(np.random.choice(VOCAB, p=probs))
if next_tok == 2:
break
tokens.append(next_tok)
piece = tokenizer.decode(next_tok)
text_parts.append(piece)
return ''.join(text_parts)
def generation_thread():
last_gen_step = -1
while True:
time.sleep(5)
if S.step <= last_gen_step + 99:
continue
if not os.path.exists(CKPT_PATH):
continue
with S.gen_lock:
S.gen_status = 'generating'
S.gen_step = S.step
try:
W = load_weights_from_ckpt(CKPT_PATH)
if W is None:
with S.gen_lock:
S.gen_status = 'idle'
continue
text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8)
with S.gen_lock:
S.gen_text = text
S.gen_step = S.step
S.gen_status = 'done'
S.step # just to reference
except Exception as e:
with S.gen_lock:
S.gen_text = f'[error: {e}]'
S.gen_status = 'done'
last_gen_step = S.step
def sysmetrics_thread():
while True:
time.sleep(1)
if not HAS_PSUTIL:
continue
now = time.monotonic()
S.cpu_pct_history.append(psutil.cpu_percent(interval=None))
mem = psutil.virtual_memory()
S.mem_mb_history.append(mem.used / (1024 * 1024))
pid = S.train_pid
if pid:
try:
p = psutil.Process(pid)
S.proc_mem_mb_history.append(p.memory_info().rss / (1024 * 1024))
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)')
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)')
RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]')
RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)')
RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)')
RE_RESUME = re.compile(r'\[RESUMED step (\d+), loss=([\d.]+)\]')
RE_FLOPS = re.compile(r'FLOPs/step: fwd=([\d.]+)M bwd_dx=([\d.]+)M bwd_dW=([\d.]+)M sdpa_bwd=([\d.]+)M total=([\d.]+)M')
RE_ANE_FLOPS = re.compile(r'ANE FLOPs/step: ([\d.]+)M')
RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
def parse_line(line):
S.logs.append(line)
m = RE_CONFIG.search(line)
if m:
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
return
m = RE_PARAMS.search(line)
if m:
S.params = {'total': float(m[1]), 'transformer': float(m[2]), 'embed': float(m[3])}
return
m = RE_KERNELS.search(line)
if m:
S.kernels = {'total': int(m[1]), 'weight_bearing': int(m[2])}
return
m = RE_ACCUM.search(line)
if m:
S.training = {'accum': int(m[1]), 'lr': m[2]}
return
m = RE_FLOPS.search(line)
if m:
S.flops.update(fwd=float(m[1]), bwd_dx=float(m[2]), bwd_dw=float(m[3]),
sdpa_bwd=float(m[4]), total=float(m[5]))
return
m = RE_ANE_FLOPS.search(line)
if m:
S.flops['ane'] = float(m[1])
return
m = RE_STEP.search(line)
if m:
S.step, S.loss = int(m[1]), float(m[2])
S.loss_history.append((S.step, S.loss))
S.best_loss = min(S.best_loss, S.loss)
return
m = RE_BATCH.search(line)
if m:
S.batch_num = int(m[1])
compile_ms, train_ms = float(m[2]), float(m[3])
S.ms_per_step = float(m[4])
S.compiles = int(m[5])
S.compile_pct = 100 * compile_ms / (compile_ms + train_ms) if compile_ms + train_ms > 0 else 0
return
m = RE_TIMING.search(line)
if m:
S.component_timing = dict(zip(['ane', 'io', 'cls', 'elem', 'rms', 'cblas_wait'], map(float, m.groups())))
return
m = RE_ANE_TFLOPS.search(line)
if m:
S.flops['ane_tflops'] = float(m[1])
return
m = RE_ANE_UTIL.search(line)
if m:
S.flops['ane_util'] = float(m[1])
return
m = RE_EFFICIENCY.search(line)
if m:
S.efficiency[m[1].strip()] = m[2].strip()
return
def parse_powermetrics_text(text):
now = time.monotonic()
m = RE_ANE_POWER.search(text)
if m:
S.power['ane'] = float(m[1]) / 1000.0
S.power_history_ane.append((now, S.power['ane']))
m = RE_CPU_POWER.search(text)
if m:
S.power['cpu'] = float(m[1]) / 1000.0
S.power_history_cpu.append((now, S.power['cpu']))
m = RE_GPU_POWER.search(text)
if m:
S.power['gpu'] = float(m[1]) / 1000.0
BRAILLE_BASE = 0x2800
BRAILLE_MAP = [
[1, 8],
[2, 16],
[4, 32],
[64, 128],
]
def braille_chart(values, width, height, label_fmt='{:.1f}', y_range=None):
if not values or width < 8 or height < 2:
return ['(no data)'] * max(1, height)
chart_w = width - 6
if chart_w < 2:
return ['(no data)'] * max(1, height)
points_x = chart_w * 2
points_y = height * 4
data = values[-points_x:] if len(values) > points_x else values
lo, hi = min(data), max(data)
if y_range:
lo, hi = y_range
if hi - lo < 0.001:
lo, hi = lo - 0.5, hi + 0.5
margin = (hi - lo) * 0.05
lo -= margin
hi += margin
grid = [[0] * chart_w for _ in range(height)]
def plot(px, py):
px = max(0, min(points_x - 1, px))
py = max(0, min(points_y - 1, py))
grid[py // 4][px // 2] |= BRAILLE_MAP[py % 4][px % 2]
def val_to_y(v):
return int((1 - (v - lo) / (hi - lo)) * (points_y - 1))
for i in range(len(data)):
if i >= points_x:
break
y0 = val_to_y(data[i])
plot(i, y0)
if i > 0:
y_prev = val_to_y(data[i - 1])
y_lo, y_hi = min(y_prev, y0), max(y_prev, y0)
for yy in range(y_lo, y_hi + 1):
if y_hi != y_lo:
t = (yy - y_prev) / (y0 - y_prev)
xx = int(i - 1 + t)
else:
xx = i
plot(xx, yy)
lines = []
for r in range(height):
if r == 0:
label = label_fmt.format(hi)[:5].rjust(5)
elif r == height - 1:
label = label_fmt.format(lo)[:5].rjust(5)
elif r == height // 2:
label = label_fmt.format((hi + lo) / 2)[:5].rjust(5)
else:
label = ' '
row_str = ''.join(chr(BRAILLE_BASE | grid[r][c]) for c in range(chart_w))
lines.append(f'{label}\u2502{row_str}')
lines.append(' \u2514' + '\u2500' * chart_w)
return lines
def draw(term):
w, h = term.width, term.height
if w < 40 or h < 15:
print(term.home + term.clear + 'Terminal too small', end='', flush=True)
return
buf = []
def put(y, x, text, style=''):
if 0 <= y < h and x < w:
text = text[:w - x]
if style:
buf.append(term.move(y, x) + style + text + term.normal)
return
buf.append(term.move(y, x) + text)
buf.append(term.home + term.clear)
mid_x = w // 2
right_w = w - mid_x - 1
left_w = mid_x - 1
row = 0
# Model Config header
hdr = '\u2500 Model Config '
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan)
row += 1
cfg = S.model_config
if cfg:
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line1)
put(row, w - 1, '\u2502', term.cyan)
row += 1
p, k, t = S.params, S.kernels, S.training
line2 = f"{p.get('total', '?')}M params ({p.get('transformer', '?')}M xfmr + {p.get('embed', '?')}M embed)"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line2)
put(row, w - 1, '\u2502', term.cyan)
row += 1
line3 = f"{k.get('total', '?')} kernels ({k.get('weight_bearing', '?')} wt-bearing) | Accum {t.get('accum', '?')} | Adam LR={t.get('lr', '?')}"
put(row, 0, '\u2502', term.cyan)
put(row, 2, line3)
put(row, w - 1, '\u2502', term.cyan)
row += 1
else:
put(row, 0, '\u2502', term.cyan)
put(row, 2, 'Waiting for model config...')
put(row, w - 1, '\u2502', term.cyan)
row += 1
remaining = h - row - 1
# Allocate: loss curve ~40%, logs ~30%, power/cpu/mem/gen share rest
power_h = max(3, remaining // 8)
gen_h = max(2, remaining // 10)
extra_panels = power_h + power_h + gen_h + 6 # power + cpu/mem + gen + dividers
log_h_min = max(5, remaining // 5)
curve_h = max(5, remaining - extra_panels - log_h_min)
# Loss Curve + Training Stats divider
put(row, 0, '\u251c\u2500 Loss Curve ' + '\u2500' * max(0, left_w - 13) + '\u252c\u2500 Training Stats ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan)
row += 1
# Loss curve
loss_vals = [l for _, l in S.loss_history]
curve_lines = braille_chart(loss_vals, left_w - 1, curve_h)
for i, cl in enumerate(curve_lines):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, 1, cl, term.green)
put(row + i, mid_x, '\u2502', term.cyan)
put(row + i, w - 1, '\u2502', term.cyan)
# Training stats (right panel)
sr = row
step_str = f'{S.step}' + (f'/{S.total_steps}' if S.total_steps and S.total_steps < 999999 else '')
put(sr, mid_x + 1, f' Step: {step_str} Loss: {S.loss:.4f}' if S.loss else ' Step: --', term.yellow)
sr += 1
put(sr, mid_x + 1, f' Best: {S.best_loss:.4f} ms/step: {S.ms_per_step:.1f}' if S.best_loss < float('inf') else ' Best: --')
sr += 1
ane_tflops = S.flops.get('ane_tflops', 0)
ane_util = S.flops.get('ane_util', 0)
if ane_tflops:
put(sr, mid_x + 1, f' ANE: {ane_tflops:.2f}T Compile: {S.compile_pct:.0f}% Util: {ane_util:.1f}%')
else:
put(sr, mid_x + 1, f' Compile: {S.compile_pct:.0f}%')
sr += 1
ct = S.component_timing
if ct:
put(sr, mid_x + 1, f' ane={ct.get("ane", 0):.1f} io={ct.get("io", 0):.1f} cls={ct.get("cls", 0):.1f} elem={ct.get("elem", 0):.1f}')
sr += 1
put(sr, mid_x + 1, f' rms={ct.get("rms", 0):.1f} cblas_wait={ct.get("cblas_wait", 0):.1f} ms/step')
sr += 1
pw = S.power
if any(pw.values()):
put(sr, mid_x + 1, '\u2500 Power ' + '\u2500' * max(0, right_w - 9), term.cyan)
sr += 1
put(sr, mid_x + 1, f' ANE: {pw["ane"]:.1f}W CPU: {pw["cpu"]:.1f}W GPU: {pw["gpu"]:.1f}W', term.magenta)
sr += 1
if S.batch_num:
put(sr, mid_x + 1, f' Batch: {S.batch_num} Compiles: {S.compiles}')
sr += 1
# Fill vertical borders between loss curve and stats
top_end = row + len(curve_lines)
for r in range(row, max(top_end, sr)):
if r >= top_end:
put(r, 0, '\u2502', term.cyan)
if r >= sr:
put(r, mid_x, '\u2502', term.cyan)
put(r, w - 1, '\u2502', term.cyan)
row = max(top_end, sr)
# Power charts
has_power = len(S.power_history_ane) > 1 or len(S.power_history_cpu) > 1
if has_power:
put(row, 0, '\u251c\u2500 ANE Power (W) ' + '\u2500' * max(0, left_w - 16) + '\u252c\u2500 CPU Power (W) ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan)
row += 1
ane_vals = [v for _, v in S.power_history_ane]
cpu_vals = [v for _, v in S.power_history_cpu]
ane_lines = braille_chart(ane_vals, left_w - 1, power_h, label_fmt='{:.1f}')
cpu_lines = braille_chart(cpu_vals, right_w - 1, power_h, label_fmt='{:.1f}')
max_lines = max(len(ane_lines), len(cpu_lines))
while len(ane_lines) < max_lines:
ane_lines.append(' ' * (left_w - 1))
while len(cpu_lines) < max_lines:
cpu_lines.append(' ' * (right_w - 1))
for i in range(max_lines):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, 1, ane_lines[i], term.red)
put(row + i, mid_x, '\u2502', term.cyan)
put(row + i, mid_x + 1, cpu_lines[i], term.blue)
put(row + i, w - 1, '\u2502', term.cyan)
row += max_lines
# CPU / Memory charts
has_sysmetrics = len(S.cpu_pct_history) > 0
if has_sysmetrics:
put(row, 0, '\u251c\u2500 CPU % ' + '\u2500' * max(0, left_w - 8) + '\u252c\u2500 Memory (MB) ' + '\u2500' * max(0, right_w - 15) + '\u2524', term.cyan)
row += 1
cpu_vals = list(S.cpu_pct_history)
mem_vals = list(S.proc_mem_mb_history) if S.proc_mem_mb_history else list(S.mem_mb_history)
mem_label = 'proc' if S.proc_mem_mb_history else 'sys'
cpu_lines = braille_chart(cpu_vals, left_w - 1, power_h, label_fmt='{:.0f}', y_range=(0, 100))
mem_lines = braille_chart(mem_vals, right_w - 1, power_h, label_fmt='{:.0f}')
max_lines = max(len(cpu_lines), len(mem_lines))
while len(cpu_lines) < max_lines:
cpu_lines.append(' ' * (left_w - 1))
while len(mem_lines) < max_lines:
mem_lines.append(' ' * (right_w - 1))
for i in range(max_lines):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, 1, cpu_lines[i], term.yellow)
put(row + i, mid_x, '\u2502', term.cyan)
put(row + i, mid_x + 1, mem_lines[i], term.magenta)
put(row + i, w - 1, '\u2502', term.cyan)
row += max_lines
# Generated text
with S.gen_lock:
gen_text = S.gen_text
gen_step = S.gen_step
gen_status = S.gen_status
if gen_text or gen_status == 'generating':
status_tag = ' (generating...)' if gen_status == 'generating' else f' (step {gen_step})'
put(row, 0, '\u251c\u2500 Generated Text' + status_tag + ' ' + '\u2500' * max(0, w - 20 - len(status_tag)) + '\u2524', term.cyan)
row += 1
if gen_text:
line_w = w - 3
text = gen_text.replace('\n', ' ')
wrapped = [text[i:i + line_w] for i in range(0, len(text), line_w)]
for i, tl in enumerate(wrapped[:gen_h]):
put(row, 0, '\u2502', term.cyan)
put(row, 2, tl, term.white)
put(row, w - 1, '\u2502', term.cyan)
row += 1
else:
put(row, 0, '\u2502', term.cyan)
put(row, 2, '...')
put(row, w - 1, '\u2502', term.cyan)
row += 1
# Logs
log_h = h - row - 1
scroll_hint = ' (scroll) ' if not S.auto_scroll else ' '
put(row, 0, '\u251c\u2500 Logs' + scroll_hint + '\u2500' * max(0, w - 8 - len(scroll_hint)) + '\u2524', term.cyan)
row += 1
logs = list(S.logs)
if log_h > 0 and logs:
if S.auto_scroll:
start = max(0, len(logs) - log_h)
else:
start = max(0, min(S.log_scroll, len(logs) - log_h))
visible = logs[start:start + log_h]
for i, line in enumerate(visible):
put(row + i, 0, '\u2502', term.cyan)
if RE_STEP.search(line):
put(row + i, 1, line[:w - 2], term.yellow)
elif line.strip().startswith('[batch'):
put(row + i, 1, line[:w - 2], term.blue)
elif 'FAIL' in line or 'error' in line.lower():
put(row + i, 1, line[:w - 2], term.red)
else:
put(row + i, 1, line[:w - 2])
put(row + i, w - 1, '\u2502', term.cyan)
for i in range(len(visible), log_h):
put(row + i, 0, '\u2502', term.cyan)
put(row + i, w - 1, '\u2502', term.cyan)
# Bottom border
put(h - 1, 0, '\u2514' + '\u2500' * (w - 2) + '\u2518', term.cyan)
sys.stdout.write(''.join(buf))
sys.stdout.flush()
def set_nonblock(fd):
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
def spawn_training(resume=False, steps=10000):
cmd = 'make train_large 2>&1 && ./train_large'
if resume:
cmd += ' --resume'
cmd += f' --steps {steps}'
proc = subprocess.Popen(
['bash', '-c', cmd],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
cwd=os.path.dirname(os.path.abspath(__file__)) or '.')
set_nonblock(proc.stdout.fileno())
return proc
def spawn_powermetrics():
try:
proc = subprocess.Popen(
['sudo', 'powermetrics', '--samplers', 'cpu_power,gpu_power,ane_power', '-i', '1000'],
stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
set_nonblock(proc.stdout.fileno())
return proc
except (FileNotFoundError, PermissionError):
return None
def main():
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
parser.add_argument('--infinite', action='store_true', help='Train indefinitely')
parser.add_argument('--no-powermetrics', action='store_true')
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
args = parser.parse_args()
if args.infinite:
args.steps = 999999999
S.total_steps = args.steps
term = Terminal()
procs = []
train_proc = spawn_training(resume=args.resume, steps=args.steps)
S.train_pid = train_proc.pid
procs.append(train_proc)
if HAS_PSUTIL:
psutil.cpu_percent(interval=None) # prime the counter
sys_t = threading.Thread(target=sysmetrics_thread, daemon=True)
sys_t.start()
pm_proc = None
if not args.no_powermetrics:
pm_proc = spawn_powermetrics()
if pm_proc:
procs.append(pm_proc)
if not args.no_generate:
gen_t = threading.Thread(target=generation_thread, daemon=True)
gen_t.start()
pm_buf = ''
train_buf = ''
def cleanup():
for p in procs:
try:
p.terminate()
except Exception:
pass
signal.signal(signal.SIGINT, lambda *a: cleanup())
signal.signal(signal.SIGTERM, lambda *a: cleanup())
resized = [False]
def on_resize(*a):
resized[0] = True
signal.signal(signal.SIGWINCH, on_resize)
with term.fullscreen(), term.cbreak(), term.hidden_cursor():
draw(term)
last_draw = time.monotonic()
while True:
fds = []
fd_map = {}
if train_proc and train_proc.stdout:
fd = train_proc.stdout.fileno()
fds.append(fd)
fd_map[fd] = 'train'
if pm_proc and pm_proc.stdout:
fd = pm_proc.stdout.fileno()
fds.append(fd)
fd_map[fd] = 'pm'
fds.append(sys.stdin.fileno())
fd_map[sys.stdin.fileno()] = 'stdin'
try:
readable, _, _ = select.select(fds, [], [], 0.25)
except (ValueError, OSError):
continue
need_draw = resized[0]
resized[0] = False
train_finished = False
for fd in readable:
kind = fd_map.get(fd)
if kind == 'train':
try:
data = os.read(fd, 65536)
except BlockingIOError:
continue
except (OSError, ValueError):
data = b''
if not data:
if train_proc.poll() is not None:
try:
rest = train_proc.stdout.read()
if rest:
for line in rest.decode('utf-8', errors='replace').split('\n'):
if line:
parse_line(line)
except Exception:
pass
S.logs.append('[dashboard] Training finished. Press q to exit.')
train_finished = True
continue
train_buf += data.decode('utf-8', errors='replace')
while '\n' in train_buf:
line, train_buf = train_buf.split('\n', 1)
parse_line(line)
need_draw = True
elif kind == 'pm':
try:
data = os.read(fd, 65536).decode('utf-8', errors='replace')
except BlockingIOError:
continue
except (OSError, ValueError):
data = ''
if not data:
continue
pm_buf += data
while '\n\n' in pm_buf or '*** ' in pm_buf:
end = pm_buf.find('\n*** ', 1)
if end < 0:
end = pm_buf.find('\n\n', 1)
if end < 0:
break
chunk = pm_buf[:end]
pm_buf = pm_buf[end:]
parse_powermetrics_text(chunk)
if len(pm_buf) > 16384:
pm_buf = pm_buf[-8192:]
need_draw = True
elif kind == 'stdin':
key = term.inkey(timeout=0)
if not key:
continue
if key == 'q':
cleanup()
return
elif key.name == 'KEY_UP':
S.auto_scroll = False
S.log_scroll = max(0, S.log_scroll - 1)
need_draw = True
elif key.name == 'KEY_DOWN':
S.log_scroll += 1
need_draw = True
elif key == 'p':
S.auto_scroll = not S.auto_scroll
if S.auto_scroll:
S.log_scroll = max(0, len(S.logs) - 10)
need_draw = True
elif key == 'r':
if train_proc:
train_proc.terminate()
train_proc.wait()
train_proc = spawn_training(resume=True, steps=args.steps)
S.train_pid = train_proc.pid
procs = [p for p in procs if p.poll() is None]
procs.append(train_proc)
S.logs.append('[dashboard] Restarted with --resume')
need_draw = True
elif key == 'g':
with S.gen_lock:
S.gen_status = 'generating'
S.gen_step = S.step
def force_gen():
try:
W = load_weights_from_ckpt(CKPT_PATH)
if W:
text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8)
with S.gen_lock:
S.gen_text = text
S.gen_step = S.step
S.gen_status = 'done'
except Exception as e:
with S.gen_lock:
S.gen_text = f'[error: {e}]'
S.gen_status = 'done'
threading.Thread(target=force_gen, daemon=True).start()
need_draw = True
now = time.monotonic()
if not need_draw and now - last_draw > 1.0:
need_draw = True
if need_draw and now - last_draw > 0.066:
draw(term)
last_draw = now
if train_finished:
draw(term)
while True:
key = term.inkey(timeout=1)
if key == 'q':
cleanup()
return
if __name__ == '__main__':
main()

189
training/stories_config.h Normal file
View File

@ -0,0 +1,189 @@
// stories_config.h — Stories110M model config and structures
#pragma once
#import <Foundation/Foundation.h>
#import <objc/runtime.h>
#import <objc/message.h>
#import <dlfcn.h>
#import <IOSurface/IOSurface.h>
#import <mach/mach_time.h>
#import <Accelerate/Accelerate.h>
#include <math.h>
#include <unistd.h>
#include <dispatch/dispatch.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <fcntl.h>
// Stories110M config
#define DIM 768
#define HIDDEN 2048
#define HEADS 12
#define HD (DIM/HEADS)
#define SEQ 256
#define NLAYERS 12
#define VOCAB 32000
#define ACCUM_STEPS 10
#define MAX_COMPILES 100
// Per compile: 5 weight-bearing kernels per layer + 1 classifier = 5*12+1 = 61
// Plus 1 static (sdpaBwd2 per layer, no weights) = 12 more but those are weight-free
// Actually sdpaBwd2 has no weights, compile once per layer
// Weight-bearing: fwdAttn(1) + fwdFFN(1) + ffnBwd(1) + sdpaBwd1(1) + qkvBwd(1) = 5 per layer
// 5 * 12 = 60 weight-bearing compiles per batch
// With MAX_COMPILES=100, we get 1 batch of ACCUM_STEPS before restart
#define KERNELS_PER_LAYER 5
#define TOTAL_WEIGHT_KERNELS (KERNELS_PER_LAYER * NLAYERS)
// Attention score channels for SDPA backward
#define SCORE_CH (HEADS*SEQ)
// Weight sizes per layer
#define WQ_SZ (DIM*DIM)
#define WO_SZ (DIM*DIM)
#define W1_SZ (HIDDEN*DIM)
#define W2_SZ (DIM*HIDDEN)
#define W3_SZ (HIDDEN*DIM)
#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
#define TOTAL_PARAMS (NLAYERS * LAYER_PARAMS + DIM + VOCAB*DIM) // +rms_final+embed
// Per-layer weight and optimizer state
typedef struct {
float *Wq, *Wk, *Wv, *Wo;
float *W1, *W2, *W3;
float *rms_att, *rms_ffn;
} LayerWeights;
typedef struct {
float *m, *v;
size_t n;
} AdamState;
typedef struct {
AdamState Wq, Wk, Wv, Wo;
AdamState W1, W2, W3;
AdamState rms_att, rms_ffn;
} LayerAdam;
// Per-layer activation buffers (saved for backward)
typedef struct {
float *layer_in; // [DIM, SEQ] input to this layer (for rmsnorm1 bwd)
float *xnorm; // [DIM, SEQ] rmsnorm1 output
float *Q, *K, *V; // [DIM, SEQ] QKV projections
float *attn_out; // [DIM, SEQ] attention output (before Wo)
float *o_out; // [DIM, SEQ] Wo output
float *x2; // [DIM, SEQ] residual after attn
float *x2norm; // [DIM, SEQ] rmsnorm2 output
float *h1, *h3; // [HIDDEN, SEQ] FFN intermediates
float *silu_out; // [HIDDEN, SEQ] SiLU(h1)*h3
float *ffn_out; // [DIM, SEQ] FFN output
} LayerActs;
// Per-layer gradient accumulators
typedef struct {
float *Wq, *Wk, *Wv, *Wo;
float *W1, *W2, *W3;
float *rms_att, *rms_ffn;
} LayerGrads;
// ANE kernels per layer
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
typedef struct {
Kern *fwdAttn, *fwdFFN, *ffnBwd, *sdpaBwd1, *sdpaBwd2, *qkvBwd;
} LayerKernels;
// Checkpoint header
typedef struct {
int magic; // 0x424C5A54 "BLZT"
int version; // 2
int step, total_steps;
int n_layers, vocab_size, dim, hidden_dim, n_heads, seq_len;
float lr, loss;
double cum_compile, cum_train, cum_wall;
int cum_steps, cum_batches;
int adam_t;
int pad[3]; // alignment
} CkptHdr;
// llama2.c model file header
typedef struct {
int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len;
} Llama2Config;
// Globals
static Class g_D, g_I, g_AR, g_AIO;
static mach_timebase_info_data_t g_tb;
static int g_compile_count = 0;
static void ane_init(void) {
dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor");
g_I = NSClassFromString(@"_ANEInMemoryModel");
g_AR = NSClassFromString(@"_ANERequest");
g_AIO= NSClassFromString(@"_ANEIOSurfaceObject");
}
static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; }
// Alloc helpers
static AdamState adam_alloc(size_t n) { AdamState s; s.m=(float*)calloc(n,4); s.v=(float*)calloc(n,4); s.n=n; return s; }
static void adam_free(AdamState *s) { free(s->m); free(s->v); }
static LayerWeights layer_weights_alloc(void) {
LayerWeights w;
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4);
w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4);
w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4);
return w;
}
static void layer_weights_free(LayerWeights *w) {
free(w->Wq);free(w->Wk);free(w->Wv);free(w->Wo);
free(w->W1);free(w->W2);free(w->W3);
free(w->rms_att);free(w->rms_ffn);
}
static LayerAdam layer_adam_alloc(void) {
LayerAdam a;
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ);
a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ);
a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM);
return a;
}
static void layer_adam_free(LayerAdam *a) {
adam_free(&a->Wq);adam_free(&a->Wk);adam_free(&a->Wv);adam_free(&a->Wo);
adam_free(&a->W1);adam_free(&a->W2);adam_free(&a->W3);
adam_free(&a->rms_att);adam_free(&a->rms_ffn);
}
static LayerActs layer_acts_alloc(void) {
LayerActs a;
a.layer_in=(float*)malloc(SEQ*DIM*4);
a.xnorm=(float*)malloc(SEQ*DIM*4); a.Q=(float*)malloc(SEQ*DIM*4);
a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4);
a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4);
a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4);
a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4);
return a;
}
static void layer_acts_free(LayerActs *a) {
free(a->layer_in);free(a->xnorm);free(a->Q);free(a->K);free(a->V);
free(a->attn_out);free(a->o_out);free(a->x2);free(a->x2norm);
free(a->h1);free(a->h3);free(a->silu_out);free(a->ffn_out);
}
static LayerGrads layer_grads_alloc(void) {
LayerGrads g;
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4);
g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4);
g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4);
return g;
}
static void layer_grads_zero(LayerGrads *g) {
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4);
memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4);
memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4);
memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4);
}
static void layer_grads_free(LayerGrads *g) {
free(g->Wq);free(g->Wk);free(g->Wv);free(g->Wo);
free(g->W1);free(g->W2);free(g->W3);
free(g->rms_att);free(g->rms_ffn);
}

129
training/stories_cpu_ops.h Normal file
View File

@ -0,0 +1,129 @@
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
#pragma once
#include "stories_config.h"
static float *g_rms_tmp = NULL;
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
int n = S; vvrsqrtf(ss, ss, &n);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
}
free(ss);
}
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
float *ss = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
}
float invd = 1.0f/d, eps=1e-5f;
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
float *rrms = (float*)malloc(S*4);
int n = S; vvrsqrtf(rrms, ss, &n);
float *dot = (float*)calloc(S, sizeof(float));
for (int i=0; i<d; i++) {
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
}
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
for (int i=0; i<d; i++) {
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
dw[i] += s;
}
free(ss); free(rrms); free(dot);
}
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
for (size_t i=0; i<s->n; i++) {
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
w[i] -= lr * mh / (sqrtf(vh) + eps);
}
}
// Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ])
// logits[v*SEQ+t] = logit for vocab v, position t
// targets[t] = target token id for position t
// Returns mean CE loss, writes dlogits = softmax(logits) - one_hot(targets)
// Data is column-major [V, S], but we process per-column (stride=1 within col is v*S+t, stride between v's is S)
// For vDSP: transpose to row-major scratch [S, V] to vectorize softmax per position
static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) {
// Work in transposed layout [S, V] where each row is one position's logits (contiguous)
float *buf = (float*)malloc(S * V * 4);
// Transpose [V,S] → [S,V]: buf[t*V+v] = logits[v*S+t]
vDSP_mtrans(logits, 1, buf, 1, (vDSP_Length)S, (vDSP_Length)V);
float total_loss = 0;
float invS = 1.0f / S;
for (int t = 0; t < S; t++) {
float *row = buf + t * V;
// max
float maxv;
vDSP_maxv(row, 1, &maxv, (vDSP_Length)V);
// row -= maxv
float neg_max = -maxv;
vDSP_vsadd(row, 1, &neg_max, row, 1, (vDSP_Length)V);
// exp in-place
int n = V;
vvexpf(row, row, &n);
// sum
float sum;
vDSP_sve(row, 1, &sum, (vDSP_Length)V);
// normalize
float inv_sum = 1.0f / sum;
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
// loss
int tgt = targets[t];
total_loss -= logf(row[tgt] + 1e-10f);
// gradient: softmax - one_hot, then /S
row[tgt] -= 1.0f;
vDSP_vsmul(row, 1, &invS, row, 1, (vDSP_Length)V);
}
// Transpose back [S,V] → [V,S]
vDSP_mtrans(buf, 1, dlogits, 1, (vDSP_Length)V, (vDSP_Length)S);
free(buf);
return total_loss / S;
}
// Embedding lookup: token_ids → x [DIM, SEQ] (channel-first)
// embed is [VOCAB, DIM] row-major (vocab_size rows, dim cols)
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
for (int d = 0; d < dim; d++) {
x[d*seq + t] = embed[tok*dim + d];
}
}
}
// Embedding backward: accumulate dE[tok] += dx[:,t] for each position
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
for (int t = 0; t < seq; t++) {
int tok = tokens[t];
for (int d = 0; d < dim; d++) {
d_embed[tok*dim + d] += dx[d*seq + t];
}
}
}

134
training/stories_io.h Normal file
View File

@ -0,0 +1,134 @@
// stories_io.h — IOSurface helpers, blob builders, NEON conversion
#pragma once
#include "stories_config.h"
#include <arm_neon.h>
static IOSurfaceRef make_surface(size_t bytes) {
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
}
static NSData *build_blob(const float *w, int rows, int cols) {
int ws=rows*cols*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
_Float16 *fp16=(_Float16*)(b+128);
for(int i=0;i<rows*cols;i++) fp16[i]=(_Float16)w[i];
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
static NSData *build_blob_t(const float *w, int rows, int cols) {
int ws=cols*rows*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
_Float16 *fp16=(_Float16*)(b+128);
for(int i=0;i<rows;i++) for(int j=0;j<cols;j++) fp16[j*rows+i]=(_Float16)w[i*cols+j];
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
static NSData *build_blob_fp16(_Float16 *d, int cnt) {
int ws=cnt*2, tot=128+ws;
uint8_t *b=(uint8_t*)calloc(tot,1);
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
memcpy(b+128,d,ws);
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
}
// NEON vectorized conversion
static void cvt_f16_f32(float *dst, const _Float16 *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vld1q_f16((const __fp16*)(src+i));
vst1q_f32(dst+i, vcvt_f32_f16(vget_low_f16(h)));
vst1q_f32(dst+i+4, vcvt_f32_f16(vget_high_f16(h)));
}
for (; i < n; i++) dst[i] = (float)src[i];
}
static void cvt_f32_f16(_Float16 *dst, const float *src, int n) {
int i = 0;
for (; i+7 < n; i += 8) {
float16x8_t h = vcombine_f16(vcvt_f16_f32(vld1q_f32(src+i)),
vcvt_f16_f32(vld1q_f32(src+i+4)));
vst1q_f16((__fp16*)(dst+i), h);
}
for (; i < n; i++) dst[i] = (_Float16)src[i];
}
// IOSurface I/O (channel-first [C,S] layout)
static void io_write_fp16(IOSurfaceRef s, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s), data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
static void io_read_fp16(IOSurfaceRef s, float *data, int ch_off, int channels, int sp) {
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
cvt_f16_f32(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp);
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
}
static void io_copy(IOSurfaceRef dst, int dst_ch, IOSurfaceRef src, int src_ch, int channels, int sp) {
IOSurfaceLock(dst, 0, NULL);
IOSurfaceLock(src, kIOSurfaceLockReadOnly, NULL);
memcpy((_Float16*)IOSurfaceGetBaseAddress(dst) + dst_ch*sp,
(_Float16*)IOSurfaceGetBaseAddress(src) + src_ch*sp,
channels * sp * sizeof(_Float16));
IOSurfaceUnlock(src, kIOSurfaceLockReadOnly, NULL);
IOSurfaceUnlock(dst, 0, NULL);
}
static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int channels, int sp) {
IOSurfaceLock(s, 0, NULL);
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
IOSurfaceUnlock(s, 0, NULL);
}
// Kernel compile/eval
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
@autoreleasepool {
NSData *md = [mil dataUsingEncoding:NSUTF8StringEncoding];
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(g_D, @selector(modelWithMILText:weights:optionsPlist:), md, weights, nil);
if (!desc) { printf(" [compile] desc=NULL\n"); return NULL; }
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(g_I, @selector(inMemoryModelWithDescriptor:), desc);
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
[[NSFileManager defaultManager] createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"] withIntermediateDirectories:YES attributes:nil error:nil];
[md writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
for (NSString *path in weights) {
NSString *rel = [path stringByReplacingOccurrencesOfString:@"@model_path/" withString:@""];
[weights[path][@"data"] writeToFile:[td stringByAppendingPathComponent:rel] atomically:YES];
}
NSError *e = nil;
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] FAIL: %s\n", e ? [[e description] UTF8String] : "no error"); return NULL;
}
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e)) {
printf(" [compile] load FAIL\n"); return NULL;
}
__sync_fetch_and_add(&g_compile_count, 1);
Kern *k = (Kern*)calloc(1, sizeof(Kern));
k->model = (void*)CFBridgingRetain(mdl);
k->ioIn = make_surface(ic_bytes);
k->ioOut = make_surface(oc_bytes);
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
k->tmpDir = (void*)CFBridgingRetain(td);
return k;
}
}
static void free_kern(Kern *k) {
if (!k) return;
id mdl = (__bridge id)k->model; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
CFRelease(k->ioIn); CFRelease(k->ioOut);
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
free(k);
}
static void ane_eval(Kern *k) {
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
}

286
training/stories_mil.h Normal file
View File

@ -0,0 +1,286 @@
// stories_mil.h — MIL program generators for ANE kernels
// Same architecture as single-layer train_large.m but parameterized
#pragma once
#include "stories_io.h"
#define MIL_HDR \
@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
#define CONV_CONST \
" string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \
" tensor<int32, [2]> st = const()[name=string(\"st\"), val=tensor<int32, [2]>([1,1])];\n" \
" tensor<int32, [4]> pd = const()[name=string(\"pd\"), val=tensor<int32, [4]>([0,0,0,0])];\n" \
" tensor<int32, [2]> dl = const()[name=string(\"dl\"), val=tensor<int32, [2]>([1,1])];\n" \
" int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n"
// SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm)
static NSString *gen_sdpa_fwd_taps(void) {
float sc = 1.0f/sqrtf((float)HD);
float invd = 1.0f/(float)DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([1])];\n"];
[m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ];
[m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ];
[m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ];
[m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,1]> rw = const()[name=string(\"rw\"), val=tensor<fp16, [1,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/rms1.bin\"), offset=uint64(64)))];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ];
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wq = const()[name=string(\"Wq\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wq.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wk = const()[name=string(\"Wk\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wk.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wv = const()[name=string(\"Wv\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wv.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wo = const()[name=string(\"Wo\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wo.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=string(\"cq\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=string(\"ck\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=string(\"cv\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS,SEQ,HD];
[m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"];
[m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=string(\"co\")];\n", DIM,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm)
static NSString *gen_ffn_fwd_taps(void) {
float invd = 1.0f/(float)DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([1])];\n"];
[m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ];
[m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ];
[m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ];
[m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"];
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,1]> rw = const()[name=string(\"rw\"), val=tensor<fp16, [1,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/rms2.bin\"), offset=uint64(64)))];\n", DIM, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ];
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W1 = const()[name=string(\"W1\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w1.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W3 = const()[name=string(\"W3\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w3.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W2 = const()[name=string(\"W2\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w2.bin\"), offset=uint64(64)))];\n", DIM,HIDDEN,DIM,HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=string(\"c1\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=string(\"c3\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=string(\"c2\")];\n", DIM,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=string(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3)
static NSString *gen_ffn_bwd(void) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM+2*HIDDEN, SEQ];
[m appendString:@CONV_CONST];
[m appendString:@" tensor<int32, [4]> bd = const()[name=string(\"bd\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dffn = slice_by_size(x=x,begin=bd,size=sd)[name=string(\"s0\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = slice_by_size(x=x,begin=b1,size=s1)[name=string(\"s1x\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM+HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = slice_by_size(x=x,begin=b3,size=s1)[name=string(\"s3x\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W2t = const()[name=string(\"W2t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w2t.bin\"), offset=uint64(64)))];\n", HIDDEN, DIM, HIDDEN, DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=string(\"cw2\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
[m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oms = sub(x=one,y=sig)[name=string(\"oms\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> homs = mul(x=h1,y=oms)[name=string(\"homs\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> brk = add(x=one,y=homs)[name=string(\"brk\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dsd = mul(x=sig,y=brk)[name=string(\"dsd\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> t1 = mul(x=dsilu,y=h3)[name=string(\"t1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = mul(x=t1,y=dsd)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> slh = mul(x=h1,y=sig)[name=string(\"slh\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = mul(x=dsilu,y=slh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W1t = const()[name=string(\"W1t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w1t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W3t = const()[name=string(\"W3t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w3t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=string(\"cw1\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=string(\"cw3\")];\n", DIM, SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = add(x=dx1,y=dx3)[name=string(\"adx\")];\n", DIM, SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=string(\"cat\")];\n", DIM+2*HIDDEN, SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// QKV backward: concat(dq,dk,dv) → dx
static NSString *gen_qkvb(void) {
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 3*DIM, SEQ];
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wqt = const()[name=string(\"Wqt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wqt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wkt = const()[name=string(\"Wkt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wkt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wvt = const()[name=string(\"Wvt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wvt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=string(\"cq\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=string(\"ck\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=string(\"cv\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = add(x=dxqk,y=dxv)[name=string(\"out\")];\n", DIM,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 1 + Wo^T
static NSString *gen_sdpa_bwd1(void) {
float sc = 1.0f/sqrtf((float)HD);
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
[m appendString:@CONV_CONST];
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wot = const()[name=string(\"Wot\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wot.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=string(\"cwo\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=df)[name=string(\"rd\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> da = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS,SEQ,HD];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=string(\"dv\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=string(\"dp\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// SDPA backward part 2: concat(probs,dp,Q,K) → concat(dQ,dK)
static NSString *gen_sdpa_bwd2(void) {
float sc = 1.0f/sqrtf((float)HD);
int bwd2_in = 2*SCORE_CH + 2*DIM;
NSMutableString *m = [NSMutableString string];
[m appendString:MIL_HDR];
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH,SEQ];
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH,SEQ];
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS,SEQ,SEQ];
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS,SEQ,SEQ];
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS,SEQ,HD];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS,HD,SEQ];
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM,SEQ];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM,SEQ];
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM,SEQ];
[m appendString:@" } -> (out);\n}\n"];
return m;
}
// Mask blob (causal mask [SEQ,SEQ])
static NSData *g_mask_blob = nil;
static NSData *get_mask_blob(void) {
if (!g_mask_blob) {
_Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16));
for(int t=0;t<SEQ;t++) for(int t2=0;t2<SEQ;t2++)
mask[t*SEQ+t2] = (t2<=t) ? (_Float16)0.0f : (_Float16)(-65504.0f);
g_mask_blob = build_blob_fp16(mask, SEQ*SEQ);
free(mask);
}
return g_mask_blob;
}

36
training/tokenize.py Normal file
View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
"""Extract pretokenized TinyStories data from zip.
Data format: flat uint16 token IDs (llama2.c BPE, 32K vocab).
Source: ~/tiny_stories_data_pretokenized.zip"""
import os, struct, zipfile
from pathlib import Path
ZIP_PATH = os.path.expanduser('~/tiny_stories_data_pretokenized.zip')
OUTPUT_PATH = str(Path(__file__).resolve().parent / 'tinystories_data00.bin')
def main():
if os.path.exists(OUTPUT_PATH):
n = os.path.getsize(OUTPUT_PATH) // 2
print(f"{OUTPUT_PATH} already exists ({n} tokens, {os.path.getsize(OUTPUT_PATH)/1e6:.1f} MB)")
return
print(f"Extracting data00.bin from {ZIP_PATH}...")
with zipfile.ZipFile(ZIP_PATH, 'r') as z:
with z.open('data00.bin') as src, open(OUTPUT_PATH, 'wb') as dst:
while True:
chunk = src.read(1 << 20)
if not chunk:
break
dst.write(chunk)
n = os.path.getsize(OUTPUT_PATH) // 2
print(f"Written {OUTPUT_PATH} ({n} tokens, {os.path.getsize(OUTPUT_PATH)/1e6:.1f} MB)")
# Sanity check
with open(OUTPUT_PATH, 'rb') as f:
tokens = struct.unpack('<10H', f.read(20))
print(f"First 10 tokens: {tokens}")
if __name__ == '__main__':
main()

File diff suppressed because it is too large Load Diff