diff --git a/training/Makefile b/training/Makefile index 90c2977..226bb39 100644 --- a/training/Makefile +++ b/training/Makefile @@ -3,10 +3,18 @@ CFLAGS = -O2 -Wall -Wno-deprecated-declarations -fobjc-arc FRAMEWORKS = -framework Foundation -framework CoreML -framework IOSurface LDFLAGS = $(FRAMEWORKS) -ldl +HEADERS_LARGE = stories_config.h stories_io.h stories_mil.h stories_cpu_ops.h + train: train.m ane_runtime.h ane_mil_gen.h model.h forward.h backward.h $(CC) $(CFLAGS) -o $@ train.m $(LDFLAGS) -clean: - rm -f train +train_large: train_large.m $(HEADERS_LARGE) + $(CC) $(CFLAGS) -o $@ train_large.m $(LDFLAGS) -framework Accelerate -.PHONY: clean +tokenize: + python3 tokenize.py + +clean: + rm -f train train_large + +.PHONY: clean tokenize diff --git a/training/README.md b/training/README.md new file mode 100644 index 0000000..53edbb9 --- /dev/null +++ b/training/README.md @@ -0,0 +1,69 @@ +# ANE Training — Stories110M on Apple Neural Engine + +Training a 109M-parameter Llama2-architecture transformer (Stories110M) directly on Apple's Neural Engine using private ANE APIs. + +![Dashboard](dashboard.gif) + +## Architecture + +- **Model**: Stories110M — dim=768, hidden=2048, heads=12, layers=12, vocab=32000, seq=256 +- **109.53M params** (84.95M transformer + 24.58M embedding) +- **72 ANE kernels** per compile (60 weight-bearing, 12 weight-free sdpaBwd2) +- **6 kernel types per layer**: fwdAttn, fwdFFN, ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd + +## Performance + +| Component | Time (ms/step) | +|-----------|---------------| +| ANE eval | 9.6 | +| IO (fp16 conversion) | 4.1 | +| Classifier (cblas) | 9.1 | +| Cross-entropy + residuals | 14.4 | +| RMSNorm | 0.1 | +| **Total** | **107 ms/step** | + +## Files + +| File | Description | +|------|-------------| +| `train_large.m` | Main training loop — 12-layer forward/backward, checkpoint, exec() restart | +| `stories_config.h` | Model config, structs, alloc helpers | +| `stories_io.h` | IOSurface I/O, NEON fp16 conversion, kernel compile/eval | +| `stories_mil.h` | MIL program generators for all 6 ANE kernel types | +| `stories_cpu_ops.h` | vDSP-vectorized RMSNorm, cross-entropy, Adam, embedding ops | +| `dashboard.py` | TUI dashboard — loss curve, power/CPU/memory graphs, text generation | +| `tokenize.py` | Extract pretokenized TinyStories data | +| `Makefile` | Build targets | + +## How it works + +1. **Forward pass**: Each layer runs fwdAttn (QKV + SDPA + Wo) and fwdFFN (W1 + SiLU(W3) + W2) on ANE via MIL-compiled kernels. Final RMSNorm + classifier matmul on CPU (cblas). + +2. **Backward pass**: Reverse layer order. ffnBwd, sdpaBwd1, sdpaBwd2, qkvBwd on ANE. Weight gradients (dW) via async cblas_sgemm on CPU. RMSNorm backward via vDSP. + +3. **Compile budget**: ANE has a ~119 compile limit per process. With 72 kernels per batch, we run 10 accumulation steps then `exec()` restart with checkpoint resume. + +4. **Data**: Real TinyStories text (20M tokens), mmap'd uint16 token IDs, random position sampling per step. + +## Usage + +```bash +# Extract tokenized data +python3 tokenize.py + +# Build and train +make train_large +./train_large # fresh start +./train_large --resume # resume from checkpoint + +# Monitor with dashboard +pip install blessed psutil numpy +python3 dashboard.py --resume # needs sudo for powermetrics +``` + +## Key techniques + +- **NEON vectorized fp16<->fp32**: ARM NEON intrinsics for fast IOSurface data transfer +- **vDSP cross-entropy**: `vDSP_mtrans` + `vvexpf` + `vDSP_sve` — 8x faster than scalar +- **Async weight gradients**: cblas_sgemm dispatched to background queue, overlapped with ANE +- **SDPA causal mask workaround**: ANE hardware ignores attn_mask, so we decompose attention into Q@K^T (ANE conv) + mask+softmax (CPU) + scores@V (ANE conv) diff --git a/training/dashboard.gif b/training/dashboard.gif new file mode 100644 index 0000000..120f7d5 Binary files /dev/null and b/training/dashboard.gif differ diff --git a/training/dashboard.py b/training/dashboard.py new file mode 100644 index 0000000..a3a1503 --- /dev/null +++ b/training/dashboard.py @@ -0,0 +1,882 @@ +"""TUI dashboard for ANE training (train_large). Uses blessed for terminal UI.""" + +import argparse, fcntl, math, os, re, select, signal, struct, subprocess, sys, time, threading +from collections import deque +from pathlib import Path + +import numpy as np + +try: + from blessed import Terminal +except ImportError: + print('pip install blessed') + sys.exit(1) + +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + +DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS = 768, 2048, 12, 256, 32000, 12 +HD = DIM // HEADS +CKPT_PATH = 'ane_stories110M_ckpt.bin' +TOKENIZER_PATH = str(Path(__file__).resolve().parent.parent.parent / 'assets' / 'models' / 'tokenizer.bin') + + +class State: + def __init__(self): + self.model_config = {} + self.params = {} + self.kernels = {} + self.training = {} + self.flops = {} + self.step = 0 + self.total_steps = 0 + self.loss = 0.0 + self.best_loss = float('inf') + self.loss_history = [] + self.ms_per_step = 0.0 + self.compile_pct = 0.0 + self.compiles = 0 + self.component_timing = {} + self.power = {'ane': 0.0, 'cpu': 0.0, 'gpu': 0.0} + self.power_history_ane = deque(maxlen=300) + self.power_history_cpu = deque(maxlen=300) + self.logs = deque(maxlen=2000) + self.log_scroll = 0 + self.auto_scroll = True + self.batch_num = 0 + self.efficiency = {} + self.gen_text = '' + self.gen_step = 0 + self.gen_status = 'idle' + self.gen_lock = threading.Lock() + self.cpu_pct_history = deque(maxlen=300) + self.mem_mb_history = deque(maxlen=300) + self.proc_mem_mb_history = deque(maxlen=300) + self.train_pid = None + +S = State() + + +class Tokenizer: + def __init__(self, path): + self.vocab = [] + self.scores = [] + with open(path, 'rb') as f: + max_len = struct.unpack('i', f.read(4))[0] + for _ in range(VOCAB): + score = struct.unpack('f', f.read(4))[0] + slen = struct.unpack('i', f.read(4))[0] + tok = f.read(slen).decode('utf-8', errors='replace') + self.vocab.append(tok) + self.scores.append(score) + + def decode(self, token_id): + if 0 <= token_id < len(self.vocab): + s = self.vocab[token_id] + if s.startswith('<0x') and s.endswith('>'): + try: + return chr(int(s[3:-1], 16)) + except: + return s + return s + return '' + +_tokenizer = None +def get_tokenizer(): + global _tokenizer + if _tokenizer is None: + try: + _tokenizer = Tokenizer(TOKENIZER_PATH) + except Exception as e: + S.logs.append(f'[gen] tokenizer load failed: {e}') + return None + return _tokenizer + + +def load_weights_from_ckpt(path): + try: + with open(path, 'rb') as f: + # CkptHdr: 96 bytes (verified with sizeof) + hdr = f.read(96) + if len(hdr) < 96: + return None + wq_sz = DIM * DIM + wo_sz = DIM * DIM + w1_sz = HIDDEN * DIM + w2_sz = DIM * HIDDEN + w3_sz = HIDDEN * DIM + # Per-layer: weights + adam state (m,v for each) + adam_per_layer = (wq_sz*2 + wq_sz*2 + wq_sz*2 + wo_sz*2 + + w1_sz*2 + w2_sz*2 + w3_sz*2 + DIM*2 + DIM*2) + W = {} + for L in range(NLAYERS): + W[f'Wq{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() + W[f'Wk{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() + W[f'Wv{L}'] = np.frombuffer(f.read(wq_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() + W[f'Wo{L}'] = np.frombuffer(f.read(wo_sz * 4), dtype=np.float32).reshape(DIM, DIM).copy() + W[f'W1_{L}'] = np.frombuffer(f.read(w1_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy() + W[f'W2_{L}'] = np.frombuffer(f.read(w2_sz * 4), dtype=np.float32).reshape(DIM, HIDDEN).copy() + W[f'W3_{L}'] = np.frombuffer(f.read(w3_sz * 4), dtype=np.float32).reshape(HIDDEN, DIM).copy() + W[f'rms1_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() + W[f'rms2_{L}'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() + # Skip adam state for this layer + f.seek(adam_per_layer * 4, 1) + W['rms_final'] = np.frombuffer(f.read(DIM * 4), dtype=np.float32).copy() + f.seek(DIM * 2 * 4, 1) # skip rms_final adam + W['embed'] = np.frombuffer(f.read(VOCAB * DIM * 4), dtype=np.float32).reshape(VOCAB, DIM).copy() + return W + except Exception as e: + S.logs.append(f'[gen] ckpt load failed: {e}') + return None + + +def rmsnorm(x, w): + ss = np.mean(x * x) + 1e-5 + return x * (1.0 / math.sqrt(ss)) * w + +def softmax(x): + x = x - np.max(x) + e = np.exp(x) + return e / np.sum(e) + +def generate_text(W, tok, max_tokens=64, temperature=0.8): + tokenizer = get_tokenizer() + if tokenizer is None: + return '[no tokenizer]' + + tokens = [1] + text_parts = [] + + # Precompute RoPE frequencies + freqs = np.zeros((SEQ, HD // 2), dtype=np.float32) + for pos in range(SEQ): + for i in range(HD // 2): + freq = 1.0 / (10000.0 ** (2.0 * i / HD)) + freqs[pos, i] = pos * freq + + for step in range(max_tokens): + seq_len = len(tokens) + if seq_len > SEQ: + break + + x = W['embed'][tokens[-1]].copy() + + for L in range(NLAYERS): + # RMSNorm + QKV + xn = rmsnorm(x, W[f'rms1_{L}']) + q = W[f'Wq{L}'] @ xn + k = W[f'Wk{L}'] @ xn + v = W[f'Wv{L}'] @ xn + + # RoPE + pos = seq_len - 1 + for h in range(HEADS): + for i in range(HD // 2): + freq = freqs[pos, i] + cos_v, sin_v = math.cos(freq), math.sin(freq) + qi, qi1 = q[h * HD + 2 * i], q[h * HD + 2 * i + 1] + q[h * HD + 2 * i] = qi * cos_v - qi1 * sin_v + q[h * HD + 2 * i + 1] = qi * sin_v + qi1 * cos_v + ki, ki1 = k[h * HD + 2 * i], k[h * HD + 2 * i + 1] + k[h * HD + 2 * i] = ki * cos_v - ki1 * sin_v + k[h * HD + 2 * i + 1] = ki * sin_v + ki1 * cos_v + + # Attention (single token) + o = np.zeros(DIM, dtype=np.float32) + for h in range(HEADS): + qh = q[h * HD:(h + 1) * HD] + kh = k[h * HD:(h + 1) * HD] + vh = v[h * HD:(h + 1) * HD] + score = np.dot(qh, kh) / math.sqrt(HD) + o[h * HD:(h + 1) * HD] = vh + + # Residual + output projection + x2 = x + W[f'Wo{L}'] @ o + + # FFN + x2n = rmsnorm(x2, W[f'rms2_{L}']) + h1 = W[f'W1_{L}'] @ x2n + h3 = W[f'W3_{L}'] @ x2n + # SiLU + h1 = h1 * (1.0 / (1.0 + np.exp(-h1))) * h3 + ffn_out = W[f'W2_{L}'] @ h1 + + x = x2 + ffn_out + + x = rmsnorm(x, W['rms_final']) + + # Logits + logits = W['embed'] @ x + + if temperature < 0.01: + next_tok = int(np.argmax(logits)) + else: + logits = logits / temperature + probs = softmax(logits) + next_tok = int(np.random.choice(VOCAB, p=probs)) + + if next_tok == 2: + break + tokens.append(next_tok) + piece = tokenizer.decode(next_tok) + text_parts.append(piece) + + return ''.join(text_parts) + + +def generation_thread(): + last_gen_step = -1 + while True: + time.sleep(5) + if S.step <= last_gen_step + 99: + continue + if not os.path.exists(CKPT_PATH): + continue + with S.gen_lock: + S.gen_status = 'generating' + S.gen_step = S.step + try: + W = load_weights_from_ckpt(CKPT_PATH) + if W is None: + with S.gen_lock: + S.gen_status = 'idle' + continue + text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8) + with S.gen_lock: + S.gen_text = text + S.gen_step = S.step + S.gen_status = 'done' + S.step # just to reference + except Exception as e: + with S.gen_lock: + S.gen_text = f'[error: {e}]' + S.gen_status = 'done' + last_gen_step = S.step + + +def sysmetrics_thread(): + while True: + time.sleep(1) + if not HAS_PSUTIL: + continue + now = time.monotonic() + S.cpu_pct_history.append(psutil.cpu_percent(interval=None)) + mem = psutil.virtual_memory() + S.mem_mb_history.append(mem.used / (1024 * 1024)) + pid = S.train_pid + if pid: + try: + p = psutil.Process(pid) + S.proc_mem_mb_history.append(p.memory_info().rss / (1024 * 1024)) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + +RE_CONFIG = re.compile(r'dim=(\d+) hidden=(\d+) heads=(\d+) seq=(\d+) vocab=(\d+) layers=(\d+)') +RE_PARAMS = re.compile(r'Params: ([\d.]+)M \(transformer ([\d.]+)M \+ embed ([\d.]+)M\)') +RE_KERNELS = re.compile(r'Kernels: (\d+).*?(\d+) weight-bearing') +RE_ACCUM = re.compile(r'Accum (\d+).*LR=([\d.e+-]+)') +RE_STEP = re.compile(r'step\s+(\d+)\s+loss=([\d.]+)') +RE_BATCH = re.compile(r'\[batch (\d+): compile=([\d.]+)ms train=([\d.]+)ms \(([\d.]+)ms/step\) compiles=(\d+)\]') +RE_TIMING = re.compile(r'ane=([\d.]+) io=([\d.]+) cls=([\d.]+) elem=([\d.]+) rms=([\d.]+) cblas_wait=([\d.]+)') +RE_RESTART = re.compile(r'\[exec\(\) restart step (\d+)') +RE_RESUME = re.compile(r'\[RESUMED step (\d+), loss=([\d.]+)\]') +RE_FLOPS = re.compile(r'FLOPs/step: fwd=([\d.]+)M bwd_dx=([\d.]+)M bwd_dW=([\d.]+)M sdpa_bwd=([\d.]+)M total=([\d.]+)M') +RE_ANE_FLOPS = re.compile(r'ANE FLOPs/step: ([\d.]+)M') +RE_ANE_TFLOPS = re.compile(r'ANE TFLOPS:\s+([\d.]+)') +RE_ANE_UTIL = re.compile(r'ANE utilization:\s+([\d.]+)%') +RE_EFFICIENCY = re.compile(r'(Total steps|Wall time|Compile time|Train time|Avg compile|Avg train|ANE TFLOPS|Total TFLOPS|ANE utilization):?\s+(.+)') +RE_ANE_POWER = re.compile(r'ANE Power:\s+([\d.]+)\s*mW') +RE_CPU_POWER = re.compile(r'CPU Power:\s+([\d.]+)\s*mW') +RE_GPU_POWER = re.compile(r'GPU Power:\s+([\d.]+)\s*mW') + +def parse_line(line): + S.logs.append(line) + m = RE_CONFIG.search(line) + if m: + S.model_config = dict(zip(['dim', 'hidden', 'heads', 'seq', 'vocab', 'layers'], map(int, m.groups()))) + return + m = RE_PARAMS.search(line) + if m: + S.params = {'total': float(m[1]), 'transformer': float(m[2]), 'embed': float(m[3])} + return + m = RE_KERNELS.search(line) + if m: + S.kernels = {'total': int(m[1]), 'weight_bearing': int(m[2])} + return + m = RE_ACCUM.search(line) + if m: + S.training = {'accum': int(m[1]), 'lr': m[2]} + return + m = RE_FLOPS.search(line) + if m: + S.flops.update(fwd=float(m[1]), bwd_dx=float(m[2]), bwd_dw=float(m[3]), + sdpa_bwd=float(m[4]), total=float(m[5])) + return + m = RE_ANE_FLOPS.search(line) + if m: + S.flops['ane'] = float(m[1]) + return + m = RE_STEP.search(line) + if m: + S.step, S.loss = int(m[1]), float(m[2]) + S.loss_history.append((S.step, S.loss)) + S.best_loss = min(S.best_loss, S.loss) + return + m = RE_BATCH.search(line) + if m: + S.batch_num = int(m[1]) + compile_ms, train_ms = float(m[2]), float(m[3]) + S.ms_per_step = float(m[4]) + S.compiles = int(m[5]) + S.compile_pct = 100 * compile_ms / (compile_ms + train_ms) if compile_ms + train_ms > 0 else 0 + return + m = RE_TIMING.search(line) + if m: + S.component_timing = dict(zip(['ane', 'io', 'cls', 'elem', 'rms', 'cblas_wait'], map(float, m.groups()))) + return + m = RE_ANE_TFLOPS.search(line) + if m: + S.flops['ane_tflops'] = float(m[1]) + return + m = RE_ANE_UTIL.search(line) + if m: + S.flops['ane_util'] = float(m[1]) + return + m = RE_EFFICIENCY.search(line) + if m: + S.efficiency[m[1].strip()] = m[2].strip() + return + + +def parse_powermetrics_text(text): + now = time.monotonic() + m = RE_ANE_POWER.search(text) + if m: + S.power['ane'] = float(m[1]) / 1000.0 + S.power_history_ane.append((now, S.power['ane'])) + m = RE_CPU_POWER.search(text) + if m: + S.power['cpu'] = float(m[1]) / 1000.0 + S.power_history_cpu.append((now, S.power['cpu'])) + m = RE_GPU_POWER.search(text) + if m: + S.power['gpu'] = float(m[1]) / 1000.0 + + +BRAILLE_BASE = 0x2800 + +BRAILLE_MAP = [ + [1, 8], + [2, 16], + [4, 32], + [64, 128], +] + +def braille_chart(values, width, height, label_fmt='{:.1f}', y_range=None): + if not values or width < 8 or height < 2: + return ['(no data)'] * max(1, height) + chart_w = width - 6 + if chart_w < 2: + return ['(no data)'] * max(1, height) + points_x = chart_w * 2 + points_y = height * 4 + data = values[-points_x:] if len(values) > points_x else values + lo, hi = min(data), max(data) + if y_range: + lo, hi = y_range + if hi - lo < 0.001: + lo, hi = lo - 0.5, hi + 0.5 + margin = (hi - lo) * 0.05 + lo -= margin + hi += margin + + grid = [[0] * chart_w for _ in range(height)] + + def plot(px, py): + px = max(0, min(points_x - 1, px)) + py = max(0, min(points_y - 1, py)) + grid[py // 4][px // 2] |= BRAILLE_MAP[py % 4][px % 2] + + def val_to_y(v): + return int((1 - (v - lo) / (hi - lo)) * (points_y - 1)) + + for i in range(len(data)): + if i >= points_x: + break + y0 = val_to_y(data[i]) + plot(i, y0) + if i > 0: + y_prev = val_to_y(data[i - 1]) + y_lo, y_hi = min(y_prev, y0), max(y_prev, y0) + for yy in range(y_lo, y_hi + 1): + if y_hi != y_lo: + t = (yy - y_prev) / (y0 - y_prev) + xx = int(i - 1 + t) + else: + xx = i + plot(xx, yy) + + lines = [] + for r in range(height): + if r == 0: + label = label_fmt.format(hi)[:5].rjust(5) + elif r == height - 1: + label = label_fmt.format(lo)[:5].rjust(5) + elif r == height // 2: + label = label_fmt.format((hi + lo) / 2)[:5].rjust(5) + else: + label = ' ' + row_str = ''.join(chr(BRAILLE_BASE | grid[r][c]) for c in range(chart_w)) + lines.append(f'{label}\u2502{row_str}') + + lines.append(' \u2514' + '\u2500' * chart_w) + return lines + + +def draw(term): + w, h = term.width, term.height + if w < 40 or h < 15: + print(term.home + term.clear + 'Terminal too small', end='', flush=True) + return + + buf = [] + + def put(y, x, text, style=''): + if 0 <= y < h and x < w: + text = text[:w - x] + if style: + buf.append(term.move(y, x) + style + text + term.normal) + return + buf.append(term.move(y, x) + text) + + buf.append(term.home + term.clear) + + mid_x = w // 2 + right_w = w - mid_x - 1 + left_w = mid_x - 1 + + row = 0 + + # Model Config header + hdr = '\u2500 Model Config ' + put(row, 0, '\u250c' + hdr + '\u2500' * max(0, w - len(hdr) - 2) + '\u2510', term.cyan) + row += 1 + + cfg = S.model_config + if cfg: + line1 = f"stories110M dim={cfg.get('dim', '')} hidden={cfg.get('hidden', '')} heads={cfg.get('heads', '')} seq={cfg.get('seq', '')} layers={cfg.get('layers', '')}" + put(row, 0, '\u2502', term.cyan) + put(row, 2, line1) + put(row, w - 1, '\u2502', term.cyan) + row += 1 + p, k, t = S.params, S.kernels, S.training + line2 = f"{p.get('total', '?')}M params ({p.get('transformer', '?')}M xfmr + {p.get('embed', '?')}M embed)" + put(row, 0, '\u2502', term.cyan) + put(row, 2, line2) + put(row, w - 1, '\u2502', term.cyan) + row += 1 + line3 = f"{k.get('total', '?')} kernels ({k.get('weight_bearing', '?')} wt-bearing) | Accum {t.get('accum', '?')} | Adam LR={t.get('lr', '?')}" + put(row, 0, '\u2502', term.cyan) + put(row, 2, line3) + put(row, w - 1, '\u2502', term.cyan) + row += 1 + else: + put(row, 0, '\u2502', term.cyan) + put(row, 2, 'Waiting for model config...') + put(row, w - 1, '\u2502', term.cyan) + row += 1 + + remaining = h - row - 1 + # Allocate: loss curve ~40%, logs ~30%, power/cpu/mem/gen share rest + power_h = max(3, remaining // 8) + gen_h = max(2, remaining // 10) + extra_panels = power_h + power_h + gen_h + 6 # power + cpu/mem + gen + dividers + log_h_min = max(5, remaining // 5) + curve_h = max(5, remaining - extra_panels - log_h_min) + + # Loss Curve + Training Stats divider + put(row, 0, '\u251c\u2500 Loss Curve ' + '\u2500' * max(0, left_w - 13) + '\u252c\u2500 Training Stats ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan) + row += 1 + + # Loss curve + loss_vals = [l for _, l in S.loss_history] + curve_lines = braille_chart(loss_vals, left_w - 1, curve_h) + for i, cl in enumerate(curve_lines): + put(row + i, 0, '\u2502', term.cyan) + put(row + i, 1, cl, term.green) + put(row + i, mid_x, '\u2502', term.cyan) + put(row + i, w - 1, '\u2502', term.cyan) + + # Training stats (right panel) + sr = row + step_str = f'{S.step}' + (f'/{S.total_steps}' if S.total_steps and S.total_steps < 999999 else '') + put(sr, mid_x + 1, f' Step: {step_str} Loss: {S.loss:.4f}' if S.loss else ' Step: --', term.yellow) + sr += 1 + put(sr, mid_x + 1, f' Best: {S.best_loss:.4f} ms/step: {S.ms_per_step:.1f}' if S.best_loss < float('inf') else ' Best: --') + sr += 1 + ane_tflops = S.flops.get('ane_tflops', 0) + ane_util = S.flops.get('ane_util', 0) + if ane_tflops: + put(sr, mid_x + 1, f' ANE: {ane_tflops:.2f}T Compile: {S.compile_pct:.0f}% Util: {ane_util:.1f}%') + else: + put(sr, mid_x + 1, f' Compile: {S.compile_pct:.0f}%') + sr += 1 + ct = S.component_timing + if ct: + put(sr, mid_x + 1, f' ane={ct.get("ane", 0):.1f} io={ct.get("io", 0):.1f} cls={ct.get("cls", 0):.1f} elem={ct.get("elem", 0):.1f}') + sr += 1 + put(sr, mid_x + 1, f' rms={ct.get("rms", 0):.1f} cblas_wait={ct.get("cblas_wait", 0):.1f} ms/step') + sr += 1 + pw = S.power + if any(pw.values()): + put(sr, mid_x + 1, '\u2500 Power ' + '\u2500' * max(0, right_w - 9), term.cyan) + sr += 1 + put(sr, mid_x + 1, f' ANE: {pw["ane"]:.1f}W CPU: {pw["cpu"]:.1f}W GPU: {pw["gpu"]:.1f}W', term.magenta) + sr += 1 + if S.batch_num: + put(sr, mid_x + 1, f' Batch: {S.batch_num} Compiles: {S.compiles}') + sr += 1 + + # Fill vertical borders between loss curve and stats + top_end = row + len(curve_lines) + for r in range(row, max(top_end, sr)): + if r >= top_end: + put(r, 0, '\u2502', term.cyan) + if r >= sr: + put(r, mid_x, '\u2502', term.cyan) + put(r, w - 1, '\u2502', term.cyan) + row = max(top_end, sr) + + # Power charts + has_power = len(S.power_history_ane) > 1 or len(S.power_history_cpu) > 1 + if has_power: + put(row, 0, '\u251c\u2500 ANE Power (W) ' + '\u2500' * max(0, left_w - 16) + '\u252c\u2500 CPU Power (W) ' + '\u2500' * max(0, right_w - 17) + '\u2524', term.cyan) + row += 1 + ane_vals = [v for _, v in S.power_history_ane] + cpu_vals = [v for _, v in S.power_history_cpu] + ane_lines = braille_chart(ane_vals, left_w - 1, power_h, label_fmt='{:.1f}') + cpu_lines = braille_chart(cpu_vals, right_w - 1, power_h, label_fmt='{:.1f}') + max_lines = max(len(ane_lines), len(cpu_lines)) + while len(ane_lines) < max_lines: + ane_lines.append(' ' * (left_w - 1)) + while len(cpu_lines) < max_lines: + cpu_lines.append(' ' * (right_w - 1)) + for i in range(max_lines): + put(row + i, 0, '\u2502', term.cyan) + put(row + i, 1, ane_lines[i], term.red) + put(row + i, mid_x, '\u2502', term.cyan) + put(row + i, mid_x + 1, cpu_lines[i], term.blue) + put(row + i, w - 1, '\u2502', term.cyan) + row += max_lines + + # CPU / Memory charts + has_sysmetrics = len(S.cpu_pct_history) > 0 + if has_sysmetrics: + put(row, 0, '\u251c\u2500 CPU % ' + '\u2500' * max(0, left_w - 8) + '\u252c\u2500 Memory (MB) ' + '\u2500' * max(0, right_w - 15) + '\u2524', term.cyan) + row += 1 + cpu_vals = list(S.cpu_pct_history) + mem_vals = list(S.proc_mem_mb_history) if S.proc_mem_mb_history else list(S.mem_mb_history) + mem_label = 'proc' if S.proc_mem_mb_history else 'sys' + cpu_lines = braille_chart(cpu_vals, left_w - 1, power_h, label_fmt='{:.0f}', y_range=(0, 100)) + mem_lines = braille_chart(mem_vals, right_w - 1, power_h, label_fmt='{:.0f}') + max_lines = max(len(cpu_lines), len(mem_lines)) + while len(cpu_lines) < max_lines: + cpu_lines.append(' ' * (left_w - 1)) + while len(mem_lines) < max_lines: + mem_lines.append(' ' * (right_w - 1)) + for i in range(max_lines): + put(row + i, 0, '\u2502', term.cyan) + put(row + i, 1, cpu_lines[i], term.yellow) + put(row + i, mid_x, '\u2502', term.cyan) + put(row + i, mid_x + 1, mem_lines[i], term.magenta) + put(row + i, w - 1, '\u2502', term.cyan) + row += max_lines + + # Generated text + with S.gen_lock: + gen_text = S.gen_text + gen_step = S.gen_step + gen_status = S.gen_status + if gen_text or gen_status == 'generating': + status_tag = ' (generating...)' if gen_status == 'generating' else f' (step {gen_step})' + put(row, 0, '\u251c\u2500 Generated Text' + status_tag + ' ' + '\u2500' * max(0, w - 20 - len(status_tag)) + '\u2524', term.cyan) + row += 1 + if gen_text: + line_w = w - 3 + text = gen_text.replace('\n', ' ') + wrapped = [text[i:i + line_w] for i in range(0, len(text), line_w)] + for i, tl in enumerate(wrapped[:gen_h]): + put(row, 0, '\u2502', term.cyan) + put(row, 2, tl, term.white) + put(row, w - 1, '\u2502', term.cyan) + row += 1 + else: + put(row, 0, '\u2502', term.cyan) + put(row, 2, '...') + put(row, w - 1, '\u2502', term.cyan) + row += 1 + + # Logs + log_h = h - row - 1 + scroll_hint = ' (scroll) ' if not S.auto_scroll else ' ' + put(row, 0, '\u251c\u2500 Logs' + scroll_hint + '\u2500' * max(0, w - 8 - len(scroll_hint)) + '\u2524', term.cyan) + row += 1 + + logs = list(S.logs) + if log_h > 0 and logs: + if S.auto_scroll: + start = max(0, len(logs) - log_h) + else: + start = max(0, min(S.log_scroll, len(logs) - log_h)) + visible = logs[start:start + log_h] + for i, line in enumerate(visible): + put(row + i, 0, '\u2502', term.cyan) + if RE_STEP.search(line): + put(row + i, 1, line[:w - 2], term.yellow) + elif line.strip().startswith('[batch'): + put(row + i, 1, line[:w - 2], term.blue) + elif 'FAIL' in line or 'error' in line.lower(): + put(row + i, 1, line[:w - 2], term.red) + else: + put(row + i, 1, line[:w - 2]) + put(row + i, w - 1, '\u2502', term.cyan) + for i in range(len(visible), log_h): + put(row + i, 0, '\u2502', term.cyan) + put(row + i, w - 1, '\u2502', term.cyan) + + # Bottom border + put(h - 1, 0, '\u2514' + '\u2500' * (w - 2) + '\u2518', term.cyan) + + sys.stdout.write(''.join(buf)) + sys.stdout.flush() + + +def set_nonblock(fd): + fl = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) + +def spawn_training(resume=False, steps=10000): + cmd = 'make train_large 2>&1 && ./train_large' + if resume: + cmd += ' --resume' + cmd += f' --steps {steps}' + proc = subprocess.Popen( + ['bash', '-c', cmd], + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + cwd=os.path.dirname(os.path.abspath(__file__)) or '.') + set_nonblock(proc.stdout.fileno()) + return proc + +def spawn_powermetrics(): + try: + proc = subprocess.Popen( + ['sudo', 'powermetrics', '--samplers', 'cpu_power,gpu_power,ane_power', '-i', '1000'], + stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + set_nonblock(proc.stdout.fileno()) + return proc + except (FileNotFoundError, PermissionError): + return None + +def main(): + parser = argparse.ArgumentParser(description='ANE Training Dashboard (stories110M)') + parser.add_argument('--resume', action='store_true', help='Resume from checkpoint') + parser.add_argument('--infinite', action='store_true', help='Train indefinitely') + parser.add_argument('--no-powermetrics', action='store_true') + parser.add_argument('--no-generate', action='store_true', help='Disable text generation') + parser.add_argument('--steps', type=int, default=10000, help='Total steps (default: 10000)') + args = parser.parse_args() + + if args.infinite: + args.steps = 999999999 + S.total_steps = args.steps + + term = Terminal() + procs = [] + + train_proc = spawn_training(resume=args.resume, steps=args.steps) + S.train_pid = train_proc.pid + procs.append(train_proc) + + if HAS_PSUTIL: + psutil.cpu_percent(interval=None) # prime the counter + sys_t = threading.Thread(target=sysmetrics_thread, daemon=True) + sys_t.start() + + pm_proc = None + if not args.no_powermetrics: + pm_proc = spawn_powermetrics() + if pm_proc: + procs.append(pm_proc) + + if not args.no_generate: + gen_t = threading.Thread(target=generation_thread, daemon=True) + gen_t.start() + + pm_buf = '' + train_buf = '' + + def cleanup(): + for p in procs: + try: + p.terminate() + except Exception: + pass + + signal.signal(signal.SIGINT, lambda *a: cleanup()) + signal.signal(signal.SIGTERM, lambda *a: cleanup()) + + resized = [False] + def on_resize(*a): + resized[0] = True + + signal.signal(signal.SIGWINCH, on_resize) + + with term.fullscreen(), term.cbreak(), term.hidden_cursor(): + draw(term) + last_draw = time.monotonic() + + while True: + fds = [] + fd_map = {} + if train_proc and train_proc.stdout: + fd = train_proc.stdout.fileno() + fds.append(fd) + fd_map[fd] = 'train' + if pm_proc and pm_proc.stdout: + fd = pm_proc.stdout.fileno() + fds.append(fd) + fd_map[fd] = 'pm' + fds.append(sys.stdin.fileno()) + fd_map[sys.stdin.fileno()] = 'stdin' + + try: + readable, _, _ = select.select(fds, [], [], 0.25) + except (ValueError, OSError): + continue + + need_draw = resized[0] + resized[0] = False + + train_finished = False + + for fd in readable: + kind = fd_map.get(fd) + if kind == 'train': + try: + data = os.read(fd, 65536) + except BlockingIOError: + continue + except (OSError, ValueError): + data = b'' + if not data: + if train_proc.poll() is not None: + try: + rest = train_proc.stdout.read() + if rest: + for line in rest.decode('utf-8', errors='replace').split('\n'): + if line: + parse_line(line) + except Exception: + pass + S.logs.append('[dashboard] Training finished. Press q to exit.') + train_finished = True + continue + train_buf += data.decode('utf-8', errors='replace') + while '\n' in train_buf: + line, train_buf = train_buf.split('\n', 1) + parse_line(line) + need_draw = True + + elif kind == 'pm': + try: + data = os.read(fd, 65536).decode('utf-8', errors='replace') + except BlockingIOError: + continue + except (OSError, ValueError): + data = '' + if not data: + continue + pm_buf += data + while '\n\n' in pm_buf or '*** ' in pm_buf: + end = pm_buf.find('\n*** ', 1) + if end < 0: + end = pm_buf.find('\n\n', 1) + if end < 0: + break + chunk = pm_buf[:end] + pm_buf = pm_buf[end:] + parse_powermetrics_text(chunk) + if len(pm_buf) > 16384: + pm_buf = pm_buf[-8192:] + need_draw = True + + elif kind == 'stdin': + key = term.inkey(timeout=0) + if not key: + continue + if key == 'q': + cleanup() + return + elif key.name == 'KEY_UP': + S.auto_scroll = False + S.log_scroll = max(0, S.log_scroll - 1) + need_draw = True + elif key.name == 'KEY_DOWN': + S.log_scroll += 1 + need_draw = True + elif key == 'p': + S.auto_scroll = not S.auto_scroll + if S.auto_scroll: + S.log_scroll = max(0, len(S.logs) - 10) + need_draw = True + elif key == 'r': + if train_proc: + train_proc.terminate() + train_proc.wait() + train_proc = spawn_training(resume=True, steps=args.steps) + S.train_pid = train_proc.pid + procs = [p for p in procs if p.poll() is None] + procs.append(train_proc) + S.logs.append('[dashboard] Restarted with --resume') + need_draw = True + elif key == 'g': + with S.gen_lock: + S.gen_status = 'generating' + S.gen_step = S.step + def force_gen(): + try: + W = load_weights_from_ckpt(CKPT_PATH) + if W: + text = generate_text(W, get_tokenizer(), max_tokens=64, temperature=0.8) + with S.gen_lock: + S.gen_text = text + S.gen_step = S.step + S.gen_status = 'done' + except Exception as e: + with S.gen_lock: + S.gen_text = f'[error: {e}]' + S.gen_status = 'done' + threading.Thread(target=force_gen, daemon=True).start() + need_draw = True + + now = time.monotonic() + if not need_draw and now - last_draw > 1.0: + need_draw = True + if need_draw and now - last_draw > 0.066: + draw(term) + last_draw = now + + if train_finished: + draw(term) + while True: + key = term.inkey(timeout=1) + if key == 'q': + cleanup() + return + +if __name__ == '__main__': + main() diff --git a/training/stories_config.h b/training/stories_config.h new file mode 100644 index 0000000..f967974 --- /dev/null +++ b/training/stories_config.h @@ -0,0 +1,189 @@ +// stories_config.h — Stories110M model config and structures +#pragma once +#import +#import +#import +#import +#import +#import +#import +#include +#include +#include +#include +#include +#include + +// Stories110M config +#define DIM 768 +#define HIDDEN 2048 +#define HEADS 12 +#define HD (DIM/HEADS) +#define SEQ 256 +#define NLAYERS 12 +#define VOCAB 32000 +#define ACCUM_STEPS 10 +#define MAX_COMPILES 100 + +// Per compile: 5 weight-bearing kernels per layer + 1 classifier = 5*12+1 = 61 +// Plus 1 static (sdpaBwd2 per layer, no weights) = 12 more but those are weight-free +// Actually sdpaBwd2 has no weights, compile once per layer +// Weight-bearing: fwdAttn(1) + fwdFFN(1) + ffnBwd(1) + sdpaBwd1(1) + qkvBwd(1) = 5 per layer +// 5 * 12 = 60 weight-bearing compiles per batch +// With MAX_COMPILES=100, we get 1 batch of ACCUM_STEPS before restart +#define KERNELS_PER_LAYER 5 +#define TOTAL_WEIGHT_KERNELS (KERNELS_PER_LAYER * NLAYERS) + +// Attention score channels for SDPA backward +#define SCORE_CH (HEADS*SEQ) + +// Weight sizes per layer +#define WQ_SZ (DIM*DIM) +#define WO_SZ (DIM*DIM) +#define W1_SZ (HIDDEN*DIM) +#define W2_SZ (DIM*HIDDEN) +#define W3_SZ (HIDDEN*DIM) +#define LAYER_PARAMS (4*WQ_SZ + W1_SZ + W2_SZ + W3_SZ + 2*DIM) +#define TOTAL_PARAMS (NLAYERS * LAYER_PARAMS + DIM + VOCAB*DIM) // +rms_final+embed + +// Per-layer weight and optimizer state +typedef struct { + float *Wq, *Wk, *Wv, *Wo; + float *W1, *W2, *W3; + float *rms_att, *rms_ffn; +} LayerWeights; + +typedef struct { + float *m, *v; + size_t n; +} AdamState; + +typedef struct { + AdamState Wq, Wk, Wv, Wo; + AdamState W1, W2, W3; + AdamState rms_att, rms_ffn; +} LayerAdam; + +// Per-layer activation buffers (saved for backward) +typedef struct { + float *layer_in; // [DIM, SEQ] input to this layer (for rmsnorm1 bwd) + float *xnorm; // [DIM, SEQ] rmsnorm1 output + float *Q, *K, *V; // [DIM, SEQ] QKV projections + float *attn_out; // [DIM, SEQ] attention output (before Wo) + float *o_out; // [DIM, SEQ] Wo output + float *x2; // [DIM, SEQ] residual after attn + float *x2norm; // [DIM, SEQ] rmsnorm2 output + float *h1, *h3; // [HIDDEN, SEQ] FFN intermediates + float *silu_out; // [HIDDEN, SEQ] SiLU(h1)*h3 + float *ffn_out; // [DIM, SEQ] FFN output +} LayerActs; + +// Per-layer gradient accumulators +typedef struct { + float *Wq, *Wk, *Wv, *Wo; + float *W1, *W2, *W3; + float *rms_att, *rms_ffn; +} LayerGrads; + +// ANE kernels per layer +typedef struct { void *model; IOSurfaceRef ioIn, ioOut; void *request; void *tmpDir; } Kern; +typedef struct { + Kern *fwdAttn, *fwdFFN, *ffnBwd, *sdpaBwd1, *sdpaBwd2, *qkvBwd; +} LayerKernels; + +// Checkpoint header +typedef struct { + int magic; // 0x424C5A54 "BLZT" + int version; // 2 + int step, total_steps; + int n_layers, vocab_size, dim, hidden_dim, n_heads, seq_len; + float lr, loss; + double cum_compile, cum_train, cum_wall; + int cum_steps, cum_batches; + int adam_t; + int pad[3]; // alignment +} CkptHdr; + +// llama2.c model file header +typedef struct { + int dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len; +} Llama2Config; + +// Globals +static Class g_D, g_I, g_AR, g_AIO; +static mach_timebase_info_data_t g_tb; +static int g_compile_count = 0; + +static void ane_init(void) { + dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW); + g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor"); + g_I = NSClassFromString(@"_ANEInMemoryModel"); + g_AR = NSClassFromString(@"_ANERequest"); + g_AIO= NSClassFromString(@"_ANEIOSurfaceObject"); +} +static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; } + +// Alloc helpers +static AdamState adam_alloc(size_t n) { AdamState s; s.m=(float*)calloc(n,4); s.v=(float*)calloc(n,4); s.n=n; return s; } +static void adam_free(AdamState *s) { free(s->m); free(s->v); } + +static LayerWeights layer_weights_alloc(void) { + LayerWeights w; + w.Wq=(float*)malloc(WQ_SZ*4); w.Wk=(float*)malloc(WQ_SZ*4); + w.Wv=(float*)malloc(WQ_SZ*4); w.Wo=(float*)malloc(WO_SZ*4); + w.W1=(float*)malloc(W1_SZ*4); w.W2=(float*)malloc(W2_SZ*4); w.W3=(float*)malloc(W3_SZ*4); + w.rms_att=(float*)malloc(DIM*4); w.rms_ffn=(float*)malloc(DIM*4); + return w; +} +static void layer_weights_free(LayerWeights *w) { + free(w->Wq);free(w->Wk);free(w->Wv);free(w->Wo); + free(w->W1);free(w->W2);free(w->W3); + free(w->rms_att);free(w->rms_ffn); +} +static LayerAdam layer_adam_alloc(void) { + LayerAdam a; + a.Wq=adam_alloc(WQ_SZ); a.Wk=adam_alloc(WQ_SZ); a.Wv=adam_alloc(WQ_SZ); a.Wo=adam_alloc(WO_SZ); + a.W1=adam_alloc(W1_SZ); a.W2=adam_alloc(W2_SZ); a.W3=adam_alloc(W3_SZ); + a.rms_att=adam_alloc(DIM); a.rms_ffn=adam_alloc(DIM); + return a; +} +static void layer_adam_free(LayerAdam *a) { + adam_free(&a->Wq);adam_free(&a->Wk);adam_free(&a->Wv);adam_free(&a->Wo); + adam_free(&a->W1);adam_free(&a->W2);adam_free(&a->W3); + adam_free(&a->rms_att);adam_free(&a->rms_ffn); +} +static LayerActs layer_acts_alloc(void) { + LayerActs a; + a.layer_in=(float*)malloc(SEQ*DIM*4); + a.xnorm=(float*)malloc(SEQ*DIM*4); a.Q=(float*)malloc(SEQ*DIM*4); + a.K=(float*)malloc(SEQ*DIM*4); a.V=(float*)malloc(SEQ*DIM*4); + a.attn_out=(float*)malloc(SEQ*DIM*4); a.o_out=(float*)malloc(SEQ*DIM*4); + a.x2=(float*)malloc(SEQ*DIM*4); a.x2norm=(float*)malloc(SEQ*DIM*4); + a.h1=(float*)malloc(SEQ*HIDDEN*4); a.h3=(float*)malloc(SEQ*HIDDEN*4); + a.silu_out=(float*)malloc(SEQ*HIDDEN*4); a.ffn_out=(float*)malloc(SEQ*DIM*4); + return a; +} +static void layer_acts_free(LayerActs *a) { + free(a->layer_in);free(a->xnorm);free(a->Q);free(a->K);free(a->V); + free(a->attn_out);free(a->o_out);free(a->x2);free(a->x2norm); + free(a->h1);free(a->h3);free(a->silu_out);free(a->ffn_out); +} +static LayerGrads layer_grads_alloc(void) { + LayerGrads g; + g.Wq=(float*)calloc(WQ_SZ,4); g.Wk=(float*)calloc(WQ_SZ,4); + g.Wv=(float*)calloc(WQ_SZ,4); g.Wo=(float*)calloc(WO_SZ,4); + g.W1=(float*)calloc(W1_SZ,4); g.W2=(float*)calloc(W2_SZ,4); g.W3=(float*)calloc(W3_SZ,4); + g.rms_att=(float*)calloc(DIM,4); g.rms_ffn=(float*)calloc(DIM,4); + return g; +} +static void layer_grads_zero(LayerGrads *g) { + memset(g->Wq,0,WQ_SZ*4);memset(g->Wk,0,WQ_SZ*4); + memset(g->Wv,0,WQ_SZ*4);memset(g->Wo,0,WO_SZ*4); + memset(g->W1,0,W1_SZ*4);memset(g->W2,0,W2_SZ*4);memset(g->W3,0,W3_SZ*4); + memset(g->rms_att,0,DIM*4);memset(g->rms_ffn,0,DIM*4); +} +static void layer_grads_free(LayerGrads *g) { + free(g->Wq);free(g->Wk);free(g->Wv);free(g->Wo); + free(g->W1);free(g->W2);free(g->W3); + free(g->rms_att);free(g->rms_ffn); +} diff --git a/training/stories_cpu_ops.h b/training/stories_cpu_ops.h new file mode 100644 index 0000000..c9f2cfa --- /dev/null +++ b/training/stories_cpu_ops.h @@ -0,0 +1,129 @@ +// stories_cpu_ops.h — CPU operations: RMSNorm, cross-entropy, Adam, softmax +#pragma once +#include "stories_config.h" + +static float *g_rms_tmp = NULL; + +static void rmsnorm(float *out, const float *x, const float *w, int d, int S) { + if (!g_rms_tmp) g_rms_tmp = (float*)malloc(S*4); + float *ss = (float*)calloc(S, sizeof(float)); + for (int i=0; in; i++) { + s->m[i] = b1*s->m[i] + (1-b1)*g[i]; + s->v[i] = b2*s->v[i] + (1-b2)*g[i]*g[i]; + float mh = s->m[i]/bc1, vh = s->v[i]/bc2; + w[i] -= lr * mh / (sqrtf(vh) + eps); + } +} + +// Cross-entropy loss + gradient for logits (column-major: [VOCAB, SEQ]) +// logits[v*SEQ+t] = logit for vocab v, position t +// targets[t] = target token id for position t +// Returns mean CE loss, writes dlogits = softmax(logits) - one_hot(targets) +// Data is column-major [V, S], but we process per-column (stride=1 within col is v*S+t, stride between v's is S) +// For vDSP: transpose to row-major scratch [S, V] to vectorize softmax per position +static float cross_entropy_loss(float *dlogits, const float *logits, const uint16_t *targets, int V, int S) { + // Work in transposed layout [S, V] where each row is one position's logits (contiguous) + float *buf = (float*)malloc(S * V * 4); + // Transpose [V,S] → [S,V]: buf[t*V+v] = logits[v*S+t] + vDSP_mtrans(logits, 1, buf, 1, (vDSP_Length)S, (vDSP_Length)V); + + float total_loss = 0; + float invS = 1.0f / S; + for (int t = 0; t < S; t++) { + float *row = buf + t * V; + // max + float maxv; + vDSP_maxv(row, 1, &maxv, (vDSP_Length)V); + // row -= maxv + float neg_max = -maxv; + vDSP_vsadd(row, 1, &neg_max, row, 1, (vDSP_Length)V); + // exp in-place + int n = V; + vvexpf(row, row, &n); + // sum + float sum; + vDSP_sve(row, 1, &sum, (vDSP_Length)V); + // normalize + float inv_sum = 1.0f / sum; + vDSP_vsmul(row, 1, &inv_sum, row, 1, (vDSP_Length)V); + // loss + int tgt = targets[t]; + total_loss -= logf(row[tgt] + 1e-10f); + // gradient: softmax - one_hot, then /S + row[tgt] -= 1.0f; + vDSP_vsmul(row, 1, &invS, row, 1, (vDSP_Length)V); + } + // Transpose back [S,V] → [V,S] + vDSP_mtrans(buf, 1, dlogits, 1, (vDSP_Length)V, (vDSP_Length)S); + free(buf); + return total_loss / S; +} + +// Embedding lookup: token_ids → x [DIM, SEQ] (channel-first) +// embed is [VOCAB, DIM] row-major (vocab_size rows, dim cols) +static void embed_lookup(float *x, const float *embed, const uint16_t *tokens, int dim, int seq) { + for (int t = 0; t < seq; t++) { + int tok = tokens[t]; + for (int d = 0; d < dim; d++) { + x[d*seq + t] = embed[tok*dim + d]; + } + } +} + +// Embedding backward: accumulate dE[tok] += dx[:,t] for each position +static void embed_backward(float *d_embed, const float *dx, const uint16_t *tokens, int dim, int seq) { + for (int t = 0; t < seq; t++) { + int tok = tokens[t]; + for (int d = 0; d < dim; d++) { + d_embed[tok*dim + d] += dx[d*seq + t]; + } + } +} diff --git a/training/stories_io.h b/training/stories_io.h new file mode 100644 index 0000000..017d8a8 --- /dev/null +++ b/training/stories_io.h @@ -0,0 +1,134 @@ +// stories_io.h — IOSurface helpers, blob builders, NEON conversion +#pragma once +#include "stories_config.h" +#include + +static IOSurfaceRef make_surface(size_t bytes) { + return IOSurfaceCreate((__bridge CFDictionaryRef)@{ + (id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1, + (id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes), + (id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0}); +} + +static NSData *build_blob(const float *w, int rows, int cols) { + int ws=rows*cols*2, tot=128+ws; + uint8_t *b=(uint8_t*)calloc(tot,1); + b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1; + *(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128; + _Float16 *fp16=(_Float16*)(b+128); + for(int i=0;imodel = (void*)CFBridgingRetain(mdl); + k->ioIn = make_surface(ic_bytes); + k->ioOut = make_surface(oc_bytes); + id wI = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioIn); + id wO = ((id(*)(Class,SEL,IOSurfaceRef))objc_msgSend)(g_AIO, @selector(objectWithIOSurface:), k->ioOut); + k->request = (void*)CFBridgingRetain(((id(*)(Class,SEL,id,id,id,id,id,id,id))objc_msgSend)(g_AR, + @selector(requestWithInputs:inputIndices:outputs:outputIndices:weightsBuffer:perfStats:procedureIndex:), + @[wI], @[@0], @[wO], @[@0], nil, nil, @0)); + k->tmpDir = (void*)CFBridgingRetain(td); + return k; + } +} +static void free_kern(Kern *k) { + if (!k) return; + id mdl = (__bridge id)k->model; NSError *e = nil; + ((BOOL(*)(id,SEL,unsigned int,NSError**))objc_msgSend)(mdl, @selector(unloadWithQoS:error:), 21, &e); + CFRelease(k->ioIn); CFRelease(k->ioOut); + [[NSFileManager defaultManager] removeItemAtPath:(__bridge id)k->tmpDir error:nil]; + CFRelease(k->model); CFRelease(k->request); CFRelease(k->tmpDir); + free(k); +} +static void ane_eval(Kern *k) { + id mdl = (__bridge id)k->model; id req = (__bridge id)k->request; NSError *e = nil; + ((BOOL(*)(id,SEL,unsigned int,id,id,NSError**))objc_msgSend)(mdl, @selector(evaluateWithQoS:options:request:error:), 21, @{}, req, &e); +} diff --git a/training/stories_mil.h b/training/stories_mil.h new file mode 100644 index 0000000..dccca44 --- /dev/null +++ b/training/stories_mil.h @@ -0,0 +1,286 @@ +// stories_mil.h — MIL program generators for ANE kernels +// Same architecture as single-layer train_large.m but parameterized +#pragma once +#include "stories_io.h" + +#define MIL_HDR \ + @"program(1.3)\n[buildInfo = dict({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \ + "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \ + "{\"coremltools-version\", \"9.0\"}})]\n{\n" +#define CONV_CONST \ + " string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \ + " tensor st = const()[name=string(\"st\"), val=tensor([1,1])];\n" \ + " tensor pd = const()[name=string(\"pd\"), val=tensor([0,0,0,0])];\n" \ + " tensor dl = const()[name=string(\"dl\"), val=tensor([1,1])];\n" \ + " int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n" + +// SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm) +static NSString *gen_sdpa_fwd_taps(void) { + float sc = 1.0f/sqrtf((float)HD); + float invd = 1.0f/(float)DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; + [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; + [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; + [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; + [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; + [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; + [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; + [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; + [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; + [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms1.bin\"), offset=uint64(64)))];\n", DIM, DIM]; + [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ]; + [m appendString:@CONV_CONST]; + [m appendFormat:@" tensor Wq = const()[name=string(\"Wq\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wq.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor Wk = const()[name=string(\"Wk\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wk.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor Wv = const()[name=string(\"Wv\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wv.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor Wo = const()[name=string(\"Wo\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wo.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=string(\"cq\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=string(\"ck\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=string(\"cv\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS,SEQ,HD]; + [m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"]; + [m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; + [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ]; + [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ]; + [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; + [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor os = const()[name=string(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; + [m appendFormat:@" tensor af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=string(\"co\")];\n", DIM,SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM,SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm) +static NSString *gen_ffn_fwd_taps(void) { + float invd = 1.0f/(float)DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; + [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; + [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; + [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; + [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; + [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; + [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; + [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; + [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; + [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms2.bin\"), offset=uint64(64)))];\n", DIM, DIM]; + [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ]; + [m appendString:@CONV_CONST]; + [m appendFormat:@" tensor W1 = const()[name=string(\"W1\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w1.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; + [m appendFormat:@" tensor W3 = const()[name=string(\"W3\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w3.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; + [m appendFormat:@" tensor W2 = const()[name=string(\"W2\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w2.bin\"), offset=uint64(64)))];\n", DIM,HIDDEN,DIM,HIDDEN]; + [m appendFormat:@" tensor h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=string(\"c1\")];\n", HIDDEN,SEQ]; + [m appendFormat:@" tensor h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=string(\"c3\")];\n", HIDDEN,SEQ]; + [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN,SEQ]; + [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN,SEQ]; + [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN,SEQ]; + [m appendFormat:@" tensor y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=string(\"c2\")];\n", DIM,SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=string(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3) +static NSString *gen_ffn_bwd(void) { + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", DIM+2*HIDDEN, SEQ]; + [m appendString:@CONV_CONST]; + [m appendString:@" tensor bd = const()[name=string(\"bd\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor dffn = slice_by_size(x=x,begin=bd,size=sd)[name=string(\"s0\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; + [m appendFormat:@" tensor s1 = const()[name=string(\"s1\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor h1 = slice_by_size(x=x,begin=b1,size=s1)[name=string(\"s1x\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", DIM+HIDDEN]; + [m appendFormat:@" tensor h3 = slice_by_size(x=x,begin=b3,size=s1)[name=string(\"s3x\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor W2t = const()[name=string(\"W2t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w2t.bin\"), offset=uint64(64)))];\n", HIDDEN, DIM, HIDDEN, DIM]; + [m appendFormat:@" tensor dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=string(\"cw2\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ]; + [m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"]; + [m appendFormat:@" tensor oms = sub(x=one,y=sig)[name=string(\"oms\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor homs = mul(x=h1,y=oms)[name=string(\"homs\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor brk = add(x=one,y=homs)[name=string(\"brk\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dsd = mul(x=sig,y=brk)[name=string(\"dsd\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor t1 = mul(x=dsilu,y=h3)[name=string(\"t1\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh1 = mul(x=t1,y=dsd)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor slh = mul(x=h1,y=sig)[name=string(\"slh\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor dh3 = mul(x=dsilu,y=slh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; + [m appendFormat:@" tensor W1t = const()[name=string(\"W1t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w1t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; + [m appendFormat:@" tensor W3t = const()[name=string(\"W3t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w3t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; + [m appendFormat:@" tensor dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=string(\"cw1\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=string(\"cw3\")];\n", DIM, SEQ]; + [m appendFormat:@" tensor dx = add(x=dx1,y=dx3)[name=string(\"adx\")];\n", DIM, SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=string(\"cat\")];\n", DIM+2*HIDDEN, SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// QKV backward: concat(dq,dk,dv) → dx +static NSString *gen_qkvb(void) { + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", 3*DIM, SEQ]; + [m appendString:@CONV_CONST]; + [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; + [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; + [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor Wqt = const()[name=string(\"Wqt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wqt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor Wkt = const()[name=string(\"Wkt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wkt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor Wvt = const()[name=string(\"Wvt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wvt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=string(\"cq\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=string(\"ck\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=string(\"cv\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor out = add(x=dxqk,y=dxv)[name=string(\"out\")];\n", DIM,SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// SDPA backward part 1 + Wo^T +static NSString *gen_sdpa_bwd1(void) { + float sc = 1.0f/sqrtf((float)HD); + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", 4*DIM, SEQ]; + [m appendString:@CONV_CONST]; + [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; + [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; + [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; + [m appendFormat:@" tensor dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor Wot = const()[name=string(\"Wot\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wot.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; + [m appendFormat:@" tensor df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=string(\"cwo\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor dr = reshape(shape=rsh,x=df)[name=string(\"rd\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor da = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS,SEQ,HD]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; + [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ]; + [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ]; + [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; + [m appendFormat:@" tensor probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=string(\"dv\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=string(\"dp\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; + [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor scs = const()[name=string(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH,SEQ]; + [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH,SEQ]; + [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH,SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH,SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// SDPA backward part 2: concat(probs,dp,Q,K) → concat(dQ,dK) +static NSString *gen_sdpa_bwd2(void) { + float sc = 1.0f/sqrtf((float)HD); + int bwd2_in = 2*SCORE_CH + 2*DIM; + NSMutableString *m = [NSMutableString string]; + [m appendString:MIL_HDR]; + [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; + [m appendFormat:@" tensor sz_sc = const()[name=string(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; + [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; + [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH,SEQ]; + [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; + [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH,SEQ]; + [m appendFormat:@" tensor sz_d = const()[name=string(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; + [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; + [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; + [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor ssh = const()[name=string(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; + [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; + [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS,SEQ,SEQ]; + [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([-1])];\n"]; + [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS,SEQ]; + [m appendFormat:@" tensor dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS,SEQ,SEQ]; + [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; + [m appendFormat:@" tensor ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS,SEQ,SEQ]; + [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; + [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; + [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS,SEQ,HD]; + [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS,HD,SEQ]; + [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; + [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM,SEQ]; + [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM,SEQ]; + [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; + [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; + [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM,SEQ]; + [m appendString:@" } -> (out);\n}\n"]; + return m; +} + +// Mask blob (causal mask [SEQ,SEQ]) +static NSData *g_mask_blob = nil; +static NSData *get_mask_blob(void) { + if (!g_mask_blob) { + _Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16)); + for(int t=0;t -#import -#import -#import -#import -#import -#import -#include -#include -#include +// train_large.m — Train stories110M (12 layers, 768dim, 3072hidden) on ANE +// Uses pretokenized TinyStories data with cross-entropy loss +// 5 weight-bearing ANE kernels per layer × 12 layers = 60 per compile batch +#include "stories_io.h" +#include "stories_mil.h" +#include "stories_cpu_ops.h" -#define DIM 768 -#define HIDDEN 2048 -#define HEADS 12 -#define HD (DIM/HEADS) -#define SEQ 512 -#define ACCUM_STEPS 100 -#define MAX_COMPILES 100 -#define NUM_KERNELS 6 -#define CKPT_PATH "/tmp/ane_large_ckpt.bin" +#define CKPT_PATH "ane_stories110M_ckpt.bin" +#define MODEL_PATH "../../assets/models/stories110M.bin" +#define DATA_PATH "tinystories_data00.bin" -static Class g_D, g_I, g_AR, g_AIO; -static mach_timebase_info_data_t g_tb; -static int g_compile_count = 0; - -static void ane_init(void) { - dlopen("/System/Library/PrivateFrameworks/AppleNeuralEngine.framework/AppleNeuralEngine", RTLD_NOW); - g_D = NSClassFromString(@"_ANEInMemoryModelDescriptor"); - g_I = NSClassFromString(@"_ANEInMemoryModel"); - g_AR = NSClassFromString(@"_ANERequest"); - g_AIO= NSClassFromString(@"_ANEIOSurfaceObject"); -} -static double tb_ms(uint64_t t) { return (double)t * g_tb.numer / g_tb.denom / 1e6; } -static IOSurfaceRef make_surface(size_t bytes) { - return IOSurfaceCreate((__bridge CFDictionaryRef)@{ - (id)kIOSurfaceWidth:@(bytes), (id)kIOSurfaceHeight:@1, - (id)kIOSurfaceBytesPerElement:@1, (id)kIOSurfaceBytesPerRow:@(bytes), - (id)kIOSurfaceAllocSize:@(bytes), (id)kIOSurfacePixelFormat:@0}); -} -static NSData *build_blob(const float *w, int rows, int cols) { - int ws=rows*cols*2, tot=128+ws; - uint8_t *b=(uint8_t*)calloc(tot,1); - b[0]=1;b[4]=2;b[64]=0xEF;b[65]=0xBE;b[66]=0xAD;b[67]=0xDE;b[68]=1; - *(uint32_t*)(b+72)=ws;*(uint32_t*)(b+80)=128; - _Float16 *fp16=(_Float16*)(b+128); - for(int i=0;i({{\"coremlc-component-MIL\", \"3510.2.1\"}, " \ - "{\"coremlc-version\", \"3505.4.1\"}, {\"coremltools-component-milinternal\", \"\"}, " \ - "{\"coremltools-version\", \"9.0\"}})]\n{\n" -#define CONV_CONST \ - " string pt = const()[name=string(\"pt\"), val=string(\"valid\")];\n" \ - " tensor st = const()[name=string(\"st\"), val=tensor([1,1])];\n" \ - " tensor pd = const()[name=string(\"pd\"), val=tensor([0,0,0,0])];\n" \ - " tensor dl = const()[name=string(\"dl\"), val=tensor([1,1])];\n" \ - " int32 gr = const()[name=string(\"gr\"), val=int32(1)];\n" - -// SDPA forward + taps: x_in → rmsnorm → QKV+SDPA+Wo → concat(o_out, Q, K, V, attn_out, xnorm) fp16 -static NSString *gen_sdpa_fwd_taps(void) { - float sc = 1.0f/sqrtf((float)HD); - float invd = 1.0f/(float)DIM; - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; - // --- RMSNorm: x → xn --- - [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; - [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; - [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; - [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; - [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; - [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; - [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; - [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; - [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms1.bin\"), offset=uint64(64)))];\n", DIM, DIM]; - [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ]; - // --- QKV + SDPA + Wo (operates on xn) --- - [m appendString:@CONV_CONST]; - [m appendFormat:@" tensor Wq = const()[name=string(\"Wq\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wq.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor Wk = const()[name=string(\"Wk\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wk.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor Wv = const()[name=string(\"Wv\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wv.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor Wo = const()[name=string(\"Wo\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wo.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor qf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wq,x=xn)[name=string(\"cq\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor kf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wk,x=xn)[name=string(\"ck\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor vf = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wv,x=xn)[name=string(\"cv\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor qsh = const()[name=string(\"qsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; - [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendFormat:@" tensor q4 = reshape(shape=qsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor q = transpose(perm=pm,x=q4)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor k4 = reshape(shape=qsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor k = transpose(perm=pm,x=k4)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor v4 = reshape(shape=qsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor v = transpose(perm=pm,x=v4)[name=string(\"tv\")];\n", HEADS,SEQ,HD]; - [m appendString:@" bool tx = const()[name=string(\"tx\"), val=bool(false)];\n"]; - [m appendString:@" bool ty = const()[name=string(\"ty\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor sc1 = matmul(transpose_x=tx,transpose_y=ty,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; - [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ]; - [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ]; - [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; - [m appendFormat:@" tensor aw = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor a4 = matmul(transpose_x=tx,transpose_y=tx,x=aw,y=v)[name=string(\"mm2\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor at = transpose(perm=pm,x=a4)[name=string(\"ta\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor os = const()[name=string(\"os\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; - [m appendFormat:@" tensor af = reshape(shape=os,x=at)[name=string(\"ra\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor oo = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wo,x=af)[name=string(\"co\")];\n", DIM,SEQ]; - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(oo,qf,kf,vf,af,xn))[name=string(\"cat\")];\n", 6*DIM,SEQ]; - [m appendString:@" } -> (out);\n}\n"]; - return m; -} - -// FFN forward + taps: x2 → rmsnorm → FFN → concat(ffn_out, h1, h3, silu_out, x2norm) fp16 -static NSString *gen_ffn_fwd_taps(void) { - float invd = 1.0f/(float)DIM; - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM, SEQ]; - // --- RMSNorm: x → xn --- - [m appendFormat:@" tensor sq = mul(x=x,y=x)[name=string(\"sq\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor rax = const()[name=string(\"rax\"), val=tensor([1])];\n"]; - [m appendFormat:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor ss = reduce_sum(x=sq,axes=rax,keep_dims=kd)[name=string(\"ss\")];\n", SEQ]; - [m appendFormat:@" fp16 invd = const()[name=string(\"invd\"), val=fp16(%f)];\n", invd]; - [m appendFormat:@" tensor ss2 = mul(x=ss,y=invd)[name=string(\"ss2\")];\n", SEQ]; - [m appendFormat:@" fp16 eps = const()[name=string(\"eps\"), val=fp16(0.00001)];\n"]; - [m appendFormat:@" tensor ss3 = add(x=ss2,y=eps)[name=string(\"ss3\")];\n", SEQ]; - [m appendFormat:@" fp16 nhalf = const()[name=string(\"nhalf\"), val=fp16(-0.5)];\n"]; - [m appendFormat:@" tensor rrms = pow(x=ss3,y=nhalf)[name=string(\"rrms\")];\n", SEQ]; - [m appendFormat:@" tensor xr = mul(x=x,y=rrms)[name=string(\"xr\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor rw = const()[name=string(\"rw\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/rms2.bin\"), offset=uint64(64)))];\n", DIM, DIM]; - [m appendFormat:@" tensor xn = mul(x=xr,y=rw)[name=string(\"xn\")];\n", DIM, SEQ]; - // --- FFN (operates on xn) --- - [m appendString:@CONV_CONST]; - [m appendFormat:@" tensor W1 = const()[name=string(\"W1\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w1.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; - [m appendFormat:@" tensor W3 = const()[name=string(\"W3\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w3.bin\"), offset=uint64(64)))];\n", HIDDEN,DIM,HIDDEN,DIM]; - [m appendFormat:@" tensor W2 = const()[name=string(\"W2\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w2.bin\"), offset=uint64(64)))];\n", DIM,HIDDEN,DIM,HIDDEN]; - [m appendFormat:@" tensor h1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1,x=xn)[name=string(\"c1\")];\n", HIDDEN,SEQ]; - [m appendFormat:@" tensor h3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3,x=xn)[name=string(\"c3\")];\n", HIDDEN,SEQ]; - [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN,SEQ]; - [m appendFormat:@" tensor silu = mul(x=h1,y=sig)[name=string(\"si\")];\n", HIDDEN,SEQ]; - [m appendFormat:@" tensor gate = mul(x=silu,y=h3)[name=string(\"gt\")];\n", HIDDEN,SEQ]; - [m appendFormat:@" tensor y = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2,x=gate)[name=string(\"c2\")];\n", DIM,SEQ]; - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(y,h1,h3,gate,xn))[name=string(\"cat\")];\n", 2*DIM+3*HIDDEN,SEQ]; - [m appendString:@" } -> (out);\n}\n"]; - return m; -} - -// Fused FFN backward: concat(dffn,h1,h3) → concat(dx,dh1,dh3) fp16 -static NSString *gen_ffn_bwd(void) { - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", DIM+2*HIDDEN, SEQ]; - [m appendString:@CONV_CONST]; - [m appendString:@" tensor bd = const()[name=string(\"bd\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor sd = const()[name=string(\"sd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor dffn = slice_by_size(x=x,begin=bd,size=sd)[name=string(\"s0\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; - [m appendFormat:@" tensor s1 = const()[name=string(\"s1\"), val=tensor([1,%d,1,%d])];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor h1 = slice_by_size(x=x,begin=b1,size=s1)[name=string(\"s1x\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", DIM+HIDDEN]; - [m appendFormat:@" tensor h3 = slice_by_size(x=x,begin=b3,size=s1)[name=string(\"s3x\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor W2t = const()[name=string(\"W2t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w2t.bin\"), offset=uint64(64)))];\n", HIDDEN, DIM, HIDDEN, DIM]; - [m appendFormat:@" tensor dsilu = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W2t,x=dffn)[name=string(\"cw2\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor sig = sigmoid(x=h1)[name=string(\"sg\")];\n", HIDDEN, SEQ]; - [m appendString:@" fp16 one = const()[name=string(\"one\"), val=fp16(1.0)];\n"]; - [m appendFormat:@" tensor oms = sub(x=one,y=sig)[name=string(\"oms\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor homs = mul(x=h1,y=oms)[name=string(\"homs\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor brk = add(x=one,y=homs)[name=string(\"brk\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor dsd = mul(x=sig,y=brk)[name=string(\"dsd\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor t1 = mul(x=dsilu,y=h3)[name=string(\"t1\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor dh1 = mul(x=t1,y=dsd)[name=string(\"dh1\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor slh = mul(x=h1,y=sig)[name=string(\"slh\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor dh3 = mul(x=dsilu,y=slh)[name=string(\"dh3\")];\n", HIDDEN, SEQ]; - [m appendFormat:@" tensor W1t = const()[name=string(\"W1t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w1t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; - [m appendFormat:@" tensor W3t = const()[name=string(\"W3t\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/w3t.bin\"), offset=uint64(64)))];\n", DIM, HIDDEN, DIM, HIDDEN]; - [m appendFormat:@" tensor dx1 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W1t,x=dh1)[name=string(\"cw1\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor dx3 = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=W3t,x=dh3)[name=string(\"cw3\")];\n", DIM, SEQ]; - [m appendFormat:@" tensor dx = add(x=dx1,y=dx3)[name=string(\"adx\")];\n", DIM, SEQ]; - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dx,dh1,dh3))[name=string(\"cat\")];\n", DIM+2*HIDDEN, SEQ]; - [m appendString:@" } -> (out);\n}\n"]; - return m; -} - -// Fused QKV backward: concat(dq,dk,dv) → dx fp16 -static NSString *gen_qkvb(void) { - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", 3*DIM, SEQ]; - [m appendString:@CONV_CONST]; - [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor dq = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; - [m appendFormat:@" tensor dk = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; - [m appendFormat:@" tensor dv = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor Wqt = const()[name=string(\"Wqt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wqt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor Wkt = const()[name=string(\"Wkt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wkt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor Wvt = const()[name=string(\"Wvt\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wvt.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor dxq = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wqt,x=dq)[name=string(\"cq\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor dxk = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wkt,x=dk)[name=string(\"ck\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor dxv = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wvt,x=dv)[name=string(\"cv\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor dxqk = add(x=dxq,y=dxk)[name=string(\"aqk\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor out = add(x=dxqk,y=dxv)[name=string(\"out\")];\n", DIM,SEQ]; - [m appendString:@" } -> (out);\n}\n"]; - return m; -} - -// SDPA backward part 1 + Wo^T: concat(Q,K,V,dx2) → Wo^T(dx2) → concat(dV, probs_flat, dp_flat) fp16 -// SCORE_CH: channels needed for flattened attention scores [HEADS,SEQ,SEQ] → [HEADS*SEQ, SEQ] -#define SCORE_CH (HEADS*SEQ) - -static NSString *gen_sdpa_bwd1(void) { - float sc = 1.0f/sqrtf((float)HD); - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", 4*DIM, SEQ]; - [m appendString:@CONV_CONST]; - [m appendFormat:@" tensor sz = const()[name=string(\"sz\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b0,size=sz)[name=string(\"s0\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", DIM]; - [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b1,size=sz)[name=string(\"s1\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*DIM]; - [m appendFormat:@" tensor vf = slice_by_size(x=x,begin=b2,size=sz)[name=string(\"s2\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 3*DIM]; - [m appendFormat:@" tensor dx2f = slice_by_size(x=x,begin=b3,size=sz)[name=string(\"s3\")];\n", DIM,SEQ]; - // Wo^T backward: dx2 → dattn - [m appendFormat:@" tensor Wot = const()[name=string(\"Wot\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/wot.bin\"), offset=uint64(64)))];\n", DIM,DIM,DIM,DIM]; - [m appendFormat:@" tensor df = conv(dilations=dl,groups=gr,pad=pd,pad_type=pt,strides=st,weight=Wot,x=dx2f)[name=string(\"cwo\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; - [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor vr = reshape(shape=rsh,x=vf)[name=string(\"rv\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor v = transpose(perm=pm,x=vr)[name=string(\"tv\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor dr = reshape(shape=rsh,x=df)[name=string(\"rd\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor da = transpose(perm=pm,x=dr)[name=string(\"td\")];\n", HEADS,SEQ,HD]; - [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; - [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor sc1 = matmul(transpose_x=bF,transpose_y=bT,x=q,y=k)[name=string(\"mm1\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; - [m appendFormat:@" tensor sc2 = mul(x=sc1,y=scv)[name=string(\"scl\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor cm = const()[name=string(\"cm\"), val=tensor(BLOBFILE(path=string(\"@model_path/weights/mask.bin\"), offset=uint64(64)))];\n", SEQ,SEQ,SEQ,SEQ]; - [m appendFormat:@" tensor ms = add(x=sc2,y=cm)[name=string(\"msk\")];\n", HEADS,SEQ,SEQ]; - [m appendString:@" int32 sax = const()[name=string(\"sax\"), val=int32(-1)];\n"]; - [m appendFormat:@" tensor probs = softmax(axis=sax,x=ms)[name=string(\"sm\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor dv4 = matmul(transpose_x=bT,transpose_y=bF,x=probs,y=da)[name=string(\"dv\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor dp4 = matmul(transpose_x=bF,transpose_y=bT,x=da,y=v)[name=string(\"dp\")];\n", HEADS,SEQ,SEQ]; - // Flatten dv back to [1,DIM,1,SEQ] - [m appendFormat:@" tensor dvt = transpose(perm=pm,x=dv4)[name=string(\"dvt\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor dvs = const()[name=string(\"dvs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; - [m appendFormat:@" tensor dvf = reshape(shape=dvs,x=dvt)[name=string(\"dvf\")];\n", DIM,SEQ]; - // Flatten probs [1,H,S,S] → [1,H*S,1,S] and dp [1,H,S,S] → [1,H*S,1,S] - [m appendFormat:@" tensor scs = const()[name=string(\"scs\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH,SEQ]; - [m appendFormat:@" tensor pf = reshape(shape=scs,x=probs)[name=string(\"pf\")];\n", SCORE_CH,SEQ]; - [m appendFormat:@" tensor dpf = reshape(shape=scs,x=dp4)[name=string(\"dpf\")];\n", SCORE_CH,SEQ]; - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dvf,pf,dpf))[name=string(\"cat\")];\n", DIM+2*SCORE_CH,SEQ]; - [m appendString:@" } -> (out);\n}\n"]; - return m; -} - -// SDPA backward part 2: concat(probs[SCORE_CH],dp[SCORE_CH],Q[DIM],K[DIM]) → concat(dQ,dK) fp16 -static NSString *gen_sdpa_bwd2(void) { - float sc = 1.0f/sqrtf((float)HD); - int bwd2_in = 2*SCORE_CH + 2*DIM; - NSMutableString *m = [NSMutableString string]; - [m appendString:MIL_HDR]; - [m appendFormat:@" func main(tensor x) {\n", bwd2_in, SEQ]; - // Slice probs - [m appendFormat:@" tensor sz_sc = const()[name=string(\"szsc\"), val=tensor([1,%d,1,%d])];\n", SCORE_CH, SEQ]; - [m appendString:@" tensor b0 = const()[name=string(\"b0\"), val=tensor([0,0,0,0])];\n"]; - [m appendFormat:@" tensor pf = slice_by_size(x=x,begin=b0,size=sz_sc)[name=string(\"s0\")];\n", SCORE_CH,SEQ]; - // Slice dp - [m appendFormat:@" tensor b1 = const()[name=string(\"b1\"), val=tensor([0,%d,0,0])];\n", SCORE_CH]; - [m appendFormat:@" tensor dpf = slice_by_size(x=x,begin=b1,size=sz_sc)[name=string(\"s1\")];\n", SCORE_CH,SEQ]; - // Slice Q - [m appendFormat:@" tensor sz_d = const()[name=string(\"szd\"), val=tensor([1,%d,1,%d])];\n", DIM, SEQ]; - [m appendFormat:@" tensor b2 = const()[name=string(\"b2\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH]; - [m appendFormat:@" tensor qf = slice_by_size(x=x,begin=b2,size=sz_d)[name=string(\"s2\")];\n", DIM,SEQ]; - // Slice K - [m appendFormat:@" tensor b3 = const()[name=string(\"b3\"), val=tensor([0,%d,0,0])];\n", 2*SCORE_CH+DIM]; - [m appendFormat:@" tensor kf = slice_by_size(x=x,begin=b3,size=sz_d)[name=string(\"s3\")];\n", DIM,SEQ]; - // Reshape to multi-head - [m appendFormat:@" tensor ssh = const()[name=string(\"ssh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor probs = reshape(shape=ssh,x=pf)[name=string(\"rp\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor dp = reshape(shape=ssh,x=dpf)[name=string(\"rdp\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor rsh = const()[name=string(\"rsh\"), val=tensor([1,%d,%d,%d])];\n", HEADS,HD,SEQ]; - [m appendString:@" tensor pm = const()[name=string(\"pm\"), val=tensor([0,1,3,2])];\n"]; - [m appendFormat:@" tensor qr = reshape(shape=rsh,x=qf)[name=string(\"rq\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor q = transpose(perm=pm,x=qr)[name=string(\"tq\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor kr = reshape(shape=rsh,x=kf)[name=string(\"rk\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor k = transpose(perm=pm,x=kr)[name=string(\"tk\")];\n", HEADS,SEQ,HD]; - // Softmax grad: ds = probs * (dp - sum(probs*dp)) * scale - [m appendFormat:@" tensor pdp = mul(x=probs,y=dp)[name=string(\"pdp\")];\n", HEADS,SEQ,SEQ]; - [m appendString:@" tensor rax = const()[name=string(\"rax\"), val=tensor([-1])];\n"]; - [m appendString:@" bool kd = const()[name=string(\"kd\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor spdp = reduce_sum(x=pdp,axes=rax,keep_dims=kd)[name=string(\"rs\")];\n", HEADS,SEQ]; - [m appendFormat:@" tensor dps = sub(x=dp,y=spdp)[name=string(\"dps\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" tensor ds0 = mul(x=probs,y=dps)[name=string(\"ds0\")];\n", HEADS,SEQ,SEQ]; - [m appendFormat:@" fp16 scv = const()[name=string(\"scv\"), val=fp16(%f)];\n", sc]; - [m appendFormat:@" tensor ds = mul(x=ds0,y=scv)[name=string(\"ds\")];\n", HEADS,SEQ,SEQ]; - [m appendString:@" bool bF = const()[name=string(\"bF\"), val=bool(false)];\n"]; - [m appendString:@" bool bT = const()[name=string(\"bT\"), val=bool(true)];\n"]; - [m appendFormat:@" tensor dq4 = matmul(transpose_x=bF,transpose_y=bF,x=ds,y=k)[name=string(\"dq\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor dk4 = matmul(transpose_x=bT,transpose_y=bF,x=ds,y=q)[name=string(\"dk\")];\n", HEADS,SEQ,HD]; - [m appendFormat:@" tensor dqt = transpose(perm=pm,x=dq4)[name=string(\"dqt\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor dkt = transpose(perm=pm,x=dk4)[name=string(\"dkt\")];\n", HEADS,HD,SEQ]; - [m appendFormat:@" tensor fs = const()[name=string(\"fs\"), val=tensor([1,%d,1,%d])];\n", DIM,SEQ]; - [m appendFormat:@" tensor dqf = reshape(shape=fs,x=dqt)[name=string(\"dqf\")];\n", DIM,SEQ]; - [m appendFormat:@" tensor dkf = reshape(shape=fs,x=dkt)[name=string(\"dkf\")];\n", DIM,SEQ]; - [m appendString:@" int32 cax = const()[name=string(\"cax\"), val=int32(1)];\n"]; - [m appendString:@" bool cid = const()[name=string(\"cid\"), val=bool(false)];\n"]; - [m appendFormat:@" tensor out = concat(axis=cax,interleave=cid,values=(dqf,dkf))[name=string(\"cat\")];\n", 2*DIM,SEQ]; - [m appendString:@" } -> (out);\n}\n"]; - return m; -} - -// ===== Weight builders ===== -static NSData *g_mask_blob = nil; -static NSData *get_mask_blob(void) { - if (!g_mask_blob) { - _Float16 *mask = (_Float16*)calloc(SEQ*SEQ, sizeof(_Float16)); - for(int t=0;t