[feat] Optimize inference: vectorize ops (NEON/vDSP), gate debug output, skip unused ANE compilation, add round-trip benchmark timing, pure C HTTP API with tokenizer

This commit is contained in:
Erik Bray 2026-03-03 19:41:54 +01:00
parent 6f16dbefca
commit 0e70f5bd71
9 changed files with 2191 additions and 193 deletions

9
.gitignore vendored
View File

@ -22,8 +22,11 @@ training/train_double_buffer
training/test_*
!training/test_*.m
# Inference binaries
# Inference binaries and runtime data
inference/qwen_ane
inference/qwen05b.bin
inference/.venv/
inference/benchmark_results.json
# Dynamic training binaries
training/training_dynamic/train
@ -58,6 +61,10 @@ training/ane_stories110M_ckpt.bin
*.bin
!training/download_data.sh
# Secrets / env
.env
inference/.env
# Internal / private
.cursor/
docs/launch/

View File

@ -1,161 +1,241 @@
# ANE Inference — Full LLM on Apple Neural Engine
First complete LLM inference running directly on Apple's Neural Engine via reverse-engineered `_ANEClient` APIs. No CoreML. No Xcode compiler dependency at runtime. Token-for-token match with PyTorch.
First complete LLM inference running directly on Apple's Neural Engine via reverse-engineered `_ANEClient` APIs. No CoreML. No Xcode compiler dependency at runtime.
Built on top of the [maderix/ANE](https://github.com/maderix/ANE) training runtime.
## What This Does
Runs **Qwen2.5-0.5B-Instruct** (24 transformer layers, 494M parameters) entirely on the ANE:
Runs **Qwen2.5-0.5B-Instruct** (24 transformer layers, 494M parameters) on ANE:
- **169 ANE kernels** compiled at startup via `_ANEInMemoryModel`
- **82 tokens/sec** decode on M4 Pro
- **Zero GPU usage** — runs on 16 dedicated neural cores
- **Correct output** — matches PyTorch reference token-for-token
- **~60 tokens/sec** decode on M4 Max
- **Pure C HTTP API** — no Python needed for serving
- **BPE tokenizer in C** — send plain text, get plain text back
- **~6s cold start**, then instant responses in server mode
All linear projections (Q, K, V, O, gate, up, down × 24 layers + chunked LM head) compile as baked-weight 1×1 convolution kernels on ANE. Element-wise ops (RMSNorm, RoPE, softmax, SiLU, attention scores) run on CPU via Accelerate BLAS.
## Architecture
```
Token → Embedding (CPU) → 24× Transformer Layer → LM Head (CPU) → Next Token
├── RMSNorm (CPU)
├── Q/K/V Projection (ANE conv kernel)
├── RoPE (CPU, rotate_half)
├── GQA Attention (CPU, 14 heads / 2 KV heads)
├── O Projection (ANE conv kernel)
├── Residual (CPU)
├── RMSNorm (CPU)
├── Gate/Up Projection (ANE conv kernel)
├── SiLU + elementwise mul (CPU)
├── Down Projection (ANE conv kernel)
└── Residual (CPU)
```
## Quick Start
## Quick Start (One Command)
```bash
# 1. Convert weights from HuggingFace safetensors to flat binary
pip install safetensors torch transformers
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct qwen05b.bin
# 2. Build
xcrun clang -O2 -framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
-o qwen_ane main.m
# 3. Run (single-shot, pass space-separated token IDs)
./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20
# 4. With tokenizer (requires transformers)
python3 run.py "Say hello in one word."
cd inference
./setup.sh
```
## Server Mode (Recommended)
This automatically:
1. Creates a Python venv and installs dependencies
2. Downloads [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) from HuggingFace (~953 MB)
3. Converts BF16 safetensors to f32 binary format (~1.9 GB)
4. Builds the `qwen_ane` binary
5. Runs a smoke test
The first invocation compiles 169 ANE kernels (~5.5s). Server mode keeps them loaded so subsequent prompts respond instantly.
After setup, you're ready to go.
### Socket server (best for `run.py` integration)
## HTTP API (Recommended)
The fastest way to use inference. Single process, zero Python overhead.
```bash
# Terminal 1: start the server (compiles once, stays running)
# Start server (compiles 169 ANE kernels on first launch, ~6s)
./qwen_ane qwen05b.bin --http 8000
# Query with plain text — tokenization happens in C
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{"prompt": "What is 2+2?", "max_tokens": 50}'
```
Response:
```json
{
"text": "2+2 equals 4.",
"prompt_tokens": 29,
"gen_tokens": 8,
"prefill_tps": 66.2,
"decode_tps": 57.3,
"elapsed_s": 0.608
}
```
### Endpoints
| Method | Path | Description |
|--------|------|-------------|
| POST | `/v1/completions` | Generate text from a prompt |
| GET | `/health` | Server status check |
### POST /v1/completions
```json
{
"prompt": "Your question here",
"max_tokens": 50,
"system": "You are a helpful assistant."
}
```
- `prompt` (required): The user message
- `max_tokens` (optional, default 50, max 512): Maximum tokens to generate
- `system` (optional): System prompt override
### Options
```bash
# Custom port
./qwen_ane qwen05b.bin --http 9000
# Custom model directory (for tokenizer files)
./qwen_ane qwen05b.bin --http 8000 --model-dir /path/to/Qwen2.5-0.5B-Instruct
```
Default model directory: `~/models/Qwen2.5-0.5B-Instruct`
## Other Modes
### Socket server (for programmatic access)
```bash
# Terminal 1: start server
./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
# Terminal 2: queries are instant (~0.5s instead of ~6s)
# Terminal 2: query with run.py (auto-detects socket)
python3 run.py "What is 2+2?"
python3 run.py "Capital of France?"
python3 run.py "Count from 1 to 5"
```
`run.py` auto-detects the socket at `/tmp/qwen_ane.sock` and connects to it. If no server is running, it falls back to subprocess mode (slower).
You can also query the socket directly:
```bash
# Or query directly with nc
echo '{"tokens": [151644, 8948, 198], "max_tokens": 50}' | nc -U /tmp/qwen_ane.sock
```
Response format:
```json
{"output": [9707, 0, 151645], "prefill_tps": 68.4, "decode_tps": 67.8, "prompt_tokens": 28, "gen_tokens": 3}
```
### Stdin server (for piping/scripting)
```bash
./qwen_ane qwen05b.bin --server
# Waits for "READY", then send lines of space-separated token IDs:
# Send space-separated token IDs, pipe char separates max_tokens:
# 151644 8948 198 2610 525|20
# (pipe character separates max_tokens)
```
### Performance comparison
### Single-shot (no server)
```bash
# Raw token IDs
./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20
# With Python tokenizer
python3 run.py "Say hello in one word."
```
### Python API server (alternative)
If you prefer Python for the HTTP layer:
```bash
./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
python3 api_server.py --port 8000
```
## Throughput Benchmark
Run the standardized benchmark to measure your hardware's performance:
```bash
./benchmark.sh
```
This runs 5 prompts of varying length, measures prefill and decode tokens/sec in server mode, tests cold start latency, and checks decode speed consistency.
Sample output (M4 Max, 128 GB):
```
Prompt Input Output Prefill(t/s) Decode(t/s) Latency(ms)
──────────────────────────────────────────────────────────────────
tiny 23 10 53.7 53.6 632
short 29 8 66.2 49.5 628
medium 33 84 63.4 55.3 2064
long 36 200 66.4 54.5 4235
stress 122 11 58.6 58.5 2303
──────────────────────────────────────────────────────────────────
Average 61.7 54.3
Cold start (single-shot): ~6.2s (includes ANE kernel compilation)
```
Results are saved to `benchmark_results.json` for programmatic use.
### Compare with LM Studio
The benchmark script prints instructions for running the same prompts in LM Studio:
1. Download [LM Studio](https://lmstudio.ai)
2. Search for and download **Qwen2.5-0.5B-Instruct** (GGUF Q4_K_M or Q8_0)
3. Load the model, start the server (Developer tab, port 1234)
4. Run the same prompts and compare tokens/sec:
```bash
curl http://localhost:1234/api/v1/chat \
-H "Content-Type: application/json" \
-d '{"model":"qwen2.5-0.5b-instruct","system_prompt":"You are a helpful assistant.","input":"What is 2+2?"}'
```
Note: LM Studio uses quantized GGUF weights (CPU/GPU) while we use full BF16 precision on the Neural Engine.
## Performance
| Mode | First prompt | Subsequent prompts |
|------|-------------|-------------------|
| Single-shot | ~6s | ~6s (recompiles) |
| Server | ~6s (startup) | ~0.5s |
| Single-shot | ~6s | ~6s (recompiles each time) |
| Server (socket/HTTP) | ~6s (startup) | ~0.5s |
## Output
## Architecture
```
=== Qwen2.5-0.5B ANE Inference ===
Loading weights...
Config: dim=896 hidden=4864 layers=24 heads=14 kv_heads=2 vocab=151936
Compiling ANE kernels (169 total)...
Compile time: 5.1s
Prompt: 28 tokens, generating up to 10
Prefill: 64.2 t/s (28 tokens)
OUT: 9707 13 151645
Decode: 82.4 t/s (2 tokens)
→ "Hello." (matches PyTorch exactly)
Token -> Embedding (CPU) -> 24x Transformer Layer -> LM Head (CPU) -> Next Token
|
+-- RMSNorm (CPU)
+-- Q/K/V Projection (ANE conv kernel)
+-- RoPE (CPU, rotate_half)
+-- GQA Attention (CPU, 14 heads / 2 KV heads)
+-- O Projection (ANE conv kernel)
+-- Residual (CPU)
+-- RMSNorm (CPU)
+-- Gate/Up Projection (ANE conv kernel)
+-- SiLU + elementwise mul (CPU)
+-- Down Projection (ANE conv kernel)
+-- Residual (CPU)
```
## Files
| File | What |
|------|------|
| `setup.sh` | One-command setup: downloads model, converts weights, builds binary |
| `benchmark.sh` | Throughput benchmark with LM Studio comparison |
| `main.m` | Entry point: weight loader, server modes, HTTP API |
| `qwen_ane_infer.h` | Full 24-layer transformer forward pass, ANE kernel compilation, KV cache |
| `main.m` | Weight loader, token I/O, main generation loop |
| `convert_weights.py` | HuggingFace safetensors → flat f32 binary (includes Q/K/V biases) |
| `run.py` | Python wrapper with HuggingFace tokenizer |
| `tokenizer.h` | BPE tokenizer in C: vocab/merge loading, encode/decode, chat template |
| `http_server.h` | Minimal HTTP/1.1 server: TCP, request parsing, JSON responses |
| `convert_weights.py` | HuggingFace safetensors to flat f32 binary |
| `run.py` | Python wrapper with HuggingFace tokenizer (auto-connects to socket server) |
| `api_server.py` | Python HTTP API bridge to socket server (alternative to C HTTP) |
## Model Support
## Model
Currently implements **Qwen2.5** architecture:
- GQA attention (grouped-query, `n_heads``n_kv_heads`)
- `rotate_half` RoPE (not interleaved pairs)
- SwiGLU FFN (gate + up + silu + down)
- Q/K/V bias (Qwen-specific)
- Tied word embeddings (lm_head = embed)
- Chunked LM head (vocab > 65536 exceeds ANE max dim)
**[Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct)**
Adapting to other architectures (LLaMA, Gemma, Mistral) requires:
1. Adjusting the config constants in `qwen_ane_infer.h`
2. Updating `convert_weights.py` for the weight naming scheme
3. Removing Q/K/V bias handling if the model doesn't have them
4. Switching RoPE to interleaved pairs if needed
- 494M parameters, BFloat16
- 24 layers, 896 dim, 4864 hidden
- 14 attention heads, 2 KV heads (GQA)
- 151,936 vocab size
- Download: `setup.sh` handles this automatically
## Requirements
- macOS 15+ on Apple Silicon (M1/M2/M3/M4)
- Xcode Command Line Tools (for `xcrun clang`)
- Python 3.9+ with `safetensors`, `torch`, `transformers` (for weight conversion)
- Xcode Command Line Tools (`xcode-select --install`)
- Python 3.11+ (for weight conversion only, not needed for serving)
## Known Limitations
- **CPU projections only** — ANE baked-weight conv kernels compile successfully but produce incorrect output (FP16 weight blob format mismatch). The `USE_ANE_PROJECTIONS` toggle exists but defaults to 0 (CPU via Accelerate BLAS). Fixing this would push decode speed from 82 t/s to 120+ t/s.
- **Single model** — hardcoded for Qwen2.5-0.5B. Needs parameterization for other sizes.
- **f32 weights** — 1.9GB on disk. FP16 or quantized weight support would halve this.
## How It Works
The key insight from maderix's reverse engineering: the ANE executes compiled MIL (Machine Learning Intermediate Language) programs as atomic graph operations. Each linear projection becomes a MIL program with baked FP16 weights, compiled in-memory via `_ANEInMemoryModel`, and executed through IOSurface-based zero-copy I/O.
We chain 169 of these atomic operations (7 per transformer layer + 16 LM head chunks) with CPU-side element-wise ops in between. The ANE handles the compute-heavy matmuls; the CPU handles the memory-bound operations (attention scores, softmax, RoPE).
- **CPU projections only** — ANE baked-weight conv kernels compile but produce incorrect output (FP16 weight blob format mismatch). `USE_ANE_PROJECTIONS` defaults to 0 (CPU via Accelerate BLAS). Fixing this would increase decode speed significantly.
- **Single model** — hardcoded for Qwen2.5-0.5B. Other sizes need config changes.
- **f32 weights** — 1.9GB on disk. FP16 weight support would halve this.
- **Single-threaded HTTP** — handles one request at a time. Sufficient for local use.
## License

172
inference/api_server.py Normal file
View File

@ -0,0 +1,172 @@
#!/usr/bin/env python3
"""HTTP API server for ANE inference.
Bridges HTTP requests to the qwen_ane Unix socket server. Handles tokenization
so clients can send plain text prompts and receive decoded responses.
Prerequisites:
1. Start the ANE server: ./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
2. Start this API: python3 api_server.py [--port 8000]
Usage:
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{"prompt": "What is 2+2?", "max_tokens": 50}'
curl http://localhost:8000/health
"""
import argparse
import json
import os
import socket
import sys
import time
from http.server import HTTPServer, BaseHTTPRequestHandler
from pathlib import Path
DEFAULT_SOCK = "/tmp/qwen_ane.sock"
MODEL_DIR = Path.home() / "models" / "Qwen2.5-0.5B-Instruct"
tokenizer = None
def get_tokenizer():
global tokenizer
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), trust_remote_code=True)
return tokenizer
def query_ane(token_ids: list[int], max_tokens: int, sock_path: str) -> dict:
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.settimeout(120)
s.connect(sock_path)
req = json.dumps({"tokens": token_ids, "max_tokens": max_tokens}) + "\n"
s.sendall(req.encode())
data = b""
while True:
chunk = s.recv(131072)
if not chunk:
break
data += chunk
if b"\n" in data:
break
s.close()
return json.loads(data.decode().strip())
class ANEHandler(BaseHTTPRequestHandler):
sock_path = DEFAULT_SOCK
def log_message(self, format, *args):
sys.stderr.write(f"[{time.strftime('%H:%M:%S')}] {format % args}\n")
def _send_json(self, code, obj):
body = json.dumps(obj).encode()
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.send_header("Access-Control-Allow-Origin", "*")
self.end_headers()
self.wfile.write(body)
def do_OPTIONS(self):
self.send_response(204)
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
self.send_header("Access-Control-Allow-Headers", "Content-Type")
self.end_headers()
def do_GET(self):
if self.path == "/health":
alive = os.path.exists(self.sock_path)
self._send_json(200, {"status": "ok" if alive else "no_backend", "socket": self.sock_path})
return
self._send_json(404, {"error": "not found"})
def do_POST(self):
if self.path != "/v1/completions":
self._send_json(404, {"error": "not found, use POST /v1/completions"})
return
length = int(self.headers.get("Content-Length", 0))
if length == 0 or length > 65536:
self._send_json(400, {"error": "invalid content length"})
return
try:
body = json.loads(self.rfile.read(length))
except json.JSONDecodeError:
self._send_json(400, {"error": "invalid JSON"})
return
prompt = body.get("prompt", "")
max_tokens = min(body.get("max_tokens", 50), 512)
system_prompt = body.get("system", "You are a helpful assistant. Be concise.")
if not prompt:
self._send_json(400, {"error": "missing 'prompt' field"})
return
tok = get_tokenizer()
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tok.encode(text)
t0 = time.time()
try:
result = query_ane(input_ids, max_tokens, self.sock_path)
except (ConnectionRefusedError, FileNotFoundError, OSError) as e:
self._send_json(503, {"error": f"ANE backend unavailable: {e}"})
return
elapsed = time.time() - t0
output_ids = result.get("output", [])
decoded = tok.decode(output_ids, skip_special_tokens=True) if output_ids else ""
self._send_json(200, {
"text": decoded,
"output_tokens": output_ids,
"prompt_tokens": len(input_ids),
"gen_tokens": len(output_ids),
"prefill_tps": result.get("prefill_tps", 0),
"decode_tps": result.get("decode_tps", 0),
"elapsed_s": round(elapsed, 3),
})
def main():
parser = argparse.ArgumentParser(description="HTTP API for ANE inference")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--sock", type=str, default=DEFAULT_SOCK)
args = parser.parse_args()
ANEHandler.sock_path = args.sock
print(f"Loading tokenizer from {MODEL_DIR}...")
get_tokenizer()
print("Tokenizer ready.")
if not os.path.exists(args.sock):
print(f"WARNING: Socket {args.sock} not found. Start the ANE server first:")
print(f" ./qwen_ane qwen05b.bin --server {args.sock}")
server = HTTPServer((args.host, args.port), ANEHandler)
print(f"API server listening on http://{args.host}:{args.port}")
print(f" POST /v1/completions {{\"prompt\": \"...\", \"max_tokens\": 50}}")
print(f" GET /health")
try:
server.serve_forever()
except KeyboardInterrupt:
print("\nShutting down.")
server.server_close()
if __name__ == "__main__":
main()

393
inference/benchmark.sh Executable file
View File

@ -0,0 +1,393 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
# Load .env if present (LMS_API_KEY, LMS_PORT, LMS_MODEL)
if [ -f "$SCRIPT_DIR/.env" ]; then
set -a
source "$SCRIPT_DIR/.env"
set +a
fi
BINARY="$SCRIPT_DIR/qwen_ane"
WEIGHTS="$SCRIPT_DIR/qwen05b.bin"
MODEL_DIR="${MODEL_DIR:-$HOME/models/Qwen2.5-0.5B-Instruct}"
SOCK="/tmp/qwen_ane_bench.sock"
HTTP_PORT=8877
RESULTS_JSON="$SCRIPT_DIR/benchmark_results.json"
# --- Prompt suite ---
PROMPT_NAMES=( "tiny" "short" "medium" "long" "stress")
PROMPTS=( "Hi" "What is 2+2?" "Explain how neural networks work in 3 sentences." "Write a short story about a robot learning to paint. Include dialogue." "The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog.")
MAX_TOKENS=( 10 20 100 200 50)
info() { printf "\033[1;34m%s\033[0m\n" "$1"; }
dim() { printf "\033[2m%s\033[0m\n" "$1"; }
# Extract a numeric or string value from flat JSON. No python needed.
# Usage: json_val '{"key":123}' "key" → 123
json_val() {
local json="$1" key="$2"
echo "$json" | sed -n "s/.*\"$key\"[[:space:]]*:[[:space:]]*\"\{0,1\}\([^,\"}\]*\)\"\{0,1\}.*/\1/p" | head -1
}
# Extract the "text" field which may contain escaped chars and commas.
# Grabs everything between "text":" and the next unescaped quote.
json_text() {
local json="$1"
echo "$json" | sed -n 's/.*"text":"\(.*\)","prompt_tokens".*/\1/p' | sed 's/\\n/ /g; s/\\"//g'
}
# Truncate a float string to integer: "317.2" → "317"
trunc() { echo "${1%%.*}"; }
# Average an array of numbers using awk. Handles both ints and floats.
# Usage: shell_avg "1.5" "2.3" "3.1" → 2.3
shell_avg() { printf '%s\n' "$@" | awk '{s+=$1; n++} END {if(n>0) printf "%.1f", s/n; else print "0"}'; }
shell_avg_int() { printf '%s\n' "$@" | awk '{s+=$1; n++} END {if(n>0) printf "%.0f", s/n; else print "0"}'; }
# --- Preflight ---
if [ ! -f "$BINARY" ]; then
echo "Binary not found: $BINARY"
echo "Run setup.sh first: $SCRIPT_DIR/setup.sh"
exit 1
fi
if [ ! -f "$WEIGHTS" ]; then
echo "Weights not found: $WEIGHTS"
echo "Run setup.sh first: $SCRIPT_DIR/setup.sh"
exit 1
fi
# Detect hardware
CHIP=$(sysctl -n machdep.cpu.brand_string 2>/dev/null || echo "Unknown")
MACOS=$(sw_vers -productVersion 2>/dev/null || echo "Unknown")
MEM_BYTES=$(sysctl -n hw.memsize 2>/dev/null || echo "0")
MEM_GB=$((MEM_BYTES / 1073741824))
echo ""
info "=== ANE Inference Benchmark (qwen_ane) ==="
echo "Hardware: $CHIP"
echo "macOS: $MACOS"
echo "Memory: ${MEM_GB} GB"
echo "Model: Qwen2.5-0.5B-Instruct (BF16, 494M params)"
echo ""
# --- Phase 1: Server mode benchmark (HTTP API) ---
info "Phase 1: Server mode (persistent ANE kernels via HTTP API)"
dim "Starting server on port $HTTP_PORT..."
# Start HTTP server in background
"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 &
SERVER_PID=$!
cleanup() {
kill "$SERVER_PID" 2>/dev/null || true
rm -f "$SOCK" /tmp/qwen_bench_server.log
}
trap cleanup EXIT
# Wait for READY
for i in $(seq 1 30); do
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then
break
fi
sleep 1
done
if ! grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then
echo "Server failed to start. Log:"
cat /tmp/qwen_bench_server.log
exit 1
fi
dim "Server ready (PID $SERVER_PID)"
echo ""
# Warmup: first request primes any remaining caches
dim "Warmup run (discarded)..."
curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
-H "Content-Type: application/json" \
-d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1
echo ""
# Print table header
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead"
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)"
printf '%.0s─' {1..85}; echo ""
# Arrays for averages
declare -a P_TPS_ARR D_TPS_ARR INF_MS_ARR TTFT_MS_ARR RT_MS_ARR
JSON_ENTRIES=""
NUM_PROMPTS=${#PROMPTS[@]}
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
NAME="${PROMPT_NAMES[$i]}"
PROMPT="${PROMPTS[$i]}"
MAXTOK="${MAX_TOKENS[$i]}"
RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
-H "Content-Type: application/json" \
-d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1)
RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
# Parse server JSON with pure shell -- no python
P_TOKENS=$(json_val "$RESP" "prompt_tokens")
G_TOKENS=$(json_val "$RESP" "gen_tokens")
P_TPS=$(json_val "$RESP" "prefill_tps")
D_TPS=$(json_val "$RESP" "decode_tps")
TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")")
INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")")
TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")")
TEXT=$(json_text "$RESP")
OVERHEAD=$((RT_MS - TOTAL_MS))
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD"
P_TPS_ARR+=("$P_TPS")
D_TPS_ARR+=("$D_TPS")
INF_MS_ARR+=("$INF_MS")
TTFT_MS_ARR+=("$TTFT_MS")
RT_MS_ARR+=("$RT_MS")
# Build JSON entry
JSON_ENTRIES="$JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS},"
# Print response text indented below
echo "$TEXT"
echo ""
done
printf '%.0s─' {1..85}; echo ""
# Averages (pure shell, no python)
AVG_P=$(shell_avg "${P_TPS_ARR[@]}")
AVG_D=$(shell_avg "${D_TPS_ARR[@]}")
AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}")
AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}")
AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}")
AVG_OVERHEAD=$((AVG_RT - AVG_INF))
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$AVG_P" "$AVG_D" "$AVG_TTFT" "$AVG_INF" "$AVG_RT" "$AVG_OVERHEAD"
echo ""
info "Infer = server-reported (pure processing). Rndtrip = wall-clock (what clients see)."
echo ""
# --- Phase 2: Cold start measurement ---
info "Phase 2: Cold start (single-shot, recompiles ANE kernels)"
# Kill server, run single-shot
kill "$SERVER_PID" 2>/dev/null || true
sleep 1
# Use perl for sub-second timing (available on all macOS, no python)
COLD_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
COLD_OUT=$("$BINARY" "$WEIGHTS" "151644 8948 198 2610 525 264 10950 17847 13 151645 198 151644 872 198 13048 151645 198 151644 77091 198" 10 2>&1 || true)
COLD_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
COLD_MS=$(echo "$COLD_T0 $COLD_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
echo "Cold start latency: ${COLD_MS}ms (includes ANE kernel compilation)"
echo ""
# Re-start server for any additional tests
"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 &
SERVER_PID=$!
# --- Phase 3: Repeated prompt (consistency check) ---
info "Phase 3: Decode speed consistency (5x same prompt)"
for retry in $(seq 1 15); do
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then break; fi
sleep 1
done
printf "%-6s %10s %10s %10s\n" "Run" "Prefill" "Decode" "Infer(ms)"
printf '%.0s─' {1..40}; echo ""
for run in $(seq 1 5); do
RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
-H "Content-Type: application/json" \
-d '{"prompt": "Count from 1 to 10", "max_tokens": 50}' 2>&1)
P=$(json_val "$RESP" "prefill_tps")
D=$(json_val "$RESP" "decode_tps")
IM=$(trunc "$(json_val "$RESP" "inference_ms")")
printf "%-6s %10s %10s %10s\n" "#$run" "$P" "$D" "$IM"
done
echo ""
# --- Save JSON results ---
JSON="{
\"hardware\": \"$CHIP\",
\"macos\": \"$MACOS\",
\"memory_gb\": $MEM_GB,
\"model\": \"Qwen2.5-0.5B-Instruct\",
\"mode\": \"http_server\",
\"cold_start_ms\": $COLD_MS,
\"avg_prefill_tps\": $AVG_P,
\"avg_decode_tps\": $AVG_D,
\"avg_inference_ms\": $AVG_INF,
\"avg_roundtrip_ms\": $AVG_RT,
\"avg_ttft_ms\": $AVG_TTFT,
\"results\": [${JSON_ENTRIES%,}]
}"
echo "$JSON" > "$RESULTS_JSON"
dim "Results saved to $RESULTS_JSON"
echo ""
# --- Phase 4: LM Studio comparison (if running) ---
LMS_PORT="${LMS_PORT:-1234}"
LMS_MODEL="${LMS_MODEL:-qwen2.5-0.5b-instruct}"
LMS_API_KEY="${LMS_API_KEY:-}"
# Check if LM Studio is running
LMS_REACHABLE=0
if curl -s --max-time 2 "http://localhost:$LMS_PORT/api/v1/chat" -H "Content-Type: application/json" -d '{}' >/dev/null 2>&1; then
LMS_REACHABLE=1
fi
if [ "$LMS_REACHABLE" -eq 1 ]; then
info "Phase 4: LM Studio comparison (localhost:$LMS_PORT)"
# If no API key, prompt for it
if [ -z "$LMS_API_KEY" ]; then
echo ""
echo " LM Studio requires an API key."
echo " Find it in LM Studio > Developer tab > API key"
echo " Or set LMS_API_KEY env var before running."
echo ""
printf " Enter LM Studio API key (or press Enter to skip): "
read -r LMS_API_KEY
if [ -z "$LMS_API_KEY" ]; then
dim "Skipping LM Studio benchmark."
LMS_REACHABLE=0
fi
fi
fi
if [ "$LMS_REACHABLE" -eq 1 ] && [ -n "$LMS_API_KEY" ]; then
echo ""
printf "%-10s %5s %5s %10s %10s %10s\n" \
"Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip"
printf "%-10s %5s %5s %10s %10s %10s\n" \
"" "tok" "tok" "(t/s)" "(ms)" "(ms)"
printf '%.0s─' {1..55}; echo ""
declare -a LMS_LATENCIES LMS_TPS_ARR LMS_TTFT_ARR
LMS_JSON_ENTRIES=""
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
NAME="${PROMPT_NAMES[$i]}"
PROMPT="${PROMPTS[$i]}"
T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $LMS_API_KEY" \
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1)
T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
eval "$(echo "$LMS_RESP" | python3 -c "
import sys, json
try:
r = json.load(sys.stdin)
text = r.get('output', [{}])[0].get('content', '').replace(chr(10),' ').replace('\"', '')
s = r.get('stats', {})
tps = s.get('tokens_per_second', 0)
ttft = int(s.get('time_to_first_token_seconds', 0) * 1000)
in_tok = s.get('input_tokens', 0)
out_tok = s.get('total_output_tokens', 0)
print(f'LMS_TEXT=\"{text}\"')
print(f'LMS_TPS={tps:.1f}')
print(f'LMS_TTFT={ttft}')
print(f'LMS_IN={in_tok}')
print(f'LMS_OUT={out_tok}')
except Exception as e:
print(f'LMS_TEXT=\"(parse error)\"')
print('LMS_TPS=0')
print('LMS_TTFT=0')
print('LMS_IN=0')
print('LMS_OUT=0')
" 2>/dev/null)"
printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS"
echo "$LMS_TEXT"
echo ""
LMS_LATENCIES+=("$LMS_MS")
LMS_TPS_ARR+=("$LMS_TPS")
LMS_TTFT_ARR+=("$LMS_TTFT")
LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT},"
done
printf '%.0s─' {1..55}; echo ""
# Averages (awk, no python)
LMS_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}")
LMS_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}")
LMS_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}")
printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$LMS_AVG_TPS" "$LMS_AVG_TTFT" "$LMS_AVG_LAT"
echo ""
# Side-by-side comparison
info "=== Side-by-Side Comparison ==="
dim "(Round-trip = wall-clock from client, apples-to-apples)"
echo ""
printf "%-24s %15s %15s\n" "" "ANE (qwen_ane)" "LM Studio"
printf '%.0s─' {1..56}; echo ""
printf "%-24s %12s t/s %12s t/s\n" "Decode speed" "$AVG_D" "$LMS_AVG_TPS"
printf "%-24s %12s t/s %12s\n" "Prefill speed" "$AVG_P" "N/A"
printf "%-24s %12s ms %12s ms\n" "TTFT" "$AVG_TTFT" "$LMS_AVG_TTFT"
printf "%-24s %12s ms %12s ms\n" "Avg round-trip" "$AVG_RT" "$LMS_AVG_LAT"
printf "%-24s %12s ms %12s ms\n" " (server-only)" "$AVG_INF" "N/A"
printf "%-24s %12s ms %12s\n" "Cold start" "$COLD_MS" "N/A"
printf "%-24s %15s %15s\n" "Precision" "F32 (from BF16)" "GGUF quantized"
printf "%-24s %15s %15s\n" "Accelerator" "Neural Engine" "CPU/GPU"
printf "%-24s %15s %15s\n" "Timing method" "Wall-clock" "Wall-clock"
echo ""
# Append LM Studio block to JSON results (pure shell, no python)
# Remove trailing "}" and newline, append lm_studio object
LMS_JSON_BLOCK=",
\"lm_studio\": {
\"port\": $LMS_PORT,
\"model\": \"$LMS_MODEL\",
\"avg_latency_ms\": $LMS_AVG_LAT,
\"avg_tps\": $LMS_AVG_TPS,
\"avg_ttft_ms\": $LMS_AVG_TTFT,
\"results\": [${LMS_JSON_ENTRIES%,}]
}
}"
# Replace the final "}" with the LM Studio block
sed -i '' '$ s/}$//' "$RESULTS_JSON"
printf '%s\n' "$LMS_JSON_BLOCK" >> "$RESULTS_JSON"
dim "LM Studio results added to $RESULTS_JSON"
else
info "=== LM Studio Comparison ==="
echo ""
if [ "$LMS_REACHABLE" -eq 0 ]; then
echo " LM Studio server not detected on localhost:$LMS_PORT"
echo ""
echo " To enable automatic comparison:"
echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF)"
echo " 2. Load the model, go to Developer tab > Start Server"
echo " 3. Re-run this benchmark"
echo ""
echo " Or set env vars: LMS_PORT=1234 LMS_API_KEY=your-key ./benchmark.sh"
fi
echo ""
echo " Manual test:"
echo " curl http://localhost:1234/api/v1/chat \\"
echo " -H 'Content-Type: application/json' \\"
echo " -H 'Authorization: Bearer YOUR_API_KEY' \\"
echo " -d '{\"model\":\"qwen2.5-0.5b-instruct\",\"system_prompt\":\"You are a helpful assistant.\",\"input\":\"What is 2+2?\"}'"
echo ""
echo " ANE (this benchmark): prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms"
echo ""
echo " Note: LM Studio uses quantized GGUF (CPU/GPU) while we use"
echo " BF16 weights (full precision) running on the Neural Engine."
fi
echo ""

221
inference/http_server.h Normal file
View File

@ -0,0 +1,221 @@
// http_server.h -- Minimal HTTP/1.1 server for ANE inference API
// Handles GET /health and POST /v1/completions using raw POSIX sockets.
// No external dependencies.
#pragma once
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>
#include <time.h>
#define HTTP_MAX_REQUEST 65536
#define HTTP_MAX_RESPONSE 262144
#define HTTP_MAX_BODY 65536
// --- HTTP request parsing ---
typedef struct {
char method[8]; // GET, POST, etc.
char path[256]; // /v1/completions, /health, etc.
char body[HTTP_MAX_BODY];
int body_len;
int content_length;
} HttpRequest;
static int http_parse_request(const char *raw, int raw_len, HttpRequest *req) {
memset(req, 0, sizeof(HttpRequest));
// Parse request line: METHOD PATH HTTP/1.1\r\n
const char *p = raw;
int i = 0;
while (*p && *p != ' ' && i < 7) req->method[i++] = *p++;
req->method[i] = '\0';
if (*p == ' ') p++;
i = 0;
while (*p && *p != ' ' && *p != '?' && i < 255) req->path[i++] = *p++;
req->path[i] = '\0';
// Skip to end of request line
while (*p && *p != '\n') p++;
if (*p) p++;
// Parse headers (only need Content-Length)
req->content_length = 0;
while (*p && !(*p == '\r' && *(p+1) == '\n') && *p != '\n') {
if (strncasecmp(p, "Content-Length:", 15) == 0) {
req->content_length = atoi(p + 15);
}
while (*p && *p != '\n') p++;
if (*p) p++;
}
// Skip blank line
if (*p == '\r') p++;
if (*p == '\n') p++;
// Copy body
int remaining = raw_len - (int)(p - raw);
req->body_len = remaining < HTTP_MAX_BODY - 1 ? remaining : HTTP_MAX_BODY - 1;
if (req->body_len > 0) memcpy(req->body, p, req->body_len);
req->body[req->body_len] = '\0';
return 0;
}
// --- HTTP response sending ---
static void http_send(int fd, int status, const char *status_text,
const char *content_type, const char *body, int body_len) {
char header[1024];
int hlen = snprintf(header, sizeof(header),
"HTTP/1.1 %d %s\r\n"
"Content-Type: %s\r\n"
"Content-Length: %d\r\n"
"Access-Control-Allow-Origin: *\r\n"
"Access-Control-Allow-Methods: POST, GET, OPTIONS\r\n"
"Access-Control-Allow-Headers: Content-Type\r\n"
"Connection: close\r\n"
"\r\n",
status, status_text, content_type, body_len);
write(fd, header, hlen);
if (body_len > 0) write(fd, body, body_len);
}
static void http_send_json(int fd, int status, const char *json) {
const char *status_text = "OK";
if (status == 400) status_text = "Bad Request";
else if (status == 404) status_text = "Not Found";
else if (status == 503) status_text = "Service Unavailable";
http_send(fd, status, status_text, "application/json", json, (int)strlen(json));
}
// --- Minimal JSON field extraction ---
static int http_json_get_string(const char *json, const char *key,
char *out, int max_out) {
char search[256];
snprintf(search, sizeof(search), "\"%s\"", key);
const char *p = strstr(json, search);
if (!p) return -1;
p += strlen(search);
while (*p && (*p == ' ' || *p == ':' || *p == '\t')) p++;
if (*p != '"') return -1;
p++;
int n = 0;
while (*p && *p != '"' && n < max_out - 1) {
if (*p == '\\') {
p++;
switch (*p) {
case 'n': out[n++] = '\n'; break;
case 't': out[n++] = '\t'; break;
case '"': out[n++] = '"'; break;
case '\\': out[n++] = '\\'; break;
default: out[n++] = *p;
}
} else {
out[n++] = *p;
}
p++;
}
out[n] = '\0';
return n;
}
static int http_json_get_int(const char *json, const char *key, int default_val) {
char search[256];
snprintf(search, sizeof(search), "\"%s\"", key);
const char *p = strstr(json, search);
if (!p) return default_val;
p += strlen(search);
while (*p && (*p == ' ' || *p == ':' || *p == '\t')) p++;
if (*p == '-' || (*p >= '0' && *p <= '9'))
return (int)strtol(p, NULL, 10);
return default_val;
}
// --- TCP server ---
typedef void (*HttpHandler)(int client_fd, HttpRequest *req, void *ctx);
static int http_serve(int port, HttpHandler handler, void *ctx) {
int srv = socket(AF_INET, SOCK_STREAM, 0);
if (srv < 0) { perror("socket"); return -1; }
int opt = 1;
setsockopt(srv, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
addr.sin_port = htons(port);
if (bind(srv, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
perror("bind"); close(srv); return -1;
}
if (listen(srv, 8) < 0) {
perror("listen"); close(srv); return -1;
}
printf("HTTP server listening on http://127.0.0.1:%d\n", port);
printf(" POST /v1/completions {\"prompt\": \"...\", \"max_tokens\": 50}\n");
printf(" GET /health\n");
printf("READY\n");
fflush(stdout);
while (1) {
int client = accept(srv, NULL, NULL);
if (client < 0) { perror("accept"); continue; }
// Read full request (headers + body)
char buf[HTTP_MAX_REQUEST];
int total = 0;
int headers_done = 0;
int content_length = 0;
int body_start = 0;
while (total < HTTP_MAX_REQUEST - 1) {
ssize_t n = read(client, buf + total, HTTP_MAX_REQUEST - 1 - total);
if (n <= 0) break;
total += n;
buf[total] = '\0';
if (!headers_done) {
char *hend = strstr(buf, "\r\n\r\n");
if (hend) {
headers_done = 1;
body_start = (int)(hend - buf) + 4;
// Extract Content-Length
char *cl = strcasestr(buf, "Content-Length:");
if (cl) content_length = atoi(cl + 15);
}
}
if (headers_done) {
int body_received = total - body_start;
if (body_received >= content_length) break;
}
}
HttpRequest req;
http_parse_request(buf, total, &req);
// Handle OPTIONS preflight
if (strcmp(req.method, "OPTIONS") == 0) {
http_send(client, 204, "No Content", "text/plain", "", 0);
close(client);
continue;
}
handler(client, &req, ctx);
close(client);
}
return 0;
}

View File

@ -1,8 +1,9 @@
// main.m -- Qwen2.5-0.5B inference on Apple Neural Engine
// Supports three modes:
// Supports four modes:
// 1. Single-shot: ./qwen_ane weights.bin "token_ids" [max_tokens]
// 2. Stdin server: ./qwen_ane weights.bin --server
// 3. Socket server: ./qwen_ane weights.bin --server /tmp/qwen_ane.sock
// 4. HTTP API: ./qwen_ane weights.bin --http 8000 --model-dir ~/models/Qwen2.5-0.5B-Instruct
//
// Build:
// xcrun clang -O2 -framework Foundation -framework IOSurface \
@ -19,10 +20,14 @@
#include <unistd.h>
#include <signal.h>
#include "qwen_ane_infer.h"
#include "tokenizer.h"
#include "http_server.h"
int g_fp16_io = 0;
static QwenModel g_model;
static const char *g_sock_path = NULL;
static Tokenizer g_tokenizer;
static int g_tokenizer_loaded = 0;
static void cleanup_socket(void) {
if (g_sock_path) unlink(g_sock_path);
@ -280,15 +285,123 @@ static void run_socket_server(const char *sock_path) {
}
}
// --- HTTP API handler ---
static void http_api_handler(int client_fd, HttpRequest *req, void *ctx) {
(void)ctx;
if (strcmp(req->method, "GET") == 0 && strcmp(req->path, "/health") == 0) {
http_send_json(client_fd, 200, "{\"status\":\"ok\",\"mode\":\"http\"}");
return;
}
if (strcmp(req->method, "POST") != 0 || strcmp(req->path, "/v1/completions") != 0) {
http_send_json(client_fd, 404, "{\"error\":\"not found, use POST /v1/completions\"}");
return;
}
if (req->body_len == 0) {
http_send_json(client_fd, 400, "{\"error\":\"empty body\"}");
return;
}
char prompt[32768];
if (http_json_get_string(req->body, "prompt", prompt, sizeof(prompt)) < 0) {
http_send_json(client_fd, 400, "{\"error\":\"missing 'prompt' field\"}");
return;
}
int max_tokens = http_json_get_int(req->body, "max_tokens", 50);
if (max_tokens > 512) max_tokens = 512;
if (max_tokens < 1) max_tokens = 1;
char system_prompt[4096];
if (http_json_get_string(req->body, "system", system_prompt, sizeof(system_prompt)) < 0)
strcpy(system_prompt, "You are a helpful assistant. Be concise.");
// Time tokenization separately
struct timespec t_tok0, t_tok1, t_gen0, t_gen1, t_det0, t_det1;
clock_gettime(CLOCK_MONOTONIC, &t_tok0);
int input_ids[4096];
int n_input = tok_encode_chat(&g_tokenizer, system_prompt, prompt, input_ids, 4096);
clock_gettime(CLOCK_MONOTONIC, &t_tok1);
double tokenize_ms = timespec_diff(&t_tok0, &t_tok1) * 1000.0;
if (n_input == 0) {
http_send_json(client_fd, 400, "{\"error\":\"tokenization produced no tokens\"}");
return;
}
// Pure inference timing
clock_gettime(CLOCK_MONOTONIC, &t_gen0);
int out_ids[4096];
double p_tps, d_tps;
int n_out = generate(input_ids, n_input, max_tokens, out_ids, 4096, &p_tps, &d_tps);
clock_gettime(CLOCK_MONOTONIC, &t_gen1);
double inference_ms = timespec_diff(&t_gen0, &t_gen1) * 1000.0;
// Prefill time = inference of prompt tokens only (from generate's internal timing)
double prefill_s = p_tps > 0 ? n_input / p_tps : 0;
double ttft_ms = prefill_s * 1000.0;
// Time detokenization separately
clock_gettime(CLOCK_MONOTONIC, &t_det0);
char decoded[65536];
tok_decode(&g_tokenizer, out_ids, n_out, decoded, sizeof(decoded));
clock_gettime(CLOCK_MONOTONIC, &t_det1);
double detokenize_ms = timespec_diff(&t_det0, &t_det1) * 1000.0;
double total_ms = tokenize_ms + inference_ms + detokenize_ms;
// Escape the decoded text for JSON
char escaped[131072];
int ei = 0;
for (int i = 0; decoded[i] && ei < (int)sizeof(escaped) - 6; i++) {
switch (decoded[i]) {
case '"': escaped[ei++] = '\\'; escaped[ei++] = '"'; break;
case '\\': escaped[ei++] = '\\'; escaped[ei++] = '\\'; break;
case '\n': escaped[ei++] = '\\'; escaped[ei++] = 'n'; break;
case '\r': escaped[ei++] = '\\'; escaped[ei++] = 'r'; break;
case '\t': escaped[ei++] = '\\'; escaped[ei++] = 't'; break;
default:
if ((unsigned char)decoded[i] < 0x20) {
ei += snprintf(escaped + ei, 7, "\\u%04x", (unsigned char)decoded[i]);
} else {
escaped[ei++] = decoded[i];
}
}
}
escaped[ei] = '\0';
// Build JSON response with detailed timing breakdown
char resp[HTTP_MAX_RESPONSE];
snprintf(resp, sizeof(resp),
"{\"text\":\"%s\",\"prompt_tokens\":%d,\"gen_tokens\":%d,"
"\"prefill_tps\":%.1f,\"decode_tps\":%.1f,"
"\"tokenize_ms\":%.1f,\"inference_ms\":%.1f,\"detokenize_ms\":%.1f,"
"\"ttft_ms\":%.1f,\"total_ms\":%.1f}",
escaped, n_input, n_out, p_tps, d_tps,
tokenize_ms, inference_ms, detokenize_ms, ttft_ms, total_ms);
http_send_json(client_fd, 200, resp);
printf("[http] prompt=%d gen=%d prefill=%.1f decode=%.1f t/s | tok=%.1f inf=%.1f detok=%.1f ms\n",
n_input, n_out, p_tps, d_tps, tokenize_ms, inference_ms, detokenize_ms);
fflush(stdout);
qwen_reset(&g_model);
}
int main(int argc, char **argv) {
@autoreleasepool {
if (argc < 2) {
fprintf(stderr,
"Usage:\n"
" %s <weights.bin> \"token_ids\" [max_tokens] (single-shot)\n"
" %s <weights.bin> --server (stdin loop)\n"
" %s <weights.bin> --server /tmp/qwen_ane.sock (socket server)\n",
argv[0], argv[0], argv[0]);
" %s <weights.bin> \"token_ids\" [max_tokens] (single-shot)\n"
" %s <weights.bin> --server (stdin loop)\n"
" %s <weights.bin> --server /tmp/qwen_ane.sock (socket server)\n"
" %s <weights.bin> --http 8000 --model-dir ~/models/Qwen2.5 (HTTP API)\n",
argv[0], argv[0], argv[0], argv[0]);
return 1;
}
@ -300,6 +413,7 @@ int main(int argc, char **argv) {
if (load_weights(argv[1]) != 0) return 1;
qwen_alloc(&g_model);
qwen_rope_init();
printf("Compiling ANE kernels (169 total)...\n");
struct timespec t0, t1;
@ -309,17 +423,143 @@ int main(int argc, char **argv) {
double compile_sec = timespec_diff(&t0, &t1);
printf("Compile time: %.1fs\n\n", compile_sec);
// Check for --server flag
// Parse flags
int server_mode = 0;
int http_port = 0;
int test_ane = 0;
const char *sock_path = NULL;
const char *model_dir = NULL;
for (int i = 2; i < argc; i++) {
if (strcmp(argv[i], "--server") == 0) {
server_mode = 1;
if (i + 1 < argc && argv[i+1][0] != '-')
sock_path = argv[++i];
} else if (strcmp(argv[i], "--http") == 0) {
if (i + 1 < argc) http_port = atoi(argv[++i]);
else { fprintf(stderr, "--http requires a port number\n"); return 1; }
} else if (strcmp(argv[i], "--model-dir") == 0) {
if (i + 1 < argc) model_dir = argv[++i];
else { fprintf(stderr, "--model-dir requires a path\n"); return 1; }
} else if (strcmp(argv[i], "--test-ane") == 0) {
test_ane = 1;
}
}
// ANE vs CPU correctness test
if (test_ane) {
printf("=== ANE vs CPU Projection Test ===\n\n");
// Use a realistic input: embed token 2610 ("What"), RMSNorm it
int test_token = 2610;
memcpy(g_model.x, g_model.embed + test_token * QWEN_DIM, QWEN_DIM * sizeof(float));
qwen_rmsnorm(g_model.xb, g_model.x, g_model.rms_att[0], QWEN_DIM);
// Also prepare a realistic Q output for the O projection test
cpu_project(g_model.wq[0], g_model.xb, g_model.q, QWEN_DIM, QWEN_Q_DIM);
float *cpu_out = (float*)calloc(QWEN_HIDDEN, sizeof(float));
float *ane_out = (float*)calloc(QWEN_HIDDEN, sizeof(float));
struct {
const char *name;
ANEKernel *kernel;
const float *weights;
int in_dim, out_dim;
} tests[] = {
{"L0 Q proj", g_model.k_q[0], g_model.wq[0], QWEN_DIM, QWEN_Q_DIM},
{"L0 K proj", g_model.k_k[0], g_model.wk[0], QWEN_DIM, QWEN_KV_DIM},
{"L0 V proj", g_model.k_v[0], g_model.wv[0], QWEN_DIM, QWEN_KV_DIM},
{"L0 O proj", g_model.k_o[0], g_model.wo[0], QWEN_Q_DIM, QWEN_DIM},
{"L0 Gate", g_model.k_gate[0], g_model.w_gate[0], QWEN_DIM, QWEN_HIDDEN},
{"L0 Up", g_model.k_up[0], g_model.w_up[0], QWEN_DIM, QWEN_HIDDEN},
{"L0 Down", g_model.k_down[0], g_model.w_down[0], QWEN_HIDDEN, QWEN_DIM},
{"LM Head c0", g_model.k_lmhead[0], g_model.embed, QWEN_DIM, QWEN_LM_CHUNK_SIZE},
};
int n_tests = sizeof(tests) / sizeof(tests[0]);
int all_pass = 1;
for (int t = 0; t < n_tests; t++) {
if (!tests[t].kernel) {
printf(" %-14s SKIP (kernel not compiled)\n", tests[t].name);
continue;
}
const float *input;
if (tests[t].in_dim == QWEN_Q_DIM) {
input = g_model.q;
} else if (tests[t].in_dim == QWEN_HIDDEN) {
cpu_project(g_model.w_gate[0], g_model.xb, g_model.hb, QWEN_DIM, QWEN_HIDDEN);
input = g_model.hb;
} else {
input = g_model.xb;
}
cpu_project(tests[t].weights, input, cpu_out, tests[t].in_dim, tests[t].out_dim);
// ANE projection with return-value check
ane_write_input(tests[t].kernel, 0, input, tests[t].in_dim * sizeof(float));
bool ane_ok = ane_run(tests[t].kernel);
ane_read_output(tests[t].kernel, 0, ane_out, tests[t].out_dim * sizeof(float));
if (!ane_ok) printf(" !! ANE execution returned false\n");
float max_diff = 0, sum_diff = 0;
float cpu_norm = 0, ane_norm = 0;
for (int i = 0; i < tests[t].out_dim; i++) {
float d = fabsf(cpu_out[i] - ane_out[i]);
if (d > max_diff) max_diff = d;
sum_diff += d;
cpu_norm += cpu_out[i] * cpu_out[i];
ane_norm += ane_out[i] * ane_out[i];
}
float avg_diff = sum_diff / tests[t].out_dim;
float rel_err = (sqrtf(cpu_norm) > 0) ?
sqrtf(sum_diff * sum_diff / tests[t].out_dim) / sqrtf(cpu_norm / tests[t].out_dim) : 0;
int pass = (max_diff < 0.5f && rel_err < 0.05f);
if (!pass) all_pass = 0;
printf(" %-14s [%d→%d] max_diff=%.6f avg_diff=%.6f rel_err=%.4f %s\n",
tests[t].name, tests[t].in_dim, tests[t].out_dim,
max_diff, avg_diff, rel_err,
pass ? "PASS" : "FAIL");
printf(" CPU first4: %.6f %.6f %.6f %.6f norm=%.4f\n",
cpu_out[0], cpu_out[1], cpu_out[2], cpu_out[3], sqrtf(cpu_norm));
printf(" ANE first4: %.6f %.6f %.6f %.6f norm=%.4f\n",
ane_out[0], ane_out[1], ane_out[2], ane_out[3], sqrtf(ane_norm));
}
printf("\n%s\n", all_pass ?
"ALL TESTS PASSED -- ANE projections match CPU (within FP16 tolerance)" :
"SOME TESTS FAILED -- ANE projections have accuracy issues");
// If all pass, benchmark one layer ANE vs CPU speed
if (all_pass) {
printf("\n=== Speed comparison (1000 iterations, L0 Q proj %d→%d) ===\n",
QWEN_DIM, QWEN_Q_DIM);
struct timespec ts0, ts1;
clock_gettime(CLOCK_MONOTONIC, &ts0);
for (int i = 0; i < 1000; i++)
cpu_project(g_model.wq[0], g_model.xb, cpu_out, QWEN_DIM, QWEN_Q_DIM);
clock_gettime(CLOCK_MONOTONIC, &ts1);
double cpu_us = timespec_diff(&ts0, &ts1) * 1e6 / 1000;
clock_gettime(CLOCK_MONOTONIC, &ts0);
for (int i = 0; i < 1000; i++)
ane_project(g_model.k_q[0], g_model.xb, ane_out, QWEN_DIM, QWEN_Q_DIM);
clock_gettime(CLOCK_MONOTONIC, &ts1);
double ane_us = timespec_diff(&ts0, &ts1) * 1e6 / 1000;
printf(" CPU: %.1f us/call\n", cpu_us);
printf(" ANE: %.1f us/call\n", ane_us);
printf(" Ratio: %.2fx %s\n", cpu_us / ane_us,
ane_us < cpu_us ? "(ANE faster)" : "(CPU faster)");
}
free(cpu_out);
free(ane_out);
return all_pass ? 0 : 1;
}
if (server_mode) {
if (sock_path)
run_socket_server(sock_path);
@ -328,6 +568,31 @@ int main(int argc, char **argv) {
return 0;
}
// HTTP API mode
if (http_port > 0) {
if (!model_dir) {
// Default to ~/models/Qwen2.5-0.5B-Instruct
static char default_dir[4096];
const char *home = getenv("HOME");
snprintf(default_dir, sizeof(default_dir), "%s/models/Qwen2.5-0.5B-Instruct", home ? home : ".");
model_dir = default_dir;
}
printf("Loading tokenizer from %s...\n", model_dir);
if (tok_init(&g_tokenizer, model_dir) != 0) {
fprintf(stderr, "Failed to load tokenizer from %s\n", model_dir);
return 1;
}
g_tokenizer_loaded = 1;
printf("Tokenizer ready.\n\n");
signal(SIGINT, handle_signal);
signal(SIGTERM, handle_signal);
http_serve(http_port, http_api_handler, NULL);
tok_free(&g_tokenizer);
return 0;
}
// Single-shot mode (original behavior)
if (argc < 3) {
fprintf(stderr, "Error: provide token IDs or --server\n");

View File

@ -26,6 +26,12 @@ static ANEKernel *compile_conv_kernel(const float *weights, int in_ch, int out_c
#include <math.h>
#include <string.h>
#include <time.h>
#include <arm_neon.h>
#include <Accelerate/Accelerate.h>
#ifndef QWEN_DEBUG
#define QWEN_DEBUG 0
#endif
// Qwen2.5-0.5B-Instruct architecture
#define QWEN_DIM 896
@ -96,73 +102,106 @@ typedef struct {
float *logits; // [vocab]
} QwenModel;
// ── CPU ops ──────────────────────────────────────────────────────────
// ── Precomputed RoPE table ───────────────────────────────────────────
static float g_rope_cos[QWEN_MAX_SEQ][QWEN_HEAD_DIM / 2];
static float g_rope_sin[QWEN_MAX_SEQ][QWEN_HEAD_DIM / 2];
static int g_rope_initialized = 0;
static void qwen_rope_init(void) {
if (g_rope_initialized) return;
int half = QWEN_HEAD_DIM / 2;
for (int pos = 0; pos < QWEN_MAX_SEQ; pos++) {
for (int i = 0; i < half; i++) {
float freq = 1.0f / powf(QWEN_ROPE_THETA, (float)(2 * i) / QWEN_HEAD_DIM);
float angle = pos * freq;
g_rope_cos[pos][i] = cosf(angle);
g_rope_sin[pos][i] = sinf(angle);
}
}
g_rope_initialized = 1;
}
// ── CPU ops (vectorized with NEON + vDSP) ────────────────────────────
static void qwen_rmsnorm(float *out, const float *x, const float *w, int D) {
float ss = 0;
for (int i = 0; i < D; i++) ss += x[i] * x[i];
float ss;
vDSP_svesq(x, 1, &ss, (vDSP_Length)D);
ss = 1.0f / sqrtf(ss / D + QWEN_RMS_EPS);
for (int i = 0; i < D; i++) out[i] = x[i] * ss * w[i];
vDSP_vsmul(x, 1, &ss, out, 1, (vDSP_Length)D);
vDSP_vmul(out, 1, w, 1, out, 1, (vDSP_Length)D);
}
static void qwen_rope(float *q, float *k, int pos, int n_q_heads, int n_kv_heads, int head_dim) {
// Qwen uses rotate_half RoPE (NOT interleaved pairs):
// rotate_half(x) = [-x[dim/2:], x[:dim/2]]
// q_embed = q * cos + rotate_half(q) * sin
// cos/sin have shape [head_dim/2] and are applied to both halves
int half = head_dim / 2;
const float *cv = g_rope_cos[pos];
const float *sv = g_rope_sin[pos];
// Precompute cos/sin for this position (head_dim/2 frequencies)
float cos_v[half], sin_v[half];
for (int i = 0; i < half; i++) {
float freq = 1.0f / powf(QWEN_ROPE_THETA, (float)(2 * i) / head_dim);
float angle = pos * freq;
cos_v[i] = cosf(angle);
sin_v[i] = sinf(angle);
}
// Apply to Q heads
for (int h = 0; h < n_q_heads; h++) {
float *qh = q + h * head_dim;
for (int i = 0; i < half; i++) {
float q_first = qh[i];
float q_second = qh[i + half];
// rotate_half: [-q_second, q_first]
qh[i] = q_first * cos_v[i] + (-q_second) * sin_v[i];
qh[i + half] = q_second * cos_v[i] + q_first * sin_v[i];
int i = 0;
for (; i + 3 < half; i += 4) {
float32x4_t first = vld1q_f32(qh + i);
float32x4_t second = vld1q_f32(qh + i + half);
float32x4_t c = vld1q_f32(cv + i);
float32x4_t s = vld1q_f32(sv + i);
vst1q_f32(qh + i, vmlsq_f32(vmulq_f32(first, c), second, s));
vst1q_f32(qh + i + half, vmlaq_f32(vmulq_f32(second, c), first, s));
}
for (; i < half; i++) {
float f = qh[i], se = qh[i + half];
qh[i] = f * cv[i] - se * sv[i];
qh[i + half] = se * cv[i] + f * sv[i];
}
}
// Apply to K heads
for (int h = 0; h < n_kv_heads; h++) {
float *kh = k + h * head_dim;
for (int i = 0; i < half; i++) {
float k_first = kh[i];
float k_second = kh[i + half];
kh[i] = k_first * cos_v[i] + (-k_second) * sin_v[i];
kh[i + half] = k_second * cos_v[i] + k_first * sin_v[i];
int i = 0;
for (; i + 3 < half; i += 4) {
float32x4_t first = vld1q_f32(kh + i);
float32x4_t second = vld1q_f32(kh + i + half);
float32x4_t c = vld1q_f32(cv + i);
float32x4_t s = vld1q_f32(sv + i);
vst1q_f32(kh + i, vmlsq_f32(vmulq_f32(first, c), second, s));
vst1q_f32(kh + i + half, vmlaq_f32(vmulq_f32(second, c), first, s));
}
for (; i < half; i++) {
float f = kh[i], se = kh[i + half];
kh[i] = f * cv[i] - se * sv[i];
kh[i + half] = se * cv[i] + f * sv[i];
}
}
}
static void qwen_silu(float *x, int n) {
for (int i = 0; i < n; i++)
int i = 0;
float32x4_t one = vdupq_n_f32(1.0f);
for (; i + 3 < n; i += 4) {
float32x4_t v = vld1q_f32(x + i);
float neg[4];
vst1q_f32(neg, vnegq_f32(v));
float exp_neg[4];
for (int j = 0; j < 4; j++) exp_neg[j] = expf(neg[j]);
float32x4_t denom = vaddq_f32(one, vld1q_f32(exp_neg));
vst1q_f32(x + i, vdivq_f32(v, denom));
}
for (; i < n; i++)
x[i] = x[i] / (1.0f + expf(-x[i]));
}
// ── ANE projection helper (single token: spatial=1) ─────────────────
static inline bool ane_run(ANEKernel *k) { return ane_eval(k); }
static void ane_project(ANEKernel *kernel, const float *in, float *out,
int in_dim, int out_dim) {
// For single-token inference: spatial=1
ane_write_input(kernel, 0, in, in_dim * sizeof(float));
ane_eval(kernel);
ane_run(kernel);
ane_read_output(kernel, 0, out, out_dim * sizeof(float));
}
// CPU matmul via Accelerate BLAS: y = W @ x, W[out_dim, in_dim]
#include <Accelerate/Accelerate.h>
static void cpu_project(const float *W, const float *x, float *y, int in_dim, int out_dim) {
// y = W @ x where W is [out_dim, in_dim] row-major
// cblas_sgemv: y = alpha * A * x + beta * y
@ -189,13 +228,14 @@ static int qwen_forward(QwenModel *m, int token) {
// Attention RMSNorm
qwen_rmsnorm(m->xb, m->x, m->rms_att[l], D);
// Debug: print first layer input/output norms
#if QWEN_DEBUG
if (l == 0 && pos == 0) {
float xnorm = 0, qnorm = 0;
float xnorm = 0;
for (int i = 0; i < D; i++) xnorm += m->xb[i] * m->xb[i];
printf(" L0 RMSNorm out norm=%.4f (first 4: %.4f %.4f %.4f %.4f)\n",
sqrtf(xnorm), m->xb[0], m->xb[1], m->xb[2], m->xb[3]);
}
#endif
// QKV projections (ANE) + bias
#if USE_ANE_PROJECTIONS
@ -207,23 +247,20 @@ static int qwen_forward(QwenModel *m, int token) {
cpu_project(m->wk[l], m->xb, m->k, D, QWEN_KV_DIM);
cpu_project(m->wv[l], m->xb, m->v, D, QWEN_KV_DIM);
#endif
// Apply Q/K biases
if (m->q_bias[l]) {
for (int i = 0; i < QWEN_Q_DIM; i++) m->q[i] += m->q_bias[l][i];
}
if (m->k_bias[l]) {
for (int i = 0; i < QWEN_KV_DIM; i++) m->k[i] += m->k_bias[l][i];
}
if (m->v_bias[l]) {
for (int i = 0; i < QWEN_KV_DIM; i++) m->v[i] += m->v_bias[l][i];
}
// Apply Q/K/V biases (vectorized)
if (m->q_bias[l])
vDSP_vadd(m->q, 1, m->q_bias[l], 1, m->q, 1, (vDSP_Length)QWEN_Q_DIM);
if (m->k_bias[l])
vDSP_vadd(m->k, 1, m->k_bias[l], 1, m->k, 1, (vDSP_Length)QWEN_KV_DIM);
if (m->v_bias[l])
vDSP_vadd(m->v, 1, m->v_bias[l], 1, m->v, 1, (vDSP_Length)QWEN_KV_DIM);
#if QWEN_DEBUG
if (l == 0 && pos == 0) {
float qn = 0;
for (int i = 0; i < QWEN_Q_DIM; i++) qn += m->q[i] * m->q[i];
printf(" L0 ANE Q norm=%.4f (first 4: %.4f %.4f %.4f %.4f)\n",
sqrtf(qn), m->q[0], m->q[1], m->q[2], m->q[3]);
// CPU reference
float cpu_q[4] = {0};
for (int i = 0; i < 4; i++) {
for (int j = 0; j < D; j++)
@ -233,6 +270,7 @@ static int qwen_forward(QwenModel *m, int token) {
printf(" L0 CPU Q first 4: %.4f %.4f %.4f %.4f\n",
cpu_q[0], cpu_q[1], cpu_q[2], cpu_q[3]);
}
#endif
// RoPE
qwen_rope(m->q, m->k, pos, QWEN_HEADS, QWEN_KV_HEADS, QWEN_HEAD_DIM);
@ -251,29 +289,30 @@ static int qwen_forward(QwenModel *m, int token) {
for (int h = 0; h < QWEN_HEADS; h++) {
int kv_h = h / QWEN_GQA_FACTOR;
float *qh = m->q + h * QWEN_HEAD_DIM;
float *att_h = m->att + h * QWEN_MAX_SEQ;
int seq_len = pos + 1;
// Attention scores: Q @ K^T for all positions up to pos
// Attention scores: Q @ K^T
float max_score = -1e9f;
for (int t = 0; t <= pos; t++) {
float *kt = m->kv_cache_k[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM;
// Use BLAS dot product for precision
float score = cblas_sdot(QWEN_HEAD_DIM, qh, 1, kt, 1);
m->att[h * QWEN_MAX_SEQ + t] = score * scale;
if (score * scale > max_score) max_score = score * scale;
att_h[t] = score * scale;
if (att_h[t] > max_score) max_score = att_h[t];
}
// Softmax (double accumulation for precision)
double sum = 0;
for (int t = 0; t <= pos; t++) {
m->att[h * QWEN_MAX_SEQ + t] = expf(m->att[h * QWEN_MAX_SEQ + t] - max_score);
sum += (double)m->att[h * QWEN_MAX_SEQ + t];
}
float inv_sum = (float)(1.0 / sum);
for (int t = 0; t <= pos; t++)
m->att[h * QWEN_MAX_SEQ + t] *= inv_sum;
// Softmax: subtract max, exp, normalize (vDSP)
float neg_max = -max_score;
vDSP_vsadd(att_h, 1, &neg_max, att_h, 1, (vDSP_Length)seq_len);
int n_exp = seq_len;
vvexpf(att_h, att_h, &n_exp);
float sum;
vDSP_sve(att_h, 1, &sum, (vDSP_Length)seq_len);
float inv_sum = 1.0f / sum;
vDSP_vsmul(att_h, 1, &inv_sum, att_h, 1, (vDSP_Length)seq_len);
// Weighted sum of V: attn_out[h] += att[t] * V[t] for each t
// Weighted sum of V
for (int t = 0; t <= pos; t++) {
float a = m->att[h * QWEN_MAX_SEQ + t];
float a = att_h[t];
float *vt = m->kv_cache_v[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM;
cblas_saxpy(QWEN_HEAD_DIM, a, vt, 1,
attn_out + h * QWEN_HEAD_DIM, 1);
@ -287,9 +326,10 @@ static int qwen_forward(QwenModel *m, int token) {
cpu_project(m->wo[l], attn_out, o_out, QWEN_Q_DIM, D);
#endif
// Residual
for (int i = 0; i < D; i++) m->x[i] += o_out[i];
// Residual (vectorized)
vDSP_vadd(m->x, 1, o_out, 1, m->x, 1, (vDSP_Length)D);
#if QWEN_DEBUG
if (l == 0 && pos == 0) {
float pan = 0;
for (int i = 0; i < D; i++) pan += m->x[i] * m->x[i];
@ -300,6 +340,7 @@ static int qwen_forward(QwenModel *m, int token) {
printf(" L0 o_proj out norm=%.4f first4=[%.6f, %.6f, %.6f, %.6f]\n",
sqrtf(on), o_out[0], o_out[1], o_out[2], o_out[3]);
}
#endif
// FFN RMSNorm
qwen_rmsnorm(m->xb, m->x, m->rms_ffn[l], D);
@ -313,6 +354,7 @@ static int qwen_forward(QwenModel *m, int token) {
cpu_project(m->w_up[l], m->xb, m->hb2, D, HD);
#endif
#if QWEN_DEBUG
if (l == 0 && pos == 0) {
float gn = 0, un = 0;
for (int i = 0; i < HD; i++) { gn += m->hb[i]*m->hb[i]; un += m->hb2[i]*m->hb2[i]; }
@ -320,9 +362,11 @@ static int qwen_forward(QwenModel *m, int token) {
printf(" L0 gate first4=[%.6f, %.6f, %.6f, %.6f]\n",
m->hb[0], m->hb[1], m->hb[2], m->hb[3]);
}
#endif
qwen_silu(m->hb, HD);
for (int i = 0; i < HD; i++) m->hb[i] *= m->hb2[i];
// SiLU(gate) * up (vectorized element-wise multiply)
vDSP_vmul(m->hb, 1, m->hb2, 1, m->hb, 1, (vDSP_Length)HD);
float ffn_out[QWEN_DIM];
#if USE_ANE_PROJECTIONS
@ -331,38 +375,39 @@ static int qwen_forward(QwenModel *m, int token) {
cpu_project(m->w_down[l], m->hb, ffn_out, HD, D);
#endif
// Residual
for (int i = 0; i < D; i++) m->x[i] += ffn_out[i];
// Residual (vectorized)
vDSP_vadd(m->x, 1, ffn_out, 1, m->x, 1, (vDSP_Length)D);
// Debug: hidden state after each layer (first 3 layers, first token only)
#if QWEN_DEBUG
if (l < 3 && pos == 0) {
float hn = 0;
for (int i = 0; i < D; i++) hn += m->x[i] * m->x[i];
printf(" C hidden[%d] norm=%.4f first4=[%.4f, %.4f, %.4f, %.4f]\n",
l+1, sqrtf(hn), m->x[0], m->x[1], m->x[2], m->x[3]);
}
#endif
}
// Final RMSNorm
qwen_rmsnorm(m->xb, m->x, m->rms_final, D);
// Debug: check final hidden state before LM head
#if QWEN_DEBUG
if (m->pos < 2) {
float fn = 0;
for (int i = 0; i < D; i++) fn += m->xb[i] * m->xb[i];
printf(" Final hidden norm=%.4f (first 4: %.6f %.6f %.6f %.6f)\n",
sqrtf(fn), m->xb[0], m->xb[1], m->xb[2], m->xb[3]);
}
#endif
// LM head via Accelerate BLAS: logits = embed @ xb
// embed is [vocab, dim] row-major
cblas_sgemv(CblasRowMajor, CblasNoTrans,
QWEN_VOCAB, D,
1.0f, m->embed, D,
m->xb, 1,
0.0f, m->logits, 1);
// Debug: check logits
#if QWEN_DEBUG
if (m->pos < 2) {
float lmax = m->logits[0], lmin = m->logits[0];
int nonzero = 0;
@ -373,24 +418,21 @@ static int qwen_forward(QwenModel *m, int token) {
}
printf(" Logits: min=%.4f max=%.4f nonzero=%d/%d\n", lmin, lmax, nonzero, QWEN_VOCAB);
}
#endif
m->pos++;
// Argmax
int max_idx = 0;
float max_val = m->logits[0];
for (int i = 1; i < QWEN_VOCAB; i++) {
if (m->logits[i] > max_val) {
max_val = m->logits[i];
max_idx = i;
}
}
return max_idx;
// Argmax (vDSP, single call over 151936 elements)
float max_val;
vDSP_Length max_idx_vdsp;
vDSP_maxvi(m->logits, 1, &max_val, &max_idx_vdsp, (vDSP_Length)QWEN_VOCAB);
return (int)max_idx_vdsp;
}
// ── Compile all ANE kernels ──────────────────────────────────────────
static void qwen_compile_kernels(QwenModel *m) {
#if USE_ANE_PROJECTIONS
int D = QWEN_DIM, HD = QWEN_HIDDEN;
printf("Compiling %d ANE kernels...\n", QWEN_LAYERS * 7 + 1);
for (int l = 0; l < QWEN_LAYERS; l++) {
@ -404,7 +446,6 @@ static void qwen_compile_kernels(QwenModel *m) {
printf(" Layer %d/%d compiled\r", l+1, QWEN_LAYERS);
fflush(stdout);
}
// LM head (tied = embedding, chunked into 16 pieces)
for (int c = 0; c < QWEN_LM_CHUNKS; c++) {
float *chunk_weights = m->embed + c * QWEN_LM_CHUNK_SIZE * D;
m->k_lmhead[c] = compile_conv_kernel(chunk_weights, D, QWEN_LM_CHUNK_SIZE, 1);
@ -413,6 +454,10 @@ static void qwen_compile_kernels(QwenModel *m) {
}
}
printf("\nAll kernels compiled.\n");
#else
printf("CPU-only mode (ANE kernel compilation skipped).\n");
(void)m;
#endif
}
// ── Allocate buffers ─────────────────────────────────────────────────

158
inference/setup.sh Executable file
View File

@ -0,0 +1,158 @@
#!/bin/bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
MODEL_ID="Qwen/Qwen2.5-0.5B-Instruct"
MODEL_DIR="$HOME/models/Qwen2.5-0.5B-Instruct"
WEIGHTS_BIN="$SCRIPT_DIR/qwen05b.bin"
BINARY="$SCRIPT_DIR/qwen_ane"
VENV_DIR="$SCRIPT_DIR/.venv"
EXPECTED_WEIGHT_SIZE=1976131100
info() { printf "\033[1;34m==> %s\033[0m\n" "$1"; }
ok() { printf "\033[1;32m ✓ %s\033[0m\n" "$1"; }
warn() { printf "\033[1;33m ! %s\033[0m\n" "$1"; }
fail() { printf "\033[1;31m ✗ %s\033[0m\n" "$1"; exit 1; }
info "ANE Inference Setup"
echo "Model: $MODEL_ID"
echo "Target: $SCRIPT_DIR"
echo ""
# --- Step 1: Prerequisites ---
info "Checking prerequisites..."
if ! command -v xcrun &>/dev/null; then
fail "Xcode Command Line Tools not found. Install with: xcode-select --install"
fi
ok "xcrun clang available"
if ! command -v python3 &>/dev/null; then
fail "Python 3 not found"
fi
PY_VER=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
PY_MAJOR=$(echo "$PY_VER" | cut -d. -f1)
PY_MINOR=$(echo "$PY_VER" | cut -d. -f2)
if [ "$PY_MAJOR" -lt 3 ] || ([ "$PY_MAJOR" -eq 3 ] && [ "$PY_MINOR" -lt 11 ]); then
fail "Python 3.11+ required (found $PY_VER). coremltools needs 3.11-3.13."
fi
ok "Python $PY_VER"
# --- Step 2: Virtual environment ---
info "Setting up Python environment..."
if [ ! -d "$VENV_DIR" ]; then
python3 -m venv "$VENV_DIR"
ok "Created venv at $VENV_DIR"
else
ok "Venv already exists"
fi
source "$VENV_DIR/bin/activate"
pip install --quiet --upgrade pip
pip install --quiet safetensors torch transformers huggingface-hub
ok "Python dependencies installed"
# --- Step 3: Download model ---
info "Downloading model from HuggingFace..."
if [ -f "$MODEL_DIR/model.safetensors" ] && [ -f "$MODEL_DIR/tokenizer.json" ]; then
ok "Model already downloaded at $MODEL_DIR"
else
mkdir -p "$MODEL_DIR"
if command -v huggingface-cli &>/dev/null; then
huggingface-cli download "$MODEL_ID" --local-dir "$MODEL_DIR"
else
python3 -c "
from huggingface_hub import snapshot_download
snapshot_download('$MODEL_ID', local_dir='$MODEL_DIR')
"
fi
ok "Model downloaded to $MODEL_DIR"
fi
# Verify key files exist
for f in model.safetensors tokenizer.json vocab.json merges.txt config.json; do
if [ ! -f "$MODEL_DIR/$f" ]; then
fail "Missing $f in $MODEL_DIR"
fi
done
ok "All model files present"
# --- Step 4: Convert weights ---
info "Converting weights to binary format..."
if [ -f "$WEIGHTS_BIN" ]; then
ACTUAL_SIZE=$(stat -f%z "$WEIGHTS_BIN" 2>/dev/null || stat -c%s "$WEIGHTS_BIN" 2>/dev/null)
if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE" ]; then
ok "Weights already converted ($((ACTUAL_SIZE / 1024 / 1024)) MB)"
else
warn "Weight file exists but wrong size ($ACTUAL_SIZE vs $EXPECTED_WEIGHT_SIZE), reconverting"
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN"
ok "Weights converted"
fi
else
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN"
ok "Weights converted"
fi
# --- Step 5: Build binary ---
info "Building qwen_ane binary..."
NEEDS_BUILD=0
if [ ! -f "$BINARY" ]; then
NEEDS_BUILD=1
elif [ "$SCRIPT_DIR/main.m" -nt "$BINARY" ] || \
[ "$SCRIPT_DIR/qwen_ane_infer.h" -nt "$BINARY" ] || \
[ "$SCRIPT_DIR/tokenizer.h" -nt "$BINARY" ] 2>/dev/null || \
[ "$SCRIPT_DIR/http_server.h" -nt "$BINARY" ] 2>/dev/null; then
NEEDS_BUILD=1
warn "Source files newer than binary, rebuilding"
fi
if [ "$NEEDS_BUILD" -eq 1 ]; then
xcrun clang -O2 -framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
-o "$BINARY" "$SCRIPT_DIR/main.m"
ok "Binary built: $BINARY"
else
ok "Binary up to date"
fi
# --- Step 6: Smoke test ---
info "Running smoke test..."
# Quick single-shot test with known token IDs for "system\nYou are a helpful assistant."
TEST_OUTPUT=$("$BINARY" "$WEIGHTS_BIN" "151644 8948 198" 3 2>&1 || true)
if echo "$TEST_OUTPUT" | grep -q "OUT:"; then
ok "Smoke test passed (model generates output)"
else
warn "Smoke test: no output tokens detected (this may be OK on first run)"
echo " Output was: $(echo "$TEST_OUTPUT" | tail -3)"
fi
# --- Done ---
echo ""
info "Setup complete!"
echo ""
echo " Binary: $BINARY"
echo " Weights: $WEIGHTS_BIN ($(du -h "$WEIGHTS_BIN" | cut -f1) )"
echo " Model: $MODEL_DIR"
echo ""
echo "Quick start:"
echo " # Single prompt (slow, compiles every time)"
echo " python3 $SCRIPT_DIR/run.py \"What is 2+2?\""
echo ""
echo " # Server mode (fast, compile once)"
echo " $BINARY $WEIGHTS_BIN --server /tmp/qwen_ane.sock &"
echo " python3 $SCRIPT_DIR/run.py \"What is 2+2?\""
echo ""
echo " # HTTP API (fast, no Python needed for queries)"
echo " $BINARY $WEIGHTS_BIN --http 8000 --model-dir $MODEL_DIR"
echo " curl http://localhost:8000/v1/completions -d '{\"prompt\":\"Hi\",\"max_tokens\":20}'"
echo ""
echo " # Run throughput benchmark"
echo " $SCRIPT_DIR/benchmark.sh"

657
inference/tokenizer.h Normal file
View File

@ -0,0 +1,657 @@
// tokenizer.h -- Byte-level BPE tokenizer for Qwen2.5 in pure C
// Loads vocab.json + merges.txt from HuggingFace model directory.
// Implements GPT-style byte-level BPE (same algorithm as tiktoken/llama.cpp).
#pragma once
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#define TOK_MAX_VOCAB 152000
#define TOK_MAX_MERGES 152000
#define TOK_MAX_TOKEN_LEN 256
#define TOK_HASH_SIZE (1 << 20) // ~1M buckets
// Special token IDs for Qwen2.5
#define TOK_IM_START 151644
#define TOK_IM_END 151645
#define TOK_ENDOFTEXT 151643
// --- Byte-to-unicode mapping (GPT-2 standard) ---
// Maps byte values 0-255 to unicode codepoints used in the BPE vocab.
// Printable ASCII stays the same; non-printable bytes map to U+0100..U+0143.
static int g_byte_to_unicode[256];
static int g_unicode_to_byte[65536];
static void tok_init_byte_mapping(void) {
int n = 0;
for (int b = 0; b < 256; b++) {
if ((b >= 0x21 && b <= 0x7E) || (b >= 0xA1 && b <= 0xAC) || (b >= 0xAE && b <= 0xFF)) {
g_byte_to_unicode[b] = b;
} else {
g_byte_to_unicode[b] = 256 + n;
n++;
}
}
memset(g_unicode_to_byte, 0xFF, sizeof(g_unicode_to_byte));
for (int b = 0; b < 256; b++)
g_unicode_to_byte[g_byte_to_unicode[b]] = b;
}
// --- UTF-8 helpers ---
static int utf8_encode(int codepoint, char *out) {
if (codepoint < 0x80) {
out[0] = (char)codepoint;
return 1;
} else if (codepoint < 0x800) {
out[0] = (char)(0xC0 | (codepoint >> 6));
out[1] = (char)(0x80 | (codepoint & 0x3F));
return 2;
} else if (codepoint < 0x10000) {
out[0] = (char)(0xE0 | (codepoint >> 12));
out[1] = (char)(0x80 | ((codepoint >> 6) & 0x3F));
out[2] = (char)(0x80 | (codepoint & 0x3F));
return 3;
}
out[0] = (char)(0xF0 | (codepoint >> 18));
out[1] = (char)(0x80 | ((codepoint >> 12) & 0x3F));
out[2] = (char)(0x80 | ((codepoint >> 6) & 0x3F));
out[3] = (char)(0x80 | (codepoint & 0x3F));
return 4;
}
static int utf8_decode(const char *s, int *codepoint) {
unsigned char c = (unsigned char)s[0];
if (c < 0x80) { *codepoint = c; return 1; }
if ((c & 0xE0) == 0xC0) {
*codepoint = ((c & 0x1F) << 6) | (s[1] & 0x3F);
return 2;
}
if ((c & 0xF0) == 0xE0) {
*codepoint = ((c & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F);
return 3;
}
*codepoint = ((c & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F);
return 4;
}
// --- Hash map: string -> int ---
typedef struct {
char *key;
int value;
} TokHashEntry;
typedef struct {
TokHashEntry *entries;
int capacity;
} TokHashMap;
static unsigned int tok_hash(const char *s) {
unsigned int h = 5381;
while (*s) h = ((h << 5) + h) ^ (unsigned char)*s++;
return h;
}
static void tok_hashmap_init(TokHashMap *m, int capacity) {
m->capacity = capacity;
m->entries = (TokHashEntry*)calloc(capacity, sizeof(TokHashEntry));
}
static void tok_hashmap_set(TokHashMap *m, const char *key, int value) {
unsigned int idx = tok_hash(key) % m->capacity;
while (m->entries[idx].key) {
if (strcmp(m->entries[idx].key, key) == 0) {
m->entries[idx].value = value;
return;
}
idx = (idx + 1) % m->capacity;
}
m->entries[idx].key = strdup(key);
m->entries[idx].value = value;
}
static int tok_hashmap_get(TokHashMap *m, const char *key, int default_val) {
unsigned int idx = tok_hash(key) % m->capacity;
while (m->entries[idx].key) {
if (strcmp(m->entries[idx].key, key) == 0)
return m->entries[idx].value;
idx = (idx + 1) % m->capacity;
}
return default_val;
}
static void tok_hashmap_free(TokHashMap *m) {
for (int i = 0; i < m->capacity; i++)
if (m->entries[i].key) free(m->entries[i].key);
free(m->entries);
m->entries = NULL;
m->capacity = 0;
}
// --- Merge pair ---
typedef struct {
char *a;
char *b;
} TokMerge;
// --- Tokenizer state ---
typedef struct {
TokHashMap vocab; // token string -> id
char **id_to_token; // id -> token string (for decoding)
int vocab_size;
TokMerge *merges;
int n_merges;
TokHashMap merge_rank; // "a b" -> rank (lower = higher priority)
// Special tokens
int im_start;
int im_end;
int eos;
} Tokenizer;
// --- JSON string parsing (minimal, handles unicode escapes) ---
static int tok_parse_json_string(const char *s, char *out, int max_out) {
if (*s != '"') return -1;
s++;
int n = 0;
while (*s && *s != '"' && n < max_out - 1) {
if (*s == '\\') {
s++;
switch (*s) {
case '"': out[n++] = '"'; break;
case '\\': out[n++] = '\\'; break;
case '/': out[n++] = '/'; break;
case 'n': out[n++] = '\n'; break;
case 'r': out[n++] = '\r'; break;
case 't': out[n++] = '\t'; break;
case 'u': {
char hex[5] = {s[1], s[2], s[3], s[4], 0};
int cp = (int)strtol(hex, NULL, 16);
n += utf8_encode(cp, out + n);
s += 4;
break;
}
default: out[n++] = *s;
}
} else {
out[n++] = *s;
}
s++;
}
out[n] = '\0';
return n;
}
// --- Load vocab.json ---
// Format: {"token_string": id, ...}
static int tok_load_vocab(Tokenizer *t, const char *path) {
FILE *f = fopen(path, "r");
if (!f) { fprintf(stderr, "Cannot open vocab: %s\n", path); return -1; }
fseek(f, 0, SEEK_END);
long fsize = ftell(f);
fseek(f, 0, SEEK_SET);
char *data = (char*)malloc(fsize + 1);
fread(data, 1, fsize, f);
data[fsize] = '\0';
fclose(f);
tok_hashmap_init(&t->vocab, TOK_HASH_SIZE);
t->id_to_token = (char**)calloc(TOK_MAX_VOCAB, sizeof(char*));
t->vocab_size = 0;
char *p = data;
// Skip opening {
while (*p && *p != '{') p++;
if (*p) p++;
char key_buf[TOK_MAX_TOKEN_LEN];
while (*p) {
while (*p && (*p == ' ' || *p == '\n' || *p == '\r' || *p == '\t' || *p == ',')) p++;
if (*p == '}' || !*p) break;
int klen = tok_parse_json_string(p, key_buf, sizeof(key_buf));
if (klen < 0) break;
// Skip past closing quote
p++; // opening "
while (*p) {
if (*p == '\\') { p += 2; continue; }
if (*p == '"') { p++; break; }
p++;
}
// Skip colon and whitespace
while (*p && (*p == ' ' || *p == ':')) p++;
int id = (int)strtol(p, &p, 10);
if (id >= 0 && id < TOK_MAX_VOCAB) {
tok_hashmap_set(&t->vocab, key_buf, id);
t->id_to_token[id] = strdup(key_buf);
if (id >= t->vocab_size) t->vocab_size = id + 1;
}
}
free(data);
printf(" Vocab: %d tokens\n", t->vocab_size);
return 0;
}
// --- Load merges.txt ---
// Format: one merge per line, "tokenA tokenB" (space-separated)
// First line may be a header starting with #
static int tok_load_merges(Tokenizer *t, const char *path) {
FILE *f = fopen(path, "r");
if (!f) { fprintf(stderr, "Cannot open merges: %s\n", path); return -1; }
t->merges = (TokMerge*)malloc(TOK_MAX_MERGES * sizeof(TokMerge));
tok_hashmap_init(&t->merge_rank, TOK_HASH_SIZE);
t->n_merges = 0;
char line[4096];
while (fgets(line, sizeof(line), f)) {
// Strip newline
int len = (int)strlen(line);
while (len > 0 && (line[len-1] == '\n' || line[len-1] == '\r')) line[--len] = '\0';
if (len == 0) continue;
if (line[0] == '#') continue; // skip header
// Split on first space
char *space = strchr(line, ' ');
if (!space) continue;
*space = '\0';
t->merges[t->n_merges].a = strdup(line);
t->merges[t->n_merges].b = strdup(space + 1);
// Store merge rank: "a b" -> rank
*space = ' '; // restore
tok_hashmap_set(&t->merge_rank, line, t->n_merges);
t->n_merges++;
if (t->n_merges >= TOK_MAX_MERGES) break;
}
fclose(f);
printf(" Merges: %d rules\n", t->n_merges);
return 0;
}
// --- Add special tokens ---
static void tok_add_special_tokens(Tokenizer *t) {
struct { const char *text; int id; } specials[] = {
{"<|endoftext|>", 151643},
{"<|im_start|>", 151644},
{"<|im_end|>", 151645},
};
for (int i = 0; i < 3; i++) {
tok_hashmap_set(&t->vocab, specials[i].text, specials[i].id);
if (specials[i].id < TOK_MAX_VOCAB) {
if (t->id_to_token[specials[i].id]) free(t->id_to_token[specials[i].id]);
t->id_to_token[specials[i].id] = strdup(specials[i].text);
}
if (specials[i].id >= t->vocab_size) t->vocab_size = specials[i].id + 1;
}
t->im_start = 151644;
t->im_end = 151645;
t->eos = 151643;
}
// --- Initialize tokenizer ---
static int tok_init(Tokenizer *t, const char *model_dir) {
char path[4096];
tok_init_byte_mapping();
snprintf(path, sizeof(path), "%s/vocab.json", model_dir);
if (tok_load_vocab(t, path) != 0) return -1;
snprintf(path, sizeof(path), "%s/merges.txt", model_dir);
if (tok_load_merges(t, path) != 0) return -1;
tok_add_special_tokens(t);
return 0;
}
static void tok_free(Tokenizer *t) {
tok_hashmap_free(&t->vocab);
tok_hashmap_free(&t->merge_rank);
if (t->id_to_token) {
for (int i = 0; i < t->vocab_size; i++)
if (t->id_to_token[i]) free(t->id_to_token[i]);
free(t->id_to_token);
}
if (t->merges) {
for (int i = 0; i < t->n_merges; i++) {
free(t->merges[i].a);
free(t->merges[i].b);
}
free(t->merges);
}
}
// --- BPE encoding ---
// Convert a raw byte string to its byte-level unicode representation (UTF-8).
// Each input byte is mapped through g_byte_to_unicode, then encoded as UTF-8.
static int tok_bytes_to_unicode_str(const char *input, int input_len, char *out, int max_out) {
int n = 0;
for (int i = 0; i < input_len && n < max_out - 4; i++) {
unsigned char b = (unsigned char)input[i];
int cp = g_byte_to_unicode[b];
n += utf8_encode(cp, out + n);
}
out[n] = '\0';
return n;
}
// A BPE word is a list of token strings (initially one per byte-level char).
typedef struct {
char **tokens;
int count;
int capacity;
} BPEWord;
static void bpe_word_init(BPEWord *w) {
w->capacity = 64;
w->tokens = (char**)malloc(w->capacity * sizeof(char*));
w->count = 0;
}
static void bpe_word_push(BPEWord *w, const char *s) {
if (w->count >= w->capacity) {
w->capacity *= 2;
w->tokens = (char**)realloc(w->tokens, w->capacity * sizeof(char*));
}
w->tokens[w->count++] = strdup(s);
}
static void bpe_word_free(BPEWord *w) {
for (int i = 0; i < w->count; i++) free(w->tokens[i]);
free(w->tokens);
}
// Apply BPE merges to a word (list of token strings).
static void bpe_merge(BPEWord *w, Tokenizer *t) {
while (w->count > 1) {
// Find the pair with lowest merge rank
int best_rank = t->n_merges + 1;
int best_idx = -1;
char pair_key[TOK_MAX_TOKEN_LEN * 2 + 2];
for (int i = 0; i < w->count - 1; i++) {
snprintf(pair_key, sizeof(pair_key), "%s %s", w->tokens[i], w->tokens[i+1]);
int rank = tok_hashmap_get(&t->merge_rank, pair_key, t->n_merges + 1);
if (rank < best_rank) {
best_rank = rank;
best_idx = i;
}
}
if (best_idx < 0) break; // no more merges
// Merge tokens[best_idx] and tokens[best_idx+1]
char merged[TOK_MAX_TOKEN_LEN * 2 + 1];
snprintf(merged, sizeof(merged), "%s%s", w->tokens[best_idx], w->tokens[best_idx+1]);
free(w->tokens[best_idx]);
free(w->tokens[best_idx+1]);
w->tokens[best_idx] = strdup(merged);
// Shift remaining tokens left
for (int i = best_idx + 1; i < w->count - 1; i++)
w->tokens[i] = w->tokens[i+1];
w->count--;
}
}
// Pre-tokenize: split on word boundaries (simplified GPT-style).
// Splits on transitions between: letters, digits, spaces, punctuation.
// Each "word" includes leading space if present (byte-level BPE convention).
typedef struct {
char **words;
int count;
int capacity;
} WordList;
static void wordlist_init(WordList *wl) {
wl->capacity = 256;
wl->words = (char**)malloc(wl->capacity * sizeof(char*));
wl->count = 0;
}
static void wordlist_push(WordList *wl, const char *s, int len) {
if (wl->count >= wl->capacity) {
wl->capacity *= 2;
wl->words = (char**)realloc(wl->words, wl->capacity * sizeof(char*));
}
char *copy = (char*)malloc(len + 1);
memcpy(copy, s, len);
copy[len] = '\0';
wl->words[wl->count++] = copy;
}
static void wordlist_free(WordList *wl) {
for (int i = 0; i < wl->count; i++) free(wl->words[i]);
free(wl->words);
}
static int is_letter(unsigned char c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c >= 0x80;
}
static int is_digit(unsigned char c) {
return c >= '0' && c <= '9';
}
static void tok_pre_tokenize(const char *text, WordList *out) {
wordlist_init(out);
int len = (int)strlen(text);
int i = 0;
while (i < len) {
int start = i;
if (text[i] == ' ') {
// Space + following word/punct
i++;
if (i < len && is_letter((unsigned char)text[i])) {
while (i < len && is_letter((unsigned char)text[i])) i++;
} else if (i < len && is_digit((unsigned char)text[i])) {
while (i < len && is_digit((unsigned char)text[i])) i++;
} else if (i < len && text[i] != ' ') {
i++; // single punct after space
}
wordlist_push(out, text + start, i - start);
} else if (is_letter((unsigned char)text[i])) {
while (i < len && is_letter((unsigned char)text[i])) i++;
wordlist_push(out, text + start, i - start);
} else if (is_digit((unsigned char)text[i])) {
while (i < len && is_digit((unsigned char)text[i])) i++;
wordlist_push(out, text + start, i - start);
} else if (text[i] == '\n' || text[i] == '\r') {
while (i < len && (text[i] == '\n' || text[i] == '\r')) i++;
wordlist_push(out, text + start, i - start);
} else {
i++;
wordlist_push(out, text + start, 1);
}
}
}
// --- Main encode function ---
// Returns number of token IDs written. Caller provides output buffer.
static int tok_encode(Tokenizer *t, const char *text, int *ids, int max_ids) {
int n_ids = 0;
// Pre-tokenize into words
WordList words;
tok_pre_tokenize(text, &words);
for (int w = 0; w < words.count && n_ids < max_ids; w++) {
// Convert word bytes to byte-level unicode string
char unicode_str[TOK_MAX_TOKEN_LEN * 4];
int wlen = (int)strlen(words.words[w]);
tok_bytes_to_unicode_str(words.words[w], wlen, unicode_str, sizeof(unicode_str));
// Split unicode string into individual unicode chars
BPEWord bpe;
bpe_word_init(&bpe);
const char *p = unicode_str;
while (*p) {
int cp;
int cplen = utf8_decode(p, &cp);
char single[8];
int slen = utf8_encode(cp, single);
single[slen] = '\0';
bpe_word_push(&bpe, single);
p += cplen;
}
// Apply BPE merges
bpe_merge(&bpe, t);
// Look up each resulting token in vocab
for (int i = 0; i < bpe.count && n_ids < max_ids; i++) {
int id = tok_hashmap_get(&t->vocab, bpe.tokens[i], -1);
if (id >= 0) {
ids[n_ids++] = id;
} else {
// Unknown token -- encode each byte-level char as individual token
const char *bp = bpe.tokens[i];
while (*bp && n_ids < max_ids) {
int bcp;
int bcplen = utf8_decode(bp, &bcp);
char single[8];
int slen = utf8_encode(bcp, single);
single[slen] = '\0';
int byte_id = tok_hashmap_get(&t->vocab, single, -1);
if (byte_id >= 0) ids[n_ids++] = byte_id;
bp += bcplen;
}
}
}
bpe_word_free(&bpe);
}
wordlist_free(&words);
return n_ids;
}
// --- Encode with special tokens ---
// Splits text on special token patterns, encodes non-special parts with BPE.
static int tok_encode_with_special(Tokenizer *t, const char *text, int *ids, int max_ids) {
struct { const char *text; int id; } specials[] = {
{"<|im_start|>", TOK_IM_START},
{"<|im_end|>", TOK_IM_END},
{"<|endoftext|>", TOK_ENDOFTEXT},
};
int n_specials = 3;
int n_ids = 0;
const char *p = text;
while (*p && n_ids < max_ids) {
// Check if current position matches a special token
int matched = 0;
for (int s = 0; s < n_specials; s++) {
int slen = (int)strlen(specials[s].text);
if (strncmp(p, specials[s].text, slen) == 0) {
ids[n_ids++] = specials[s].id;
p += slen;
matched = 1;
break;
}
}
if (matched) continue;
// Find next special token
const char *next_special = NULL;
for (int s = 0; s < n_specials; s++) {
const char *found = strstr(p, specials[s].text);
if (found && (!next_special || found < next_special))
next_special = found;
}
// Encode the text up to the next special (or end)
int chunk_len = next_special ? (int)(next_special - p) : (int)strlen(p);
if (chunk_len > 0) {
char *chunk = (char*)malloc(chunk_len + 1);
memcpy(chunk, p, chunk_len);
chunk[chunk_len] = '\0';
n_ids += tok_encode(t, chunk, ids + n_ids, max_ids - n_ids);
free(chunk);
}
p += chunk_len;
}
return n_ids;
}
// --- Decode token IDs to text ---
static int tok_decode(Tokenizer *t, const int *ids, int n_ids, char *out, int max_out) {
int n = 0;
for (int i = 0; i < n_ids; i++) {
int id = ids[i];
// Skip special tokens in output
if (id == TOK_IM_START || id == TOK_IM_END || id == TOK_ENDOFTEXT)
continue;
if (id < 0 || id >= t->vocab_size || !t->id_to_token[id])
continue;
const char *tok_str = t->id_to_token[id];
// Convert byte-level unicode token back to raw bytes
const char *p = tok_str;
while (*p && n < max_out - 1) {
int cp;
int cplen = utf8_decode(p, &cp);
int byte_val = g_unicode_to_byte[cp < 65536 ? cp : 0];
if (byte_val >= 0 && byte_val < 256) {
out[n++] = (char)byte_val;
} else {
// Not a byte-mapped char, copy UTF-8 directly
for (int j = 0; j < cplen && n < max_out - 1; j++)
out[n++] = p[j];
}
p += cplen;
}
}
out[n] = '\0';
return n;
}
// --- Chat template ---
// Formats: <|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n
static int tok_apply_chat_template(const char *system_prompt, const char *user_prompt,
char *out, int max_out) {
if (!system_prompt) system_prompt = "You are a helpful assistant.";
return snprintf(out, max_out,
"<|im_start|>system\n%s<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n",
system_prompt, user_prompt);
}
// --- Convenience: encode a chat prompt ---
static int tok_encode_chat(Tokenizer *t, const char *system_prompt, const char *user_prompt,
int *ids, int max_ids) {
char templated[65536];
tok_apply_chat_template(system_prompt, user_prompt, templated, sizeof(templated));
return tok_encode_with_special(t, templated, ids, max_ids);
}