mirror of https://github.com/maderix/ANE.git
stories110M: 12-layer ANE training with dashboard, 107ms/step
- Scale to full stories110M (109M params, 12 layers) with real TinyStories data - vDSP-vectorized cross-entropy (110ms→14ms), NEON fp16 IO, async dW - TUI dashboard: loss curve, ANE/CPU power, CPU/memory graphs, text generation - Split into modular headers: config, io, mil, cpu_ops
This commit is contained in:
parent
f213c8db68
commit
4d67db1bdb
|
|
@ -3,10 +3,18 @@ CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc
|
|||
FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface
|
||||
LDFLAGS = $(FRAMEWORKS) -ldl
|
||||
|
||||
HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h
|
||||
|
||||
train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h
|
||||
$(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS)
|
||||
|
||||
clean:
|
||||
rm -f train
|
||||
train_large: train_large.m $(HEADERS_LARGE)
|
||||
$(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate
|
||||
|
||||
.PHONY: clean
|
||||
tokenize:
|
||||
python3 tokenize.py
|
||||
|
||||
clean:
|
||||
rm -f train train_large
|
||||
|
||||
.PHONY: clean tokenize
|
||||
|
|
|
|||
|
|
@ -0,0 +1,69 @@
|
|||
# ANE Training — Stories110M on Apple Neural Engine
|
||||
|
||||
Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly on Apple's Neural Engine using private ANE APIs.
|
||||
|
||||

|
||||
|
||||
## Architecture
|
||||
|
||||
- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256
|
||||
- **109.53M params** (84.95M transformer + 24.58M embedding)
|
||||
- **72 ANE kernels** per compile (60 weight-bearing, 12 weight-free sdpaBwd2)
|
||||
- **6 kernel types per layer**: fwdAttn, fwdFFN, ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd
|
||||
|
||||
## Performance
|
||||
|
||||
| Component | Time (ms/step) |
|
||||
|-----------|---------------|
|
||||
| ANE eval | 9.6 |
|
||||
| IO (fp16 conversion) | 4.1 |
|
||||
| Classifier (cblas) | 9.1 |
|
||||
| Cross-entropy + residuals | 14.4 |
|
||||
| RMSNorm | 0.1 |
|
||||
| **Total** | **107 ms/step** |
|
||||
|
||||
## Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `train_large.m` | Main training loop — 12-layer forward/backward, checkpoint, exec() restart |
|
||||
| `stories_config.h` | Model config, structs, alloc helpers |
|
||||
| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval |
|
||||
| `stories_mil.h` | MIL program generators for all 6 ANE kernel types |
|
||||
| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam, embedding ops |
|
||||
| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs, text generation |
|
||||
| `tokenize.py` | Extract pretokenized TinyStories data |
|
||||
| `Makefile` | Build targets |
|
||||
|
||||
## How it works
|
||||
|
||||
1. **Forward pass**: Each layer runs fwdAttn (QKV + SDPA + Wo) and fwdFFN (W1 + SiLU(W3) + W2) on ANE via MIL-compiled kernels. Final RMSNorm + classifier matmul on CPU (cblas).
|
||||
|
||||
2. **Backward pass**: Reverse layer order. ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd on ANE. Weight gradients (dW) via async cblas_sgemm on CPU. RMSNorm backward via vDSP.
|
||||
|
||||
3. **Compile budget**: ANE has a ~119 compile limit per process. With 72 kernels per batch, we run 10 accumulation steps then `exec()` restart with checkpoint resume.
|
||||
|
||||
4. **Data**: Real TinyStories text (20M tokens), mmap'd uint16 token IDs, random position sampling per step.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Extract tokenized data
|
||||
python3 tokenize.py
|
||||
|
||||
# Build and train
|
||||
make train_large
|
||||
./train_large # fresh start
|
||||
./train_large --resume # resume from checkpoint
|
||||
|
||||
# Monitor with dashboard
|
||||
pip install blessed psutil numpy
|
||||
python3 dashboard.py --resume # needs sudo for powermetrics
|
||||
```
|
||||
|
||||
## Key techniques
|
||||
|
||||
- **NEON vectorized fp16<->fp32**: ARM NEON intrinsics for fast IOSurface data transfer
|
||||
- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar
|
||||
- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE
|
||||
- **SDPA causal mask workaround**: ANE hardware ignores attn_mask, so we decompose attention into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv)
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 232 KiB |
|
|
@ -0,0 +1,882 @@
|
|||
"""TUI dashboard for ANE training (train_large). Uses blessed for terminal UI."""
|
||||
|
||||
import argparse, fcntl, math, os, re, select, signal, struct, subprocess, sys, time, threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from blessed import Terminal
|
||||
except ImportError:
|
||||
print('pip install blessed')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import psutil
|
||||
HAS_PSUTIL = True
|
||||
except ImportError:
|
||||
HAS_PSUTIL = False
|
||||
|
||||
DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12
|
||||
HD = DIM // HEADS
|
||||
CKPT_PATH = 'ane_stories110M_ckpt.bin'
|
||||
TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent.parent / 'assets' / 'models' / 'tokenizer.bin')
|
||||
|
||||
|
||||
class State:
|
||||
def __init__(self):
|
||||
self.model_config = {}
|
||||
self.params = {}
|
||||
self.kernels = {}
|
||||
self.training = {}
|
||||
self.flops = {}
|
||||
self.step = 0
|
||||
self.total_steps = 0
|
||||
self.loss = 0.0
|
||||
self.best_loss = float('inf')
|
||||
self.loss_history = []
|
||||
self.ms_per_step = 0.0
|
||||
self.compile_pct = 0.0
|
||||
self.compiles = 0
|
||||
self.component_timing = {}
|
||||
self.power = {'ane': 0.0, 'cpu': 0.0, 'gpu': 0.0}
|
||||
self.power_history_ane = deque(maxlen=300)
|
||||
self.power_history_cpu = deque(maxlen=300)
|
||||
self.logs = deque(maxlen=2000)
|
||||
self.log_scroll = 0
|
||||
self.auto_scroll = True
|
||||
self.batch_num = 0
|
||||
self.efficiency = {}
|
||||
self.gen_text = ''
|
||||
self.gen_step = 0
|
||||
self.gen_status = 'idle'
|
||||
self.gen_lock = threading.Lock()
|
||||
self.cpu_pct_history = deque(maxlen=300)
|
||||
self.mem_mb_history = deque(maxlen=300)
|
||||
self.proc_mem_mb_history = deque(maxlen=300)
|
||||
self.train_pid = None
|
||||
|
||||
S = State()
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, path):
|
||||
self.vocab = []
|
||||
self.scores = []
|
||||
with open(path, 'rb') as f:
|
||||
max_len = struct.unpack('i', f.read(4))[0]
|
||||
for _ in range(VOCAB):
|
||||
score = struct.unpack('f', f.read(4))[0]
|
||||
slen = struct.unpack('i', f.read(4))[0]
|
||||
tok = f.read(slen).decode('utf-8', errors='replace')
|
||||
self.vocab.append(tok)
|
||||
self.scores.append(score)
|
||||
|
||||
def decode(self, token_id):
|
||||
if 0 <= token_id < len(self.vocab):
|
||||
s = self.vocab[token_id]
|
||||
if s.startswith('<0x') and s.endswith('>'):
|
||||
try:
|
||||
return chr(int(s[3:-1], 16))
|
||||
except:
|
||||
return s
|
||||
return s
|
||||
return ''
|
||||
|
||||
_tokenizer = None
|
||||
def get_tokenizer():
|
||||
global _tokenizer
|
||||
if _tokenizer is None:
|
||||
try:
|
||||
_tokenizer = Tokenizer(TOKENIZER_PATH)
|
||||
except Exception as e:
|
||||
S.logs.append(f'[gen] tokenizer load failed: {e}')
|
||||
return None
|
||||
return _tokenizer
|
||||
|
||||
|
||||
def load_weights_from_ckpt(path):
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
# CkptHdr: 96 bytes (verified with sizeof)
|
||||
hdr = f.read(96)
|
||||
if len(hdr) < 96:
|
||||
return None
|
||||
wq_sz = DIM * DIM
|
||||
wo_sz = DIM * DIM
|
||||
w1_sz = HIDDEN * DIM
|
||||
w2_sz = DIM * HIDDEN
|
||||
w3_sz = HIDDEN * DIM
|
||||
# Per-layer: weights + adam state (m,v for each)
|
||||
adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 +
|
||||
w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2)
|
||||
W = {}
|
||||
for L in range(NLAYERS):
|
||||
W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy()
|
||||
W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
||||
W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy()
|
||||
W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy()
|
||||
W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||
W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||
# Skip adam state for this layer
|
||||
f.seek(adam_per_layer * 4, 1)
|
||||
W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy()
|
||||
f.seek(DIM * 2 * 4, 1) # skip rms_final adam
|
||||
W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy()
|
||||
return W
|
||||
except Exception as e:
|
||||
S.logs.append(f'[gen] ckpt load failed: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def rmsnorm(x, w):
|
||||
ss = np.mean(x * x) + 1e-5
|
||||
return x * (1.0 / math.sqrt(ss)) * w
|
||||
|
||||
def softmax(x):
|
||||
x = x - np.max(x)
|
||||
e = np.exp(x)
|
||||
return e / np.sum(e)
|
||||
|
||||
def generate_text(W, tok, max_tokens=64, temperature=0.8):
|
||||
tokenizer = get_tokenizer()
|
||||
if tokenizer is None:
|
||||
return '[no tokenizer]'
|
||||
|
||||
tokens = [1]
|
||||
text_parts = []
|
||||
|
||||
# Precompute RoPE frequencies
|
||||
freqs = np.zeros((SEQ, HD // 2), dtype=np.float32)
|
||||
for pos in range(SEQ):
|
||||
for i in range(HD // 2):
|
||||
freq = 1.0 / (10000.0 ** (2.0 * i / HD))
|
||||
freqs[pos, i] = pos * freq
|
||||
|
||||
for step in range(max_tokens):
|
||||
seq_len = len(tokens)
|
||||
if seq_len > SEQ:
|
||||
break
|
||||
|
||||
x = W['embed'][tokens[-1]].copy()
|
||||
|
||||
for L in range(NLAYERS):
|
||||
# RMSNorm + QKV
|
||||
xn = rmsnorm(x, W[f'rms1_{L}'])
|
||||
q = W[f'Wq{L}'] @ xn
|
||||
k = W[f'Wk{L}'] @ xn
|
||||
v = W[f'Wv{L}'] @ xn
|
||||
|
||||
# RoPE
|
||||
pos = seq_len - 1
|
||||
for h in range(HEADS):
|
||||
for i in range(HD // 2):
|
||||
freq = freqs[pos, i]
|
||||
cos_v, sin_v = math.cos(freq), math.sin(freq)
|
||||
qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1]
|
||||
q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v
|
||||
q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v
|
||||
ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1]
|
||||
k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v
|
||||
k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v
|
||||
|
||||
# Attention (single token)
|
||||
o = np.zeros(DIM, dtype=np.float32)
|
||||
for h in range(HEADS):
|
||||
qh = q[h * HD:(h + 1) * HD]
|
||||
kh = k[h * HD:(h + 1) * HD]
|
||||
vh = v[h * HD:(h + 1) * HD]
|
||||
score = np.dot(qh, kh) / math.sqrt(HD)
|
||||
o[h * HD:(h + 1) * HD] = vh
|
||||
|
||||
# Residual + output projection
|
||||
x2 = x + W[f'Wo{L}'] @ o
|
||||
|
||||
# FFN
|
||||
x2n = rmsnorm(x2, W[f'rms2_{L}'])
|
||||
h1 = W[f'W1_{L}'] @ x2n
|
||||
h3 = W[f'W3_{L}'] @ x2n
|
||||
# SiLU
|
||||
h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3
|
||||
ffn_out = W[f'W2_{L}'] @ h1
|
||||
|
||||
x = x2 + ffn_out
|
||||
|
||||
x = rmsnorm(x, W['rms_final'])
|
||||
|
||||
# Logits
|
||||
logits = W['embed'] @ x
|
||||
|
||||
if temperature < 0.01:
|
||||
next_tok = int(np.argmax(logits))
|
||||
else:
|
||||
logits = logits / temperature
|
||||
probs = softmax(logits)
|
||||
next_tok = int(np.random.choice(VOCAB, p=probs))
|
||||
|
||||
if next_tok == 2:
|
||||
break
|
||||
tokens.append(next_tok)
|
||||
piece = tokenizer.decode(next_tok)
|
||||
text_parts.append(piece)
|
||||
|
||||
return ''.join(text_parts)
|
||||
|
||||
|
||||
def generation_thread():
|
||||
last_gen_step = -1
|
||||
while True:
|
||||
time.sleep(5)
|
||||
if S.step <= last_gen_step + 99:
|
||||
continue
|
||||
if not os.path.exists(CKPT_PATH):
|
||||
continue
|
||||
with S.gen_lock:
|
||||
S.gen_status = 'generating'
|
||||
S.gen_step = S.step
|
||||
try:
|
||||
W = load_weights_from_ckpt(CKPT_PATH)
|
||||
if W is None:
|
||||
with S.gen_lock:
|
||||
S.gen_status = 'idle'
|
||||
continue
|
||||
text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8)
|
||||
with S.gen_lock:
|
||||
S.gen_text = text
|
||||
S.gen_step = S.step
|
||||
S.gen_status = 'done'
|
||||
S.step # just to reference
|
||||
except Exception as e:
|
||||
with S.gen_lock:
|
||||
S.gen_text = f'[error: {e}]'
|
||||
S.gen_status = 'done'
|
||||
last_gen_step = S.step
|
||||
|
||||
|
||||
def sysmetrics_thread():
|
||||
while True:
|
||||
time.sleep(1)
|
||||
if not HAS_PSUTIL:
|
||||
continue
|
||||
now = time.monotonic()
|
||||
S.cpu_pct_history.append(psutil.cpu_percent(interval=None))
|
||||
mem = psutil.virtual_memory()
|
||||
S.mem_mb_history.append(mem.used / (1024 * 1024))
|
||||
pid = S.train_pid
|
||||
if pid:
|
||||
try:
|
||||
p = psutil.Process(pid)
|
||||
S.proc_mem_mb_history.append(p.memory_info().rss / (1024 * 1024))
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
|
||||
RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)')
|
||||
RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)')
|
||||
RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing')
|
||||
RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)')
|
||||
RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)')
|
||||
RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]')
|
||||
RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)')
|
||||
RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)')
|
||||
RE_RESUME = re.compile(r'\[RESUMED step (\d+), loss=([\d.]+)\]')
|
||||
RE_FLOPS = re.compile(r'FLOPs/step: fwd=([\d.]+)M bwd_dx=([\d.]+)M bwd_dW=([\d.]+)M sdpa_bwd=([\d.]+)M total=([\d.]+)M')
|
||||
RE_ANE_FLOPS = re.compile(r'ANE FLOPs/step: ([\d.]+)M')
|
||||
RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)')
|
||||
RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%')
|
||||
RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)')
|
||||
RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW')
|
||||
RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW')
|
||||
RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW')
|
||||
|
||||
def parse_line(line):
|
||||
S.logs.append(line)
|
||||
m = RE_CONFIG.search(line)
|
||||
if m:
|
||||
S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups())))
|
||||
return
|
||||
m = RE_PARAMS.search(line)
|
||||
if m:
|
||||
S.params = {'total': float(m[1]), 'transformer': float(m[2]), 'embed': float(m[3])}
|
||||
return
|
||||
m = RE_KERNELS.search(line)
|
||||
if m:
|
||||
S.kernels = {'total': int(m[1]), 'weight_bearing': int(m[2])}
|
||||
return
|
||||
m = RE_ACCUM.search(line)
|
||||
if m:
|
||||
S.training = {'accum': int(m[1]), 'lr': m[2]}
|
||||
return
|
||||
m = RE_FLOPS.search(line)
|
||||
if m:
|
||||
S.flops.update(fwd=float(m[1]), bwd_dx=float(m[2]), bwd_dw=float(m[3]),
|
||||
sdpa_bwd=float(m[4]), total=float(m[5]))
|
||||
return
|
||||
m = RE_ANE_FLOPS.search(line)
|
||||
if m:
|
||||
S.flops['ane'] = float(m[1])
|
||||
return
|
||||
m = RE_STEP.search(line)
|
||||
if m:
|
||||
S.step, S.loss = int(m[1]), float(m[2])
|
||||
S.loss_history.append((S.step, S.loss))
|
||||
S.best_loss = min(S.best_loss, S.loss)
|
||||
return
|
||||
m = RE_BATCH.search(line)
|
||||
if m:
|
||||
S.batch_num = int(m[1])
|
||||
compile_ms, train_ms = float(m[2]), float(m[3])
|
||||
S.ms_per_step = float(m[4])
|
||||
S.compiles = int(m[5])
|
||||
S.compile_pct = 100 * compile_ms / (compile_ms + train_ms) if compile_ms + train_ms > 0 else 0
|
||||
return
|
||||
m = RE_TIMING.search(line)
|
||||
if m:
|
||||
S.component_timing = dict(zip(['ane', 'io', 'cls', 'elem', 'rms', 'cblas_wait'], map(float, m.groups())))
|
||||
return
|
||||
m = RE_ANE_TFLOPS.search(line)
|
||||
if m:
|
||||
S.flops['ane_tflops'] = float(m[1])
|
||||
return
|
||||
m = RE_ANE_UTIL.search(line)
|
||||
if m:
|
||||
S.flops['ane_util'] = float(m[1])
|
||||
return
|
||||
m = RE_EFFICIENCY.search(line)
|
||||
if m:
|
||||
S.efficiency[m[1].strip()] = m[2].strip()
|
||||
return
|
||||
|
||||
|
||||
def parse_powermetrics_text(text):
|
||||
now = time.monotonic()
|
||||
m = RE_ANE_POWER.search(text)
|
||||
if m:
|
||||
S.power['ane'] = float(m[1]) / 1000.0
|
||||
S.power_history_ane.append((now, S.power['ane']))
|
||||
m = RE_CPU_POWER.search(text)
|
||||
if m:
|
||||
S.power['cpu'] = float(m[1]) / 1000.0
|
||||
S.power_history_cpu.append((now, S.power['cpu']))
|
||||
m = RE_GPU_POWER.search(text)
|
||||
if m:
|
||||
S.power['gpu'] = float(m[1]) / 1000.0
|
||||
|
||||
|
||||
BRAILLE_BASE = 0x2800
|
||||
|
||||
BRAILLE_MAP = [
|
||||
[1, 8],
|
||||
[2, 16],
|
||||
[4, 32],
|
||||
[64, 128],
|
||||
]
|
||||
|
||||
def braille_chart(values, width, height, label_fmt='{:.1f}', y_range=None):
|
||||
if not values or width < 8 or height < 2:
|
||||
return ['(no data)'] * max(1, height)
|
||||
chart_w = width - 6
|
||||
if chart_w < 2:
|
||||
return ['(no data)'] * max(1, height)
|
||||
points_x = chart_w * 2
|
||||
points_y = height * 4
|
||||
data = values[-points_x:] if len(values) > points_x else values
|
||||
lo, hi = min(data), max(data)
|
||||
if y_range:
|
||||
lo, hi = y_range
|
||||
if hi - lo < 0.001:
|
||||
lo, hi = lo - 0.5, hi + 0.5
|
||||
margin = (hi - lo) * 0.05
|
||||
lo -= margin
|
||||
hi += margin
|
||||
|
||||
grid = [[0] * chart_w for _ in range(height)]
|
||||
|
||||
def plot(px, py):
|
||||
px = max(0, min(points_x - 1, px))
|
||||
py = max(0, min(points_y - 1, py))
|
||||
grid[py // 4][px // 2] |= BRAILLE_MAP[py % 4][px % 2]
|
||||
|
||||
def val_to_y(v):
|
||||
return int((1 - (v - lo) / (hi - lo)) * (points_y - 1))
|
||||
|
||||
for i in range(len(data)):
|
||||
if i >= points_x:
|
||||
break
|
||||
y0 = val_to_y(data[i])
|
||||
plot(i, y0)
|
||||
if i > 0:
|
||||
y_prev = val_to_y(data[i - 1])
|
||||
y_lo, y_hi = min(y_prev, y0), max(y_prev, y0)
|
||||
for yy in range(y_lo, y_hi + 1):
|
||||
if y_hi != y_lo:
|
||||
t = (yy - y_prev) / (y0 - y_prev)
|
||||
xx = int(i - 1 + t)
|
||||
else:
|
||||
xx = i
|
||||
plot(xx, yy)
|
||||
|
||||
lines = []
|
||||
for r in range(height):
|
||||
if r == 0:
|
||||
label = label_fmt.format(hi)[:5].rjust(5)
|
||||
elif r == height - 1:
|
||||
label = label_fmt.format(lo)[:5].rjust(5)
|
||||
elif r == height // 2:
|
||||
label = label_fmt.format((hi + lo) / 2)[:5].rjust(5)
|
||||
else:
|
||||
label = ' '
|
||||
row_str = ''.join(chr(BRAILLE_BASE | grid[r][c]) for c in range(chart_w))
|
||||
lines.append(f'{label}\u2502{row_str}')
|
||||
|
||||
lines.append(' \u2514' + '\u2500' * chart_w)
|
||||
return lines
|
||||
|
||||
|
||||
def draw(term):
|
||||
w, h = term.width, term.height
|
||||
if w < 40 or h < 15:
|
||||
print(term.home + term.clear + 'Terminal too small', end='', flush=True)
|
||||
return
|
||||
|
||||
buf = []
|
||||
|
||||
def put(y, x, text, style=''):
|
||||
if 0 <= y < h and x < w:
|
||||
text = text[:w - x]
|
||||
if style:
|
||||
buf.append(term.move(y, x) + style + text + term.normal)
|
||||
return
|
||||
buf.append(term.move(y, x) + text)
|
||||
|
||||
buf.append(term.home + term.clear)
|
||||
|
||||
mid_x = w // 2
|
||||
right_w = w - mid_x - 1
|
||||
left_w = mid_x - 1
|
||||
|
||||
row = 0
|
||||
|
||||
# Model Config header
|
||||
hdr = '\u2500 Model Config '
|
||||
put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan)
|
||||
row += 1
|
||||
|
||||
cfg = S.model_config
|
||||
if cfg:
|
||||
line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}"
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, line1)
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
row += 1
|
||||
p, k, t = S.params, S.kernels, S.training
|
||||
line2 = f"{p.get('total', '?')}M params ({p.get('transformer', '?')}M xfmr + {p.get('embed', '?')}M embed)"
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, line2)
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
row += 1
|
||||
line3 = f"{k.get('total', '?')} kernels ({k.get('weight_bearing', '?')} wt-bearing) | Accum {t.get('accum', '?')} | Adam LR={t.get('lr', '?')}"
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, line3)
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
row += 1
|
||||
else:
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, 'Waiting for model config...')
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
row += 1
|
||||
|
||||
remaining = h - row - 1
|
||||
# Allocate: loss curve ~40%, logs ~30%, power/cpu/mem/gen share rest
|
||||
power_h = max(3, remaining // 8)
|
||||
gen_h = max(2, remaining // 10)
|
||||
extra_panels = power_h + power_h + gen_h + 6 # power + cpu/mem + gen + dividers
|
||||
log_h_min = max(5, remaining // 5)
|
||||
curve_h = max(5, remaining - extra_panels - log_h_min)
|
||||
|
||||
# Loss Curve + Training Stats divider
|
||||
put(row, 0, '\u251c\u2500 Loss Curve ' + '\u2500' * max(0, left_w - 13) + '\u252c\u2500 Training Stats ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan)
|
||||
row += 1
|
||||
|
||||
# Loss curve
|
||||
loss_vals = [l for _, l in S.loss_history]
|
||||
curve_lines = braille_chart(loss_vals, left_w - 1, curve_h)
|
||||
for i, cl in enumerate(curve_lines):
|
||||
put(row + i, 0, '\u2502', term.cyan)
|
||||
put(row + i, 1, cl, term.green)
|
||||
put(row + i, mid_x, '\u2502', term.cyan)
|
||||
put(row + i, w - 1, '\u2502', term.cyan)
|
||||
|
||||
# Training stats (right panel)
|
||||
sr = row
|
||||
step_str = f'{S.step}' + (f'/{S.total_steps}' if S.total_steps and S.total_steps < 999999 else '')
|
||||
put(sr, mid_x + 1, f' Step: {step_str} Loss: {S.loss:.4f}' if S.loss else ' Step: --', term.yellow)
|
||||
sr += 1
|
||||
put(sr, mid_x + 1, f' Best: {S.best_loss:.4f} ms/step: {S.ms_per_step:.1f}' if S.best_loss < float('inf') else ' Best: --')
|
||||
sr += 1
|
||||
ane_tflops = S.flops.get('ane_tflops', 0)
|
||||
ane_util = S.flops.get('ane_util', 0)
|
||||
if ane_tflops:
|
||||
put(sr, mid_x + 1, f' ANE: {ane_tflops:.2f}T Compile: {S.compile_pct:.0f}% Util: {ane_util:.1f}%')
|
||||
else:
|
||||
put(sr, mid_x + 1, f' Compile: {S.compile_pct:.0f}%')
|
||||
sr += 1
|
||||
ct = S.component_timing
|
||||
if ct:
|
||||
put(sr, mid_x + 1, f' ane={ct.get("ane", 0):.1f} io={ct.get("io", 0):.1f} cls={ct.get("cls", 0):.1f} elem={ct.get("elem", 0):.1f}')
|
||||
sr += 1
|
||||
put(sr, mid_x + 1, f' rms={ct.get("rms", 0):.1f} cblas_wait={ct.get("cblas_wait", 0):.1f} ms/step')
|
||||
sr += 1
|
||||
pw = S.power
|
||||
if any(pw.values()):
|
||||
put(sr, mid_x + 1, '\u2500 Power ' + '\u2500' * max(0, right_w - 9), term.cyan)
|
||||
sr += 1
|
||||
put(sr, mid_x + 1, f' ANE: {pw["ane"]:.1f}W CPU: {pw["cpu"]:.1f}W GPU: {pw["gpu"]:.1f}W', term.magenta)
|
||||
sr += 1
|
||||
if S.batch_num:
|
||||
put(sr, mid_x + 1, f' Batch: {S.batch_num} Compiles: {S.compiles}')
|
||||
sr += 1
|
||||
|
||||
# Fill vertical borders between loss curve and stats
|
||||
top_end = row + len(curve_lines)
|
||||
for r in range(row, max(top_end, sr)):
|
||||
if r >= top_end:
|
||||
put(r, 0, '\u2502', term.cyan)
|
||||
if r >= sr:
|
||||
put(r, mid_x, '\u2502', term.cyan)
|
||||
put(r, w - 1, '\u2502', term.cyan)
|
||||
row = max(top_end, sr)
|
||||
|
||||
# Power charts
|
||||
has_power = len(S.power_history_ane) > 1 or len(S.power_history_cpu) > 1
|
||||
if has_power:
|
||||
put(row, 0, '\u251c\u2500 ANE Power (W) ' + '\u2500' * max(0, left_w - 16) + '\u252c\u2500 CPU Power (W) ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan)
|
||||
row += 1
|
||||
ane_vals = [v for _, v in S.power_history_ane]
|
||||
cpu_vals = [v for _, v in S.power_history_cpu]
|
||||
ane_lines = braille_chart(ane_vals, left_w - 1, power_h, label_fmt='{:.1f}')
|
||||
cpu_lines = braille_chart(cpu_vals, right_w - 1, power_h, label_fmt='{:.1f}')
|
||||
max_lines = max(len(ane_lines), len(cpu_lines))
|
||||
while len(ane_lines) < max_lines:
|
||||
ane_lines.append(' ' * (left_w - 1))
|
||||
while len(cpu_lines) < max_lines:
|
||||
cpu_lines.append(' ' * (right_w - 1))
|
||||
for i in range(max_lines):
|
||||
put(row + i, 0, '\u2502', term.cyan)
|
||||
put(row + i, 1, ane_lines[i], term.red)
|
||||
put(row + i, mid_x, '\u2502', term.cyan)
|
||||
put(row + i, mid_x + 1, cpu_lines[i], term.blue)
|
||||
put(row + i, w - 1, '\u2502', term.cyan)
|
||||
row += max_lines
|
||||
|
||||
# CPU / Memory charts
|
||||
has_sysmetrics = len(S.cpu_pct_history) > 0
|
||||
if has_sysmetrics:
|
||||
put(row, 0, '\u251c\u2500 CPU % ' + '\u2500' * max(0, left_w - 8) + '\u252c\u2500 Memory (MB) ' + '\u2500' * max(0, right_w - 15) + '\u2524', term.cyan)
|
||||
row += 1
|
||||
cpu_vals = list(S.cpu_pct_history)
|
||||
mem_vals = list(S.proc_mem_mb_history) if S.proc_mem_mb_history else list(S.mem_mb_history)
|
||||
mem_label = 'proc' if S.proc_mem_mb_history else 'sys'
|
||||
cpu_lines = braille_chart(cpu_vals, left_w - 1, power_h, label_fmt='{:.0f}', y_range=(0, 100))
|
||||
mem_lines = braille_chart(mem_vals, right_w - 1, power_h, label_fmt='{:.0f}')
|
||||
max_lines = max(len(cpu_lines), len(mem_lines))
|
||||
while len(cpu_lines) < max_lines:
|
||||
cpu_lines.append(' ' * (left_w - 1))
|
||||
while len(mem_lines) < max_lines:
|
||||
mem_lines.append(' ' * (right_w - 1))
|
||||
for i in range(max_lines):
|
||||
put(row + i, 0, '\u2502', term.cyan)
|
||||
put(row + i, 1, cpu_lines[i], term.yellow)
|
||||
put(row + i, mid_x, '\u2502', term.cyan)
|
||||
put(row + i, mid_x + 1, mem_lines[i], term.magenta)
|
||||
put(row + i, w - 1, '\u2502', term.cyan)
|
||||
row += max_lines
|
||||
|
||||
# Generated text
|
||||
with S.gen_lock:
|
||||
gen_text = S.gen_text
|
||||
gen_step = S.gen_step
|
||||
gen_status = S.gen_status
|
||||
if gen_text or gen_status == 'generating':
|
||||
status_tag = ' (generating...)' if gen_status == 'generating' else f' (step {gen_step})'
|
||||
put(row, 0, '\u251c\u2500 Generated Text' + status_tag + ' ' + '\u2500' * max(0, w - 20 - len(status_tag)) + '\u2524', term.cyan)
|
||||
row += 1
|
||||
if gen_text:
|
||||
line_w = w - 3
|
||||
text = gen_text.replace('\n', ' ')
|
||||
wrapped = [text[i:i + line_w] for i in range(0, len(text), line_w)]
|
||||
for i, tl in enumerate(wrapped[:gen_h]):
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, tl, term.white)
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
row += 1
|
||||
else:
|
||||
put(row, 0, '\u2502', term.cyan)
|
||||
put(row, 2, '...')
|
||||
put(row, w - 1, '\u2502', term.cyan)
|
||||
row += 1
|
||||
|
||||
# Logs
|
||||
log_h = h - row - 1
|
||||
scroll_hint = ' (scroll) ' if not S.auto_scroll else ' '
|
||||
put(row, 0, '\u251c\u2500 Logs' + scroll_hint + '\u2500' * max(0, w - 8 - len(scroll_hint)) + '\u2524', term.cyan)
|
||||
row += 1
|
||||
|
||||
logs = list(S.logs)
|
||||
if log_h > 0 and logs:
|
||||
if S.auto_scroll:
|
||||
start = max(0, len(logs) - log_h)
|
||||
else:
|
||||
start = max(0, min(S.log_scroll, len(logs) - log_h))
|
||||
visible = logs[start:start + log_h]
|
||||
for i, line in enumerate(visible):
|
||||
put(row + i, 0, '\u2502', term.cyan)
|
||||
if RE_STEP.search(line):
|
||||
put(row + i, 1, line[:w - 2], term.yellow)
|
||||
elif line.strip().startswith('[batch'):
|
||||
put(row + i, 1, line[:w - 2], term.blue)
|
||||
elif 'FAIL' in line or 'error' in line.lower():
|
||||
put(row + i, 1, line[:w - 2], term.red)
|
||||
else:
|
||||
put(row + i, 1, line[:w - 2])
|
||||
put(row + i, w - 1, '\u2502', term.cyan)
|
||||
for i in range(len(visible), log_h):
|
||||
put(row + i, 0, '\u2502', term.cyan)
|
||||
put(row + i, w - 1, '\u2502', term.cyan)
|
||||
|
||||
# Bottom border
|
||||
put(h - 1, 0, '\u2514' + '\u2500' * (w - 2) + '\u2518', term.cyan)
|
||||
|
||||
sys.stdout.write(''.join(buf))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def set_nonblock(fd):
|
||||
fl = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
|
||||
|
||||
def spawn_training(resume=False, steps=10000):
|
||||
cmd = 'make train_large 2>&1 && ./train_large'
|
||||
if resume:
|
||||
cmd += ' --resume'
|
||||
cmd += f' --steps {steps}'
|
||||
proc = subprocess.Popen(
|
||||
['bash', '-c', cmd],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)) or '.')
|
||||
set_nonblock(proc.stdout.fileno())
|
||||
return proc
|
||||
|
||||
def spawn_powermetrics():
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
['sudo', 'powermetrics', '--samplers', 'cpu_power,gpu_power,ane_power', '-i', '1000'],
|
||||
stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
|
||||
set_nonblock(proc.stdout.fileno())
|
||||
return proc
|
||||
except (FileNotFoundError, PermissionError):
|
||||
return None
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)')
|
||||
parser.add_argument('--resume', action='store_true', help='Resume from checkpoint')
|
||||
parser.add_argument('--infinite', action='store_true', help='Train indefinitely')
|
||||
parser.add_argument('--no-powermetrics', action='store_true')
|
||||
parser.add_argument('--no-generate', action='store_true', help='Disable text generation')
|
||||
parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.infinite:
|
||||
args.steps = 999999999
|
||||
S.total_steps = args.steps
|
||||
|
||||
term = Terminal()
|
||||
procs = []
|
||||
|
||||
train_proc = spawn_training(resume=args.resume, steps=args.steps)
|
||||
S.train_pid = train_proc.pid
|
||||
procs.append(train_proc)
|
||||
|
||||
if HAS_PSUTIL:
|
||||
psutil.cpu_percent(interval=None) # prime the counter
|
||||
sys_t = threading.Thread(target=sysmetrics_thread, daemon=True)
|
||||
sys_t.start()
|
||||
|
||||
pm_proc = None
|
||||
if not args.no_powermetrics:
|
||||
pm_proc = spawn_powermetrics()
|
||||
if pm_proc:
|
||||
procs.append(pm_proc)
|
||||
|
||||
if not args.no_generate:
|
||||
gen_t = threading.Thread(target=generation_thread, daemon=True)
|
||||
gen_t.start()
|
||||
|
||||
pm_buf = ''
|
||||
train_buf = ''
|
||||
|
||||
def cleanup():
|
||||
for p in procs:
|
||||
try:
|
||||
p.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
signal.signal(signal.SIGINT, lambda *a: cleanup())
|
||||
signal.signal(signal.SIGTERM, lambda *a: cleanup())
|
||||
|
||||
resized = [False]
|
||||
def on_resize(*a):
|
||||
resized[0] = True
|
||||
|
||||
signal.signal(signal.SIGWINCH, on_resize)
|
||||
|
||||
with term.fullscreen(), term.cbreak(), term.hidden_cursor():
|
||||
draw(term)
|
||||
last_draw = time.monotonic()
|
||||
|
||||
while True:
|
||||
fds = []
|
||||
fd_map = {}
|
||||
if train_proc and train_proc.stdout:
|
||||
fd = train_proc.stdout.fileno()
|
||||
fds.append(fd)
|
||||
fd_map[fd] = 'train'
|
||||
if pm_proc and pm_proc.stdout:
|
||||
fd = pm_proc.stdout.fileno()
|
||||
fds.append(fd)
|
||||
fd_map[fd] = 'pm'
|
||||
fds.append(sys.stdin.fileno())
|
||||
fd_map[sys.stdin.fileno()] = 'stdin'
|
||||
|
||||
try:
|
||||
readable, _, _ = select.select(fds, [], [], 0.25)
|
||||
except (ValueError, OSError):
|
||||
continue
|
||||
|
||||
need_draw = resized[0]
|
||||
resized[0] = False
|
||||
|
||||
train_finished = False
|
||||
|
||||
for fd in readable:
|
||||
kind = fd_map.get(fd)
|
||||
if kind == 'train':
|
||||
try:
|
||||
data = os.read(fd, 65536)
|
||||
except BlockingIOError:
|
||||
continue
|
||||
except (OSError, ValueError):
|
||||
data = b''
|
||||
if not data:
|
||||
if train_proc.poll() is not None:
|
||||
try:
|
||||
rest = train_proc.stdout.read()
|
||||
if rest:
|
||||
for line in rest.decode('utf-8', errors='replace').split('\n'):
|
||||
if line:
|
||||
parse_line(line)
|
||||
except Exception:
|
||||
pass
|
||||
S.logs.append('[dashboard] Training finished. Press q to exit.')
|
||||
train_finished = True
|
||||
continue
|
||||
train_buf += data.decode('utf-8', errors='replace')
|
||||
while '\n' in train_buf:
|
||||
line, train_buf = train_buf.split('\n', 1)
|
||||
parse_line(line)
|
||||
need_draw = True
|
||||
|
||||
elif kind == 'pm':
|
||||
try:
|
||||
data = os.read(fd, 65536).decode('utf-8', errors='replace')
|
||||
except BlockingIOError:
|
||||
continue
|
||||
except (OSError, ValueError):
|
||||
data = ''
|
||||
if not data:
|
||||
continue
|
||||
pm_buf += data
|
||||
while '\n\n' in pm_buf or '*** ' in pm_buf:
|
||||
end = pm_buf.find('\n*** ', 1)
|
||||
if end < 0:
|
||||
end = pm_buf.find('\n\n', 1)
|
||||
if end < 0:
|
||||
break
|
||||
chunk = pm_buf[:end]
|
||||
pm_buf = pm_buf[end:]
|
||||
parse_powermetrics_text(chunk)
|
||||
if len(pm_buf) > 16384:
|
||||
pm_buf = pm_buf[-8192:]
|
||||
need_draw = True
|
||||
|
||||
elif kind == 'stdin':
|
||||
key = term.inkey(timeout=0)
|
||||
if not key:
|
||||
continue
|
||||
if key == 'q':
|
||||
cleanup()
|
||||
return
|
||||
elif key.name == 'KEY_UP':
|
||||
S.auto_scroll = False
|
||||
S.log_scroll = max(0, S.log_scroll - 1)
|
||||
need_draw = True
|
||||
elif key.name == 'KEY_DOWN':
|
||||
S.log_scroll += 1
|
||||
need_draw = True
|
||||
elif key == 'p':
|
||||
S.auto_scroll = not S.auto_scroll
|
||||
if S.auto_scroll:
|
||||
S.log_scroll = max(0, len(S.logs) - 10)
|
||||
need_draw = True
|
||||
elif key == 'r':
|
||||
if train_proc:
|
||||
train_proc.terminate()
|
||||
train_proc.wait()
|
||||
train_proc = spawn_training(resume=True, steps=args.steps)
|
||||
S.train_pid = train_proc.pid
|
||||
procs = [p for p in procs if p.poll() is None]
|
||||
procs.append(train_proc)
|
||||
S.logs.append('[dashboard] Restarted with --resume')
|
||||
need_draw = True
|
||||
elif key == 'g':
|
||||
with S.gen_lock:
|
||||
S.gen_status = 'generating'
|
||||
S.gen_step = S.step
|
||||
def force_gen():
|
||||
try:
|
||||
W = load_weights_from_ckpt(CKPT_PATH)
|
||||
if W:
|
||||
text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8)
|
||||
with S.gen_lock:
|
||||
S.gen_text = text
|
||||
S.gen_step = S.step
|
||||
S.gen_status = 'done'
|
||||
except Exception as e:
|
||||
with S.gen_lock:
|
||||
S.gen_text = f'[error: {e}]'
|
||||
S.gen_status = 'done'
|
||||
threading.Thread(target=force_gen, daemon=True).start()
|
||||
need_draw = True
|
||||
|
||||
now = time.monotonic()
|
||||
if not need_draw and now - last_draw > 1.0:
|
||||
need_draw = True
|
||||
if need_draw and now - last_draw > 0.066:
|
||||
draw(term)
|
||||
last_draw = now
|
||||
|
||||
if train_finished:
|
||||
draw(term)
|
||||
while True:
|
||||
key = term.inkey(timeout=1)
|
||||
if key == 'q':
|
||||
cleanup()
|
||||
return
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
// stories_config.h — Stories110M model config and structures
|
||||
#pragma once
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <objc/runtime.h>
|
||||
#import <objc/message.h>
|
||||
#import <dlfcn.h>
|
||||
#import <IOSurface/IOSurface.h>
|
||||
#import <mach/mach_time.h>
|
||||
#import <Accelerate/Accelerate.h>
|
||||
#include <math.h>
|
||||
#include <unistd.h>
|
||||
#include <dispatch/dispatch.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
|
||||
// Stories110M config
|
||||
#define DIM 768
|
||||
#define HIDDEN 2048
|
||||
#define HEADS 12
|
||||
#define HD (DIM/HEADS)
|
||||
#define SEQ 256
|
||||
#define NLAYERS 12
|
||||
#define VOCAB 32000
|
||||
#define ACCUM_STEPS 10
|
||||
#define MAX_COMPILES 100
|
||||
|
||||
// Per compile: 5 weight-bearing kernels per layer + 1 classifier = 5*12+1 = 61
|
||||
// Plus 1 static (sdpaBwd2 per layer, no weights) = 12 more but those are weight-free
|
||||
// Actually sdpaBwd2 has no weights, compile once per layer
|
||||
// Weight-bearing: fwdAttn(1) + fwdFFN(1) + ffnBwd(1) + sdpaBwd1(1) + qkvBwd(1) = 5 per layer
|
||||
// 5 * 12 = 60 weight-bearing compiles per batch
|
||||
// With MAX_COMPILES=100, we get 1 batch of ACCUM_STEPS before restart
|
||||
#define KERNELS_PER_LAYER 5
|
||||
#define TOTAL_WEIGHT_KERNELS (KERNELS_PER_LAYER * NLAYERS)
|
||||
|
||||
// Attention score channels for SDPA backward
|
||||
#define SCORE_CH (HEADS*SEQ)
|
||||
|
||||
// Weight sizes per layer
|
||||
#define WQ_SZ (DIM*DIM)
|
||||
#define WO_SZ (DIM*DIM)
|
||||
#define W1_SZ (HIDDEN*DIM)
|
||||
#define W2_SZ (DIM*HIDDEN)
|
||||
#define W3_SZ (HIDDEN*DIM)
|
||||
#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM)
|
||||
#define TOTAL_PARAMS (NLAYERS * LAYER_PARAMS + DIM + VOCAB*DIM) // +rms_final+embed
|
||||
|
||||
// Per-layer weight and optimizer state
|
||||
typedef struct {
|
||||
float *Wq, *Wk, *Wv, *Wo;
|
||||
float *W1, *W2, *W3;
|
||||
float *rms_att, *rms_ffn;
|
||||
} LayerWeights;
|
||||
|
||||
typedef struct {
|
||||
float *m, *v;
|
||||
size_t n;
|
||||
} AdamState;
|
||||
|
||||
typedef struct {
|
||||
AdamState Wq, Wk, Wv, Wo;
|
||||
AdamState W1, W2, W3;
|
||||
AdamState rms_att, rms_ffn;
|
||||
} LayerAdam;
|
||||
|
||||
// Per-layer activation buffers (saved for backward)
|
||||
typedef struct {
|
||||
float *layer_in; // [DIM, SEQ] input to this layer (for rmsnorm1 bwd)
|
||||
float *xnorm; // [DIM, SEQ] rmsnorm1 output
|
||||
float *Q, *K, *V; // [DIM, SEQ] QKV projections
|
||||
float *attn_out; // [DIM, SEQ] attention output (before Wo)
|
||||
float *o_out; // [DIM, SEQ] Wo output
|
||||
float *x2; // [DIM, SEQ] residual after attn
|
||||
float *x2norm; // [DIM, SEQ] rmsnorm2 output
|
||||
float *h1, *h3; // [HIDDEN, SEQ] FFN intermediates
|
||||
float *silu_out; // [HIDDEN, SEQ] SiLU(h1)*h3
|
||||
float *ffn_out; // [DIM, SEQ] FFN output
|
||||
} LayerActs;
|
||||
|
||||
// Per-layer gradient accumulators
|
||||
typedef struct {
|
||||
float *Wq, *Wk, *Wv, *Wo;
|
||||
float *W1, *W2, *W3;
|
||||
float *rms_att, *rms_ffn;
|
||||
} LayerGrads;
|
||||
|
||||
// ANE kernels per layer
|
||||
typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern;
|
||||
typedef struct {
|
||||
Kern *fwdAttn, *fwdFFN, *ffnBwd, *sdpaBwd1, *sdpaBwd2, *qkvBwd;
|
||||
} LayerKernels;
|
||||
|
||||
// Checkpoint header
|
||||
typedef struct {
|
||||
int magic; // 0x424C5A54 "BLZT"
|
||||
int version; // 2
|
||||
int step, total_steps;
|
||||
int n_layers, vocab_size, dim, hidden_dim, n_heads, seq_len;
|
||||
float lr, loss;
|
||||
double cum_compile, cum_train, cum_wall;
|
||||
int cum_steps, cum_batches;
|
||||
int adam_t;
|
||||
int pad[3]; // alignment
|
||||
} CkptHdr;
|
||||
|
||||
// llama2.c model file header
|
||||
typedef struct {
|
||||
int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len;
|
||||
} Llama2Config;
|
||||
|
||||
// Globals
|
||||
static Class g_D, g_I, g_AR, g_AIO;
|
||||
static mach_timebase_info_data_t g_tb;
|
||||
static int g_compile_count = 0;
|
||||
|
||||
static void ane_init(void) {
|
||||
dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW);
|
||||
g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor");
|
||||
g_I = NSClassFromString(@"_ANEInMemoryModel");
|
||||
g_AR = NSClassFromString(@"_ANERequest");
|
||||
g_AIO= NSClassFromString(@"_ANEIOSurfaceObject");
|
||||
}
|
||||
static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; }
|
||||
|
||||
// Alloc helpers
|
||||
static AdamState adam_alloc(size_t n) { AdamState s; s.m=(float*)calloc(n,4); s.v=(float*)calloc(n,4); s.n=n; return s; }
|
||||
static void adam_free(AdamState *s) { free(s->m); free(s->v); }
|
||||
|
||||
static LayerWeights layer_weights_alloc(void) {
|
||||
LayerWeights w;
|
||||
w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4);
|
||||
w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4);
|
||||
w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4);
|
||||
w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4);
|
||||
return w;
|
||||
}
|
||||
static void layer_weights_free(LayerWeights *w) {
|
||||
free(w->Wq);free(w->Wk);free(w->Wv);free(w->Wo);
|
||||
free(w->W1);free(w->W2);free(w->W3);
|
||||
free(w->rms_att);free(w->rms_ffn);
|
||||
}
|
||||
static LayerAdam layer_adam_alloc(void) {
|
||||
LayerAdam a;
|
||||
a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ);
|
||||
a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ);
|
||||
a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM);
|
||||
return a;
|
||||
}
|
||||
static void layer_adam_free(LayerAdam *a) {
|
||||
adam_free(&a->Wq);adam_free(&a->Wk);adam_free(&a->Wv);adam_free(&a->Wo);
|
||||
adam_free(&a->W1);adam_free(&a->W2);adam_free(&a->W3);
|
||||
adam_free(&a->rms_att);adam_free(&a->rms_ffn);
|
||||
}
|
||||
static LayerActs layer_acts_alloc(void) {
|
||||
LayerActs a;
|
||||
a.layer_in=(float*)malloc(SEQ*DIM*4);
|
||||
a.xnorm=(float*)malloc(SEQ*DIM*4); a.Q=(float*)malloc(SEQ*DIM*4);
|
||||
a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4);
|
||||
a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4);
|
||||
a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4);
|
||||
a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4);
|
||||
a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4);
|
||||
return a;
|
||||
}
|
||||
static void layer_acts_free(LayerActs *a) {
|
||||
free(a->layer_in);free(a->xnorm);free(a->Q);free(a->K);free(a->V);
|
||||
free(a->attn_out);free(a->o_out);free(a->x2);free(a->x2norm);
|
||||
free(a->h1);free(a->h3);free(a->silu_out);free(a->ffn_out);
|
||||
}
|
||||
static LayerGrads layer_grads_alloc(void) {
|
||||
LayerGrads g;
|
||||
g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4);
|
||||
g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4);
|
||||
g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4);
|
||||
g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4);
|
||||
return g;
|
||||
}
|
||||
static void layer_grads_zero(LayerGrads *g) {
|
||||
memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4);
|
||||
memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4);
|
||||
memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4);
|
||||
memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4);
|
||||
}
|
||||
static void layer_grads_free(LayerGrads *g) {
|
||||
free(g->Wq);free(g->Wk);free(g->Wv);free(g->Wo);
|
||||
free(g->W1);free(g->W2);free(g->W3);
|
||||
free(g->rms_att);free(g->rms_ffn);
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax
|
||||
#pragma once
|
||||
#include "stories_config.h"
|
||||
|
||||
static float *g_rms_tmp = NULL;
|
||||
|
||||
static void rmsnorm(float *out, const float *x, const float *w, int d, int S) {
|
||||
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
|
||||
float *ss = (float*)calloc(S, sizeof(float));
|
||||
for (int i=0; i<d; i++) {
|
||||
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
||||
}
|
||||
float invd = 1.0f/d, eps=1e-5f;
|
||||
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
||||
int n = S; vvrsqrtf(ss, ss, &n);
|
||||
for (int i=0; i<d; i++) {
|
||||
vDSP_vmul(x+i*S, 1, ss, 1, out+i*S, 1, (vDSP_Length)S);
|
||||
vDSP_vsmul(out+i*S, 1, &w[i], out+i*S, 1, (vDSP_Length)S);
|
||||
}
|
||||
free(ss);
|
||||
}
|
||||
|
||||
static void rmsnorm_bwd(float *dx, float *dw, const float *dy, const float *x, const float *w, int d, int S) {
|
||||
if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4);
|
||||
float *ss = (float*)calloc(S, sizeof(float));
|
||||
for (int i=0; i<d; i++) {
|
||||
vDSP_vmul(x+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vadd(g_rms_tmp, 1, ss, 1, ss, 1, (vDSP_Length)S);
|
||||
}
|
||||
float invd = 1.0f/d, eps=1e-5f;
|
||||
vDSP_vsmsa(ss, 1, &invd, &eps, ss, 1, (vDSP_Length)S);
|
||||
float *rrms = (float*)malloc(S*4);
|
||||
int n = S; vvrsqrtf(rrms, ss, &n);
|
||||
float *dot = (float*)calloc(S, sizeof(float));
|
||||
for (int i=0; i<d; i++) {
|
||||
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vsma(g_rms_tmp, 1, &w[i], dot, 1, dot, 1, (vDSP_Length)S);
|
||||
}
|
||||
vDSP_vmul(rrms, 1, rrms, 1, ss, 1, (vDSP_Length)S);
|
||||
vDSP_vsmul(ss, 1, &invd, ss, 1, (vDSP_Length)S);
|
||||
vDSP_vmul(dot, 1, ss, 1, dot, 1, (vDSP_Length)S);
|
||||
for (int i=0; i<d; i++) {
|
||||
vDSP_vmul(x+i*S, 1, dot, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vsub(g_rms_tmp, 1, dy+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vsmul(g_rms_tmp, 1, &w[i], dx+i*S, 1, (vDSP_Length)S);
|
||||
vDSP_vmul(dy+i*S, 1, x+i*S, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
vDSP_vmul(g_rms_tmp, 1, rrms, 1, g_rms_tmp, 1, (vDSP_Length)S);
|
||||
float s; vDSP_sve(g_rms_tmp, 1, &s, (vDSP_Length)S);
|
||||
dw[i] += s;
|
||||
}
|
||||
free(ss); free(rrms); free(dot);
|
||||
}
|
||||
|
||||
static void adam_update(float *w, const float *g, AdamState *s, int t, float lr, float b1, float b2, float eps) {
|
||||
float bc1 = 1.0f - powf(b1, t), bc2 = 1.0f - powf(b2, t);
|
||||
for (size_t i=0; i<s->n; i++) {
|
||||
s->m[i] = b1*s->m[i] + (1-b1)*g[i];
|
||||
s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i];
|
||||
float mh = s->m[i]/bc1, vh = s->v[i]/bc2;
|
||||
w[i] -= lr * mh / (sqrtf(vh) + eps);
|
||||
}
|
||||
}
|
||||
|
||||
// Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ])
|
||||
// logits[v*SEQ+t] = logit for vocab v, position t
|
||||
// targets[t] = target token id for position t
|
||||
// Returns mean CE loss, writes dlogits = softmax(logits) - one_hot(targets)
|
||||
// Data is column-major [V, S], but we process per-column (stride=1 within col is v*S+t, stride between v's is S)
|
||||
// For vDSP: transpose to row-major scratch [S, V] to vectorize softmax per position
|
||||
static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) {
|
||||
// Work in transposed layout [S, V] where each row is one position's logits (contiguous)
|
||||
float *buf = (float*)malloc(S * V * 4);
|
||||
// Transpose [V,S] → [S,V]: buf[t*V+v] = logits[v*S+t]
|
||||
vDSP_mtrans(logits, 1, buf, 1, (vDSP_Length)S, (vDSP_Length)V);
|
||||
|
||||
float total_loss = 0;
|
||||
float invS = 1.0f / S;
|
||||
for (int t = 0; t < S; t++) {
|
||||
float *row = buf + t * V;
|
||||
// max
|
||||
float maxv;
|
||||
vDSP_maxv(row, 1, &maxv, (vDSP_Length)V);
|
||||
// row -= maxv
|
||||
float neg_max = -maxv;
|
||||
vDSP_vsadd(row, 1, &neg_max, row, 1, (vDSP_Length)V);
|
||||
// exp in-place
|
||||
int n = V;
|
||||
vvexpf(row, row, &n);
|
||||
// sum
|
||||
float sum;
|
||||
vDSP_sve(row, 1, &sum, (vDSP_Length)V);
|
||||
// normalize
|
||||
float inv_sum = 1.0f / sum;
|
||||
vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V);
|
||||
// loss
|
||||
int tgt = targets[t];
|
||||
total_loss -= logf(row[tgt] + 1e-10f);
|
||||
// gradient: softmax - one_hot, then /S
|
||||
row[tgt] -= 1.0f;
|
||||
vDSP_vsmul(row, 1, &invS, row, 1, (vDSP_Length)V);
|
||||
}
|
||||
// Transpose back [S,V] → [V,S]
|
||||
vDSP_mtrans(buf, 1, dlogits, 1, (vDSP_Length)V, (vDSP_Length)S);
|
||||
free(buf);
|
||||
return total_loss / S;
|
||||
}
|
||||
|
||||
// Embedding lookup: token_ids → x [DIM, SEQ] (channel-first)
|
||||
// embed is [VOCAB, DIM] row-major (vocab_size rows, dim cols)
|
||||
static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) {
|
||||
for (int t = 0; t < seq; t++) {
|
||||
int tok = tokens[t];
|
||||
for (int d = 0; d < dim; d++) {
|
||||
x[d*seq + t] = embed[tok*dim + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Embedding backward: accumulate dE[tok] += dx[:,t] for each position
|
||||
static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) {
|
||||
for (int t = 0; t < seq; t++) {
|
||||
int tok = tokens[t];
|
||||
for (int d = 0; d < dim; d++) {
|
||||
d_embed[tok*dim + d] += dx[d*seq + t];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
// stories_io.h — IOSurface helpers, blob builders, NEON conversion
|
||||
#pragma once
|
||||
#include "stories_config.h"
|
||||
#include <arm_neon.h>
|
||||
|
||||
static IOSurfaceRef make_surface(size_t bytes) {
|
||||
return IOSurfaceCreate((__bridge CFDictionaryRef)@{
|
||||
(id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1,
|
||||
(id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes),
|
||||
(id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0});
|
||||
}
|
||||
|
||||
static NSData *build_blob(const float *w, int rows, int cols) {
|
||||
int ws=rows*cols*2, tot=128+ws;
|
||||
uint8_t *b=(uint8_t*)calloc(tot,1);
|
||||
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
|
||||
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
|
||||
_Float16 *fp16=(_Float16*)(b+128);
|
||||
for(int i=0;i<rows*cols;i++) fp16[i]=(_Float16)w[i];
|
||||
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
|
||||
}
|
||||
static NSData *build_blob_t(const float *w, int rows, int cols) {
|
||||
int ws=cols*rows*2, tot=128+ws;
|
||||
uint8_t *b=(uint8_t*)calloc(tot,1);
|
||||
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
|
||||
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
|
||||
_Float16 *fp16=(_Float16*)(b+128);
|
||||
for(int i=0;i<rows;i++) for(int j=0;j<cols;j++) fp16[j*rows+i]=(_Float16)w[i*cols+j];
|
||||
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
|
||||
}
|
||||
static NSData *build_blob_fp16(_Float16 *d, int cnt) {
|
||||
int ws=cnt*2, tot=128+ws;
|
||||
uint8_t *b=(uint8_t*)calloc(tot,1);
|
||||
b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1;
|
||||
*(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128;
|
||||
memcpy(b+128,d,ws);
|
||||
return [NSData dataWithBytesNoCopy:b length:tot freeWhenDone:YES];
|
||||
}
|
||||
|
||||
// NEON vectorized conversion
|
||||
static void cvt_f16_f32(float *dst, const _Float16 *src, int n) {
|
||||
int i = 0;
|
||||
for (; i+7 < n; i += 8) {
|
||||
float16x8_t h = vld1q_f16((const __fp16*)(src+i));
|
||||
vst1q_f32(dst+i, vcvt_f32_f16(vget_low_f16(h)));
|
||||
vst1q_f32(dst+i+4, vcvt_f32_f16(vget_high_f16(h)));
|
||||
}
|
||||
for (; i < n; i++) dst[i] = (float)src[i];
|
||||
}
|
||||
static void cvt_f32_f16(_Float16 *dst, const float *src, int n) {
|
||||
int i = 0;
|
||||
for (; i+7 < n; i += 8) {
|
||||
float16x8_t h = vcombine_f16(vcvt_f16_f32(vld1q_f32(src+i)),
|
||||
vcvt_f16_f32(vld1q_f32(src+i+4)));
|
||||
vst1q_f16((__fp16*)(dst+i), h);
|
||||
}
|
||||
for (; i < n; i++) dst[i] = (_Float16)src[i];
|
||||
}
|
||||
|
||||
// IOSurface I/O (channel-first [C,S] layout)
|
||||
static void io_write_fp16(IOSurfaceRef s, const float *data, int channels, int sp) {
|
||||
IOSurfaceLock(s, 0, NULL);
|
||||
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s), data, channels * sp);
|
||||
IOSurfaceUnlock(s, 0, NULL);
|
||||
}
|
||||
static void io_read_fp16(IOSurfaceRef s, float *data, int ch_off, int channels, int sp) {
|
||||
IOSurfaceLock(s, kIOSurfaceLockReadOnly, NULL);
|
||||
cvt_f16_f32(data, (_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, channels * sp);
|
||||
IOSurfaceUnlock(s, kIOSurfaceLockReadOnly, NULL);
|
||||
}
|
||||
static void io_copy(IOSurfaceRef dst, int dst_ch, IOSurfaceRef src, int src_ch, int channels, int sp) {
|
||||
IOSurfaceLock(dst, 0, NULL);
|
||||
IOSurfaceLock(src, kIOSurfaceLockReadOnly, NULL);
|
||||
memcpy((_Float16*)IOSurfaceGetBaseAddress(dst) + dst_ch*sp,
|
||||
(_Float16*)IOSurfaceGetBaseAddress(src) + src_ch*sp,
|
||||
channels * sp * sizeof(_Float16));
|
||||
IOSurfaceUnlock(src, kIOSurfaceLockReadOnly, NULL);
|
||||
IOSurfaceUnlock(dst, 0, NULL);
|
||||
}
|
||||
static void io_write_fp16_at(IOSurfaceRef s, int ch_off, const float *data, int channels, int sp) {
|
||||
IOSurfaceLock(s, 0, NULL);
|
||||
cvt_f32_f16((_Float16*)IOSurfaceGetBaseAddress(s) + ch_off * sp, data, channels * sp);
|
||||
IOSurfaceUnlock(s, 0, NULL);
|
||||
}
|
||||
|
||||
// Kernel compile/eval
|
||||
static Kern *compile_kern_mil_w(NSString *mil, NSDictionary *weights, int ic_bytes, int oc_bytes) {
|
||||
@autoreleasepool {
|
||||
NSData *md = [mil dataUsingEncoding:NSUTF8StringEncoding];
|
||||
id desc = ((id(*)(Class,SEL,id,id,id))objc_msgSend)(g_D, @selector(modelWithMILText:weights:optionsPlist:), md, weights, nil);
|
||||
if (!desc) { printf(" [compile] desc=NULL\n"); return NULL; }
|
||||
id mdl = ((id(*)(Class,SEL,id))objc_msgSend)(g_I, @selector(inMemoryModelWithDescriptor:), desc);
|
||||
id hx = ((id(*)(id,SEL))objc_msgSend)(mdl, @selector(hexStringIdentifier));
|
||||
NSString *td = [NSTemporaryDirectory() stringByAppendingPathComponent:hx];
|
||||
[[NSFileManager defaultManager] createDirectoryAtPath:[td stringByAppendingPathComponent:@"weights"] withIntermediateDirectories:YES attributes:nil error:nil];
|
||||
[md writeToFile:[td stringByAppendingPathComponent:@"model.mil"] atomically:YES];
|
||||
for (NSString *path in weights) {
|
||||
NSString *rel = [path stringByReplacingOccurrencesOfString:@"@model_path/" withString:@""];
|
||||
[weights[path][@"data"] writeToFile:[td stringByAppendingPathComponent:rel] atomically:YES];
|
||||
}
|
||||
NSError *e = nil;
|
||||
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(compileWithQoS:options:error:), 21, @{}, &e)) {
|
||||
printf(" [compile] FAIL: %s\n", e ? [[e description] UTF8String] : "no error"); return NULL;
|
||||
}
|
||||
if (!((BOOL(*)(id,SEL,unsigned int,id,NSError**))objc_msgSend)(mdl, @selector(loadWithQoS:options:error:), 21, @{}, &e)) {
|
||||
printf(" [compile] load FAIL\n"); return NULL;
|
||||
}
|
||||
__sync_fetch_and_add(&g_compile_count, 1);
|
||||
Kern *k = (Kern*)calloc(1, sizeof(Kern));
|
||||
k->model = (void*)CFBridgingRetain(mdl);
|
||||
k->ioIn = make_surface(ic_bytes);
|
||||
k->ioOut = make_surface(oc_bytes);
|
||||
id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn);
|
||||
id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut);
|
||||
k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR,
|
||||
@selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:),
|
||||
@[wI], @[@0], @[wO], @[@0], nil, nil, @0));
|
||||
k->tmpDir = (void*)CFBridgingRetain(td);
|
||||
return k;
|
||||
}
|
||||
}
|
||||
static void free_kern(Kern *k) {
|
||||
if (!k) return;
|
||||
id mdl = (__bridge id)k->model; NSError *e = nil;
|
||||
((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e);
|
||||
CFRelease(k->ioIn); CFRelease(k->ioOut);
|
||||
[[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil];
|
||||
CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir);
|
||||
free(k);
|
||||
}
|
||||
static void ane_eval(Kern *k) {
|
||||
id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil;
|
||||
((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e);
|
||||
}
|
||||
|
|
@ -0,0 +1,286 @@
|
|||
// stories_mil.h — MIL program generators for ANE kernels
|
||||
// Same architecture as single-layer train_large.m but parameterized
|
||||
#pragma once
|
||||
#include "stories_io.h"
|
||||
|
||||
#define MIL_HDR \
|
||||
@"program(1.3)\n[buildInfo = dict<string, string>({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \
|
||||
"{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \
|
||||
"{\"coremltools-version\", \"9.0\"}})]\n{\n"
|
||||
#define CONV_CONST \
|
||||
" string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \
|
||||
" tensor<int32, [2]> st = const()[name=string(\"st\"), val=tensor<int32, [2]>([1,1])];\n" \
|
||||
" tensor<int32, [4]> pd = const()[name=string(\"pd\"), val=tensor<int32, [4]>([0,0,0,0])];\n" \
|
||||
" tensor<int32, [2]> dl = const()[name=string(\"dl\"), val=tensor<int32, [2]>([1,1])];\n" \
|
||||
" int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n"
|
||||
|
||||
// SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm)
|
||||
static NSString *gen_sdpa_fwd_taps(void) {
|
||||
float sc = 1.0f/sqrtf((float)HD);
|
||||
float invd = 1.0f/(float)DIM;
|
||||
NSMutableString *m = [NSMutableString string];
|
||||
[m appendString:MIL_HDR];
|
||||
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([1])];\n"];
|
||||
[m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ];
|
||||
[m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ];
|
||||
[m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ];
|
||||
[m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,1]> rw = const()[name=string(\"rw\"), val=tensor<fp16, [1,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/rms1.bin\"), offset=uint64(64)))];\n", DIM, DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ];
|
||||
[m appendString:@CONV_CONST];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wq = const()[name=string(\"Wq\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wq.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wk = const()[name=string(\"Wk\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wk.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wv = const()[name=string(\"Wv\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wv.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wo = const()[name=string(\"Wo\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wo.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=string(\"cq\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=string(\"ck\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=string(\"cv\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> qsh = const()[name=string(\"qsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
|
||||
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS,SEQ,HD];
|
||||
[m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"];
|
||||
[m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> os = const()[name=string(\"os\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=string(\"co\")];\n", DIM,SEQ];
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM,SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
||||
// FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm)
|
||||
static NSString *gen_ffn_fwd_taps(void) {
|
||||
float invd = 1.0f/(float)DIM;
|
||||
NSMutableString *m = [NSMutableString string];
|
||||
[m appendString:MIL_HDR];
|
||||
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([1])];\n"];
|
||||
[m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ];
|
||||
[m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ];
|
||||
[m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ];
|
||||
[m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,1,%d]> rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,1]> rw = const()[name=string(\"rw\"), val=tensor<fp16, [1,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/rms2.bin\"), offset=uint64(64)))];\n", DIM, DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ];
|
||||
[m appendString:@CONV_CONST];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W1 = const()[name=string(\"W1\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w1.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W3 = const()[name=string(\"W3\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w3.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W2 = const()[name=string(\"W2\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w2.bin\"), offset=uint64(64)))];\n", DIM,HIDDEN,DIM,HIDDEN];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=string(\"c1\")];\n", HIDDEN,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=string(\"c3\")];\n", HIDDEN,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=string(\"c2\")];\n", DIM,SEQ];
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=string(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
||||
// FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3)
|
||||
static NSString *gen_ffn_bwd(void) {
|
||||
NSMutableString *m = [NSMutableString string];
|
||||
[m appendString:MIL_HDR];
|
||||
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", DIM+2*HIDDEN, SEQ];
|
||||
[m appendString:@CONV_CONST];
|
||||
[m appendString:@" tensor<int32, [4]> bd = const()[name=string(\"bd\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
|
||||
[m appendFormat:@" tensor<int32, [4]> sd = const()[name=string(\"sd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dffn = slice_by_size(x=x,begin=bd,size=sd)[name=string(\"s0\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
|
||||
[m appendFormat:@" tensor<int32, [4]> s1 = const()[name=string(\"s1\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h1 = slice_by_size(x=x,begin=b1,size=s1)[name=string(\"s1x\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM+HIDDEN];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> h3 = slice_by_size(x=x,begin=b3,size=s1)[name=string(\"s3x\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W2t = const()[name=string(\"W2t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w2t.bin\"), offset=uint64(64)))];\n", HIDDEN, DIM, HIDDEN, DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=string(\"cw2\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ];
|
||||
[m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> oms = sub(x=one,y=sig)[name=string(\"oms\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> homs = mul(x=h1,y=oms)[name=string(\"homs\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> brk = add(x=one,y=homs)[name=string(\"brk\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dsd = mul(x=sig,y=brk)[name=string(\"dsd\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> t1 = mul(x=dsilu,y=h3)[name=string(\"t1\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh1 = mul(x=t1,y=dsd)[name=string(\"dh1\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> slh = mul(x=h1,y=sig)[name=string(\"slh\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dh3 = mul(x=dsilu,y=slh)[name=string(\"dh3\")];\n", HIDDEN, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W1t = const()[name=string(\"W1t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w1t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> W3t = const()[name=string(\"W3t\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/w3t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=string(\"cw1\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=string(\"cw3\")];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx = add(x=dx1,y=dx3)[name=string(\"adx\")];\n", DIM, SEQ];
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=string(\"cat\")];\n", DIM+2*HIDDEN, SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
||||
// QKV backward: concat(dq,dk,dv) → dx
|
||||
static NSString *gen_qkvb(void) {
|
||||
NSMutableString *m = [NSMutableString string];
|
||||
[m appendString:MIL_HDR];
|
||||
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 3*DIM, SEQ];
|
||||
[m appendString:@CONV_CONST];
|
||||
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
|
||||
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dq = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dk = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dv = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wqt = const()[name=string(\"Wqt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wqt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wkt = const()[name=string(\"Wkt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wkt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wvt = const()[name=string(\"Wvt\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wvt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=string(\"cq\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=string(\"ck\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=string(\"cv\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = add(x=dxqk,y=dxv)[name=string(\"out\")];\n", DIM,SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
||||
// SDPA backward part 1 + Wo^T
|
||||
static NSString *gen_sdpa_bwd1(void) {
|
||||
float sc = 1.0f/sqrtf((float)HD);
|
||||
NSMutableString *m = [NSMutableString string];
|
||||
[m appendString:MIL_HDR];
|
||||
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", 4*DIM, SEQ];
|
||||
[m appendString:@CONV_CONST];
|
||||
[m appendFormat:@" tensor<int32, [4]> sz = const()[name=string(\"sz\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
|
||||
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 3*DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [%d,%d,1,1]> Wot = const()[name=string(\"Wot\"), val=tensor<fp16, [%d,%d,1,1]>(BLOBFILE(path=string(\"@model_path/weights/wot.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=string(\"cwo\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
|
||||
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dr = reshape(shape=rsh,x=df)[name=string(\"rd\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> da = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS,SEQ,HD];
|
||||
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
|
||||
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,1,%d,%d]> cm = const()[name=string(\"cm\"), val=tensor<fp16, [1,1,%d,%d]>(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=string(\"dv\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=string(\"dp\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> dvs = const()[name=string(\"dvs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> scs = const()[name=string(\"scs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH,SEQ];
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH,SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
||||
// SDPA backward part 2: concat(probs,dp,Q,K) → concat(dQ,dK)
|
||||
static NSString *gen_sdpa_bwd2(void) {
|
||||
float sc = 1.0f/sqrtf((float)HD);
|
||||
int bwd2_in = 2*SCORE_CH + 2*DIM;
|
||||
NSMutableString *m = [NSMutableString string];
|
||||
[m appendString:MIL_HDR];
|
||||
[m appendFormat:@" func main<ios18>(tensor<fp16, [1, %d, 1, %d]> x) {\n", bwd2_in, SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> sz_sc = const()[name=string(\"szsc\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", SCORE_CH, SEQ];
|
||||
[m appendString:@" tensor<int32, [4]> b0 = const()[name=string(\"b0\"), val=tensor<int32, [4]>([0,0,0,0])];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b1 = const()[name=string(\"b1\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", SCORE_CH];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> sz_d = const()[name=string(\"szd\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM, SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b2 = const()[name=string(\"b2\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> b3 = const()[name=string(\"b3\"), val=tensor<int32, [4]>([0,%d,0,0])];\n", 2*SCORE_CH+DIM];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> ssh = const()[name=string(\"ssh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> rsh = const()[name=string(\"rsh\"), val=tensor<int32, [4]>([1,%d,%d,%d])];\n", HEADS,HD,SEQ];
|
||||
[m appendString:@" tensor<int32, [4]> pm = const()[name=string(\"pm\"), val=tensor<int32, [4]>([0,1,3,2])];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendString:@" tensor<int32, [1]> rax = const()[name=string(\"rax\"), val=tensor<int32, [1]>([-1])];\n"];
|
||||
[m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,1]> spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS,SEQ,SEQ];
|
||||
[m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"];
|
||||
[m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS,SEQ,HD];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,%d,%d]> dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS,HD,SEQ];
|
||||
[m appendFormat:@" tensor<int32, [4]> fs = const()[name=string(\"fs\"), val=tensor<int32, [4]>([1,%d,1,%d])];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM,SEQ];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM,SEQ];
|
||||
[m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"];
|
||||
[m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"];
|
||||
[m appendFormat:@" tensor<fp16, [1,%d,1,%d]> out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM,SEQ];
|
||||
[m appendString:@" } -> (out);\n}\n"];
|
||||
return m;
|
||||
}
|
||||
|
||||
// Mask blob (causal mask [SEQ,SEQ])
|
||||
static NSData *g_mask_blob = nil;
|
||||
static NSData *get_mask_blob(void) {
|
||||
if (!g_mask_blob) {
|
||||
_Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16));
|
||||
for(int t=0;t<SEQ;t++) for(int t2=0;t2<SEQ;t2++)
|
||||
mask[t*SEQ+t2] = (t2<=t) ? (_Float16)0.0f : (_Float16)(-65504.0f);
|
||||
g_mask_blob = build_blob_fp16(mask, SEQ*SEQ);
|
||||
free(mask);
|
||||
}
|
||||
return g_mask_blob;
|
||||
}
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Extract pretokenized TinyStories data from zip.
|
||||
Data format: flat uint16 token IDs (llama2.c BPE, 32K vocab).
|
||||
Source: ~/tiny_stories_data_pretokenized.zip"""
|
||||
|
||||
import os, struct, zipfile
|
||||
from pathlib import Path
|
||||
|
||||
ZIP_PATH = os.path.expanduser('~/tiny_stories_data_pretokenized.zip')
|
||||
OUTPUT_PATH = str(Path(__file__).resolve().parent / 'tinystories_data00.bin')
|
||||
|
||||
def main():
|
||||
if os.path.exists(OUTPUT_PATH):
|
||||
n = os.path.getsize(OUTPUT_PATH) // 2
|
||||
print(f"{OUTPUT_PATH} already exists ({n} tokens, {os.path.getsize(OUTPUT_PATH)/1e6:.1f} MB)")
|
||||
return
|
||||
|
||||
print(f"Extracting data00.bin from {ZIP_PATH}...")
|
||||
with zipfile.ZipFile(ZIP_PATH, 'r') as z:
|
||||
with z.open('data00.bin') as src, open(OUTPUT_PATH, 'wb') as dst:
|
||||
while True:
|
||||
chunk = src.read(1 << 20)
|
||||
if not chunk:
|
||||
break
|
||||
dst.write(chunk)
|
||||
|
||||
n = os.path.getsize(OUTPUT_PATH) // 2
|
||||
print(f"Written {OUTPUT_PATH} ({n} tokens, {os.path.getsize(OUTPUT_PATH)/1e6:.1f} MB)")
|
||||
|
||||
# Sanity check
|
||||
with open(OUTPUT_PATH, 'rb') as f:
|
||||
tokens = struct.unpack('<10H', f.read(20))
|
||||
print(f"First 10 tokens: {tokens}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue