[feat][gpu] Q4 quantization, Metal GPU shaders, ANE kernel fusion, memory safety

This commit is contained in:
Erik Bray 2026-03-04 00:48:17 +01:00
parent 0e70f5bd71
commit be96079bbf
8 changed files with 3340 additions and 258 deletions

4
.gitignore vendored
View File

@ -25,6 +25,9 @@ training/test_*
# Inference binaries and runtime data
inference/qwen_ane
inference/qwen05b.bin
inference/qwen05b_f32.bin
inference/qwen05b_f16.bin
inference/qwen05b_q8.bin
inference/.venv/
inference/benchmark_results.json
@ -59,6 +62,7 @@ web/
training/tinystories_data00.bin
training/ane_stories110M_ckpt.bin
*.bin
*.metallib
!training/download_data.sh
# Secrets / env

View File

@ -66,125 +66,235 @@ MEM_BYTES=$(sysctl -n hw.memsize 2>/dev/null || echo "0")
MEM_GB=$((MEM_BYTES / 1073741824))
echo ""
info "=== ANE Inference Benchmark (qwen_ane) ==="
info "=== ANE Multi-Format Inference Benchmark ==="
echo "Hardware: $CHIP"
echo "macOS: $MACOS"
echo "Memory: ${MEM_GB} GB"
echo "Model: Qwen2.5-0.5B-Instruct (BF16, 494M params)"
echo "Model: Qwen2.5-0.5B-Instruct (494M params)"
echo ""
# --- Phase 1: Server mode benchmark (HTTP API) ---
info "Phase 1: Server mode (persistent ANE kernels via HTTP API)"
dim "Starting server on port $HTTP_PORT..."
# --- Phase 0: Prepare weight files (F16 + Q8) ---
WEIGHTS_F16="$SCRIPT_DIR/qwen05b.bin"
WEIGHTS_Q8="$SCRIPT_DIR/qwen05b_q8.bin"
WEIGHTS_Q4="$SCRIPT_DIR/qwen05b_q4.bin"
CONVERT="$SCRIPT_DIR/convert_weights.py"
VENV_DIR="$SCRIPT_DIR/.venv"
# Start HTTP server in background
"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 &
SERVER_PID=$!
info "Phase 0: Preparing weight files"
if [ ! -f "$WEIGHTS_Q8" ]; then
if [ ! -f "$CONVERT" ]; then
echo " convert_weights.py not found, skipping Q8 generation."
WEIGHTS_Q8=""
else
dim "Generating Q8 weights (one-time)..."
if [ -d "$VENV_DIR" ]; then
source "$VENV_DIR/bin/activate"
fi
python3 "$CONVERT" "$MODEL_DIR" "$WEIGHTS_Q8" --q8
dim "Q8 weights ready: $(du -h "$WEIGHTS_Q8" | cut -f1)"
fi
else
dim "Q8 weights already exist: $(du -h "$WEIGHTS_Q8" | cut -f1)"
fi
if [ ! -f "$WEIGHTS_Q4" ]; then
if [ ! -f "$CONVERT" ]; then
echo " convert_weights.py not found, skipping Q4 generation."
WEIGHTS_Q4=""
else
dim "Generating Q4 weights (one-time)..."
if [ -d "$VENV_DIR" ]; then
source "$VENV_DIR/bin/activate"
fi
python3 "$CONVERT" "$MODEL_DIR" "$WEIGHTS_Q4" --q4
dim "Q4 weights ready: $(du -h "$WEIGHTS_Q4" | cut -f1)"
fi
else
dim "Q4 weights already exist: $(du -h "$WEIGHTS_Q4" | cut -f1)"
fi
dim "F16 weights: $(du -h "$WEIGHTS_F16" | cut -f1)"
echo ""
# ANE weight formats to benchmark
# GPU flag: empty for CPU formats, "--gpu" for Metal GPU formats
ANE_FMT_NAMES=("F16")
ANE_FMT_WEIGHTS=("$WEIGHTS_F16")
ANE_FMT_LABELS=("F16→F32 (AMX)")
ANE_FMT_GPU=("")
if [ -n "$WEIGHTS_Q8" ] && [ -f "$WEIGHTS_Q8" ]; then
ANE_FMT_NAMES+=("Q8")
ANE_FMT_WEIGHTS+=("$WEIGHTS_Q8")
ANE_FMT_LABELS+=("Q8 (NEON dequant)")
ANE_FMT_GPU+=("")
fi
if [ -n "$WEIGHTS_Q4" ] && [ -f "$WEIGHTS_Q4" ]; then
ANE_FMT_NAMES+=("Q4_Metal")
ANE_FMT_WEIGHTS+=("$WEIGHTS_Q4")
ANE_FMT_LABELS+=("Q4 SIMD (Metal GPU)")
ANE_FMT_GPU+=("--gpu")
ANE_FMT_NAMES+=("Q4_AMX")
ANE_FMT_WEIGHTS+=("$WEIGHTS_Q4")
ANE_FMT_LABELS+=("Q4→F32 (AMX dequant)")
ANE_FMT_GPU+=("")
fi
NUM_ANE_FMTS=${#ANE_FMT_NAMES[@]}
NUM_PROMPTS=${#PROMPTS[@]}
# Global cleanup
SERVER_PID=""
cleanup() {
kill "$SERVER_PID" 2>/dev/null || true
[ -n "$SERVER_PID" ] && kill "$SERVER_PID" 2>/dev/null || true
rm -f "$SOCK" /tmp/qwen_bench_server.log
}
trap cleanup EXIT
# Wait for READY
for i in $(seq 1 30); do
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then
break
fi
# Helper: start server with given weight file and optional extra flags, wait for READY
start_server() {
local wfile="$1"
shift
local extra_flags="$*"
[ -n "$SERVER_PID" ] && kill "$SERVER_PID" 2>/dev/null || true
sleep 1
done
if ! grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then
echo "Server failed to start. Log:"
rm -f /tmp/qwen_bench_server.log
"$BINARY" "$wfile" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" $extra_flags > /tmp/qwen_bench_server.log 2>&1 &
SERVER_PID=$!
for _i in $(seq 1 30); do
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then return 0; fi
sleep 1
done
echo "Server failed to start with $wfile. Log:"
cat /tmp/qwen_bench_server.log
exit 1
fi
dim "Server ready (PID $SERVER_PID)"
echo ""
return 1
}
# Warmup: first request primes any remaining caches
dim "Warmup run (discarded)..."
curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
-H "Content-Type: application/json" \
-d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1
echo ""
# --- Phase 1: Multi-format ANE benchmarks ---
# Per-format result tracking (indexed by format number)
declare -a ALL_AVG_P ALL_AVG_D ALL_AVG_INF ALL_AVG_TTFT ALL_AVG_RT
ANE_JSON_BLOCKS=""
# Print table header
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead"
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)"
printf '%.0s─' {1..85}; echo ""
for fmt_idx in $(seq 0 $((NUM_ANE_FMTS - 1))); do
FMT_NAME="${ANE_FMT_NAMES[$fmt_idx]}"
FMT_WEIGHTS="${ANE_FMT_WEIGHTS[$fmt_idx]}"
FMT_LABEL="${ANE_FMT_LABELS[$fmt_idx]}"
FMT_GPU="${ANE_FMT_GPU[$fmt_idx]}"
# Arrays for averages
declare -a P_TPS_ARR D_TPS_ARR INF_MS_ARR TTFT_MS_ARR RT_MS_ARR
echo ""
info "Phase 1.$((fmt_idx+1)): ANE $FMT_NAME benchmark ($FMT_LABEL)"
dim "Weights: $(du -h "$FMT_WEIGHTS" | cut -f1) — Starting server..."
JSON_ENTRIES=""
NUM_PROMPTS=${#PROMPTS[@]}
if ! start_server "$FMT_WEIGHTS" $FMT_GPU; then
echo "Skipping $FMT_NAME format."
ALL_AVG_P+=("0"); ALL_AVG_D+=("0"); ALL_AVG_INF+=("0")
ALL_AVG_TTFT+=("0"); ALL_AVG_RT+=("0")
continue
fi
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
NAME="${PROMPT_NAMES[$i]}"
PROMPT="${PROMPTS[$i]}"
MAXTOK="${MAX_TOKENS[$i]}"
RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
dim "Warmup run (discarded)..."
curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
-H "Content-Type: application/json" \
-d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1)
RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
# Parse server JSON with pure shell -- no python
P_TOKENS=$(json_val "$RESP" "prompt_tokens")
G_TOKENS=$(json_val "$RESP" "gen_tokens")
P_TPS=$(json_val "$RESP" "prefill_tps")
D_TPS=$(json_val "$RESP" "decode_tps")
TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")")
INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")")
TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")")
TEXT=$(json_text "$RESP")
OVERHEAD=$((RT_MS - TOTAL_MS))
-d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1
echo ""
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD"
"Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead"
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)"
printf '%.0s─' {1..85}; echo ""
P_TPS_ARR+=("$P_TPS")
D_TPS_ARR+=("$D_TPS")
INF_MS_ARR+=("$INF_MS")
TTFT_MS_ARR+=("$TTFT_MS")
RT_MS_ARR+=("$RT_MS")
declare -a P_TPS_ARR=() D_TPS_ARR=() INF_MS_ARR=() TTFT_MS_ARR=() RT_MS_ARR=()
FMT_JSON_ENTRIES=""
# Build JSON entry
JSON_ENTRIES="$JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS},"
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
NAME="${PROMPT_NAMES[$i]}"
PROMPT="${PROMPTS[$i]}"
MAXTOK="${MAX_TOKENS[$i]}"
# Print response text indented below
echo "$TEXT"
RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
-H "Content-Type: application/json" \
-d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1)
RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
P_TOKENS=$(json_val "$RESP" "prompt_tokens")
G_TOKENS=$(json_val "$RESP" "gen_tokens")
P_TPS=$(json_val "$RESP" "prefill_tps")
D_TPS=$(json_val "$RESP" "decode_tps")
TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")")
INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")")
TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")")
TEXT=$(json_text "$RESP")
OVERHEAD=$((RT_MS - TOTAL_MS))
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
"$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD"
P_TPS_ARR+=("$P_TPS")
D_TPS_ARR+=("$D_TPS")
INF_MS_ARR+=("$INF_MS")
TTFT_MS_ARR+=("$TTFT_MS")
RT_MS_ARR+=("$RT_MS")
FMT_JSON_ENTRIES="$FMT_JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS},"
echo "$TEXT"
echo ""
done
printf '%.0s─' {1..85}; echo ""
F_AVG_P=$(shell_avg "${P_TPS_ARR[@]}")
F_AVG_D=$(shell_avg "${D_TPS_ARR[@]}")
F_AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}")
F_AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}")
F_AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}")
F_AVG_OVERHEAD=$((F_AVG_RT - F_AVG_INF))
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$F_AVG_P" "$F_AVG_D" "$F_AVG_TTFT" "$F_AVG_INF" "$F_AVG_RT" "$F_AVG_OVERHEAD"
echo ""
ALL_AVG_P+=("$F_AVG_P")
ALL_AVG_D+=("$F_AVG_D")
ALL_AVG_INF+=("$F_AVG_INF")
ALL_AVG_TTFT+=("$F_AVG_TTFT")
ALL_AVG_RT+=("$F_AVG_RT")
ANE_JSON_BLOCKS="$ANE_JSON_BLOCKS
\"$FMT_NAME\": {
\"format\": \"$FMT_NAME\",
\"label\": \"$FMT_LABEL\",
\"weight_size_mb\": $(du -m "$FMT_WEIGHTS" | cut -f1),
\"avg_prefill_tps\": $F_AVG_P,
\"avg_decode_tps\": $F_AVG_D,
\"avg_inference_ms\": $F_AVG_INF,
\"avg_roundtrip_ms\": $F_AVG_RT,
\"avg_ttft_ms\": $F_AVG_TTFT,
\"results\": [${FMT_JSON_ENTRIES%,}]
},"
done
printf '%.0s─' {1..85}; echo ""
# Use F16 results as the primary ANE numbers (first format)
AVG_P="${ALL_AVG_P[0]}"
AVG_D="${ALL_AVG_D[0]}"
AVG_INF="${ALL_AVG_INF[0]}"
AVG_TTFT="${ALL_AVG_TTFT[0]}"
AVG_RT="${ALL_AVG_RT[0]}"
# Averages (pure shell, no python)
AVG_P=$(shell_avg "${P_TPS_ARR[@]}")
AVG_D=$(shell_avg "${D_TPS_ARR[@]}")
AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}")
AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}")
AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}")
AVG_OVERHEAD=$((AVG_RT - AVG_INF))
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$AVG_P" "$AVG_D" "$AVG_TTFT" "$AVG_INF" "$AVG_RT" "$AVG_OVERHEAD"
echo ""
info "Infer = server-reported (pure processing). Rndtrip = wall-clock (what clients see)."
echo ""
# --- Phase 2: Cold start measurement ---
info "Phase 2: Cold start (single-shot, recompiles ANE kernels)"
# Kill server, run single-shot
kill "$SERVER_PID" 2>/dev/null || true
SERVER_PID=""
sleep 1
# Use perl for sub-second timing (available on all macOS, no python)
COLD_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
COLD_OUT=$("$BINARY" "$WEIGHTS" "151644 8948 198 2610 525 264 10950 17847 13 151645 198 151644 872 198 13048 151645 198 151644 77091 198" 10 2>&1 || true)
COLD_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
@ -193,17 +303,11 @@ COLD_MS=$(echo "$COLD_T0 $COLD_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
echo "Cold start latency: ${COLD_MS}ms (includes ANE kernel compilation)"
echo ""
# Re-start server for any additional tests
"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 &
SERVER_PID=$!
# Re-start server (F16) for consistency check
start_server "$WEIGHTS_F16"
# --- Phase 3: Repeated prompt (consistency check) ---
info "Phase 3: Decode speed consistency (5x same prompt)"
for retry in $(seq 1 15); do
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then break; fi
sleep 1
done
info "Phase 3: Decode speed consistency (5x same prompt, F16)"
printf "%-6s %10s %10s %10s\n" "Run" "Prefill" "Decode" "Infer(ms)"
printf '%.0s─' {1..40}; echo ""
@ -227,12 +331,8 @@ JSON="{
\"model\": \"Qwen2.5-0.5B-Instruct\",
\"mode\": \"http_server\",
\"cold_start_ms\": $COLD_MS,
\"avg_prefill_tps\": $AVG_P,
\"avg_decode_tps\": $AVG_D,
\"avg_inference_ms\": $AVG_INF,
\"avg_roundtrip_ms\": $AVG_RT,
\"avg_ttft_ms\": $AVG_TTFT,
\"results\": [${JSON_ENTRIES%,}]
\"ane_formats\": {$( echo "$ANE_JSON_BLOCKS" | sed '$ s/,$//' )
}
}"
echo "$JSON" > "$RESULTS_JSON"
dim "Results saved to $RESULTS_JSON"
@ -240,9 +340,12 @@ echo ""
# --- Phase 4: LM Studio comparison (if running) ---
LMS_PORT="${LMS_PORT:-1234}"
LMS_MODEL="${LMS_MODEL:-qwen2.5-0.5b-instruct}"
LMS_API_KEY="${LMS_API_KEY:-}"
# Models to benchmark (override via LMS_MODELS env var, comma-separated)
LMS_MODELS_DEFAULT="qwen2.5-0.5b-instruct,qwen2.5-0.5b-instruct-mlx@8bit,qwen2.5-0.5b-instruct-mlx@4bit"
IFS=',' read -ra LMS_MODEL_LIST <<< "${LMS_MODELS:-$LMS_MODELS_DEFAULT}"
# Check if LM Studio is running
LMS_REACHABLE=0
if curl -s --max-time 2 "http://localhost:$LMS_PORT/api/v1/chat" -H "Content-Type: application/json" -d '{}' >/dev/null 2>&1; then
@ -251,8 +354,8 @@ fi
if [ "$LMS_REACHABLE" -eq 1 ]; then
info "Phase 4: LM Studio comparison (localhost:$LMS_PORT)"
dim "Models: ${LMS_MODEL_LIST[*]}"
# If no API key, prompt for it
if [ -z "$LMS_API_KEY" ]; then
echo ""
echo " LM Studio requires an API key."
@ -268,30 +371,53 @@ if [ "$LMS_REACHABLE" -eq 1 ]; then
fi
fi
LMS_ALL_JSON=""
if [ "$LMS_REACHABLE" -eq 1 ] && [ -n "$LMS_API_KEY" ]; then
echo ""
printf "%-10s %5s %5s %10s %10s %10s\n" \
"Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip"
printf "%-10s %5s %5s %10s %10s %10s\n" \
"" "tok" "tok" "(t/s)" "(ms)" "(ms)"
printf '%.0s─' {1..55}; echo ""
declare -a LMS_LATENCIES LMS_TPS_ARR LMS_TTFT_ARR
LMS_JSON_ENTRIES=""
# Track the best model for the final comparison table
BEST_LMS_MODEL=""
BEST_LMS_TPS="0"
BEST_LMS_LAT="99999"
BEST_LMS_TTFT="0"
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
NAME="${PROMPT_NAMES[$i]}"
PROMPT="${PROMPTS[$i]}"
for LMS_MODEL in "${LMS_MODEL_LIST[@]}"; do
echo ""
info "── $LMS_MODEL ──"
T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \
# Test if this model is available
TEST_RESP=$(curl -s --max-time 10 "http://localhost:$LMS_PORT/api/v1/chat" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $LMS_API_KEY" \
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1)
T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"test\",\"input\":\"hi\"}" 2>&1)
eval "$(echo "$LMS_RESP" | python3 -c "
if echo "$TEST_RESP" | grep -qi "error\|not found\|not loaded\|no model"; then
dim " Model '$LMS_MODEL' not available, skipping."
continue
fi
printf "%-10s %5s %5s %10s %10s %10s\n" \
"Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip"
printf "%-10s %5s %5s %10s %10s %10s\n" \
"" "tok" "tok" "(t/s)" "(ms)" "(ms)"
printf '%.0s─' {1..55}; echo ""
declare -a LMS_LATENCIES=() LMS_TPS_ARR=() LMS_TTFT_ARR=()
LMS_JSON_ENTRIES=""
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
NAME="${PROMPT_NAMES[$i]}"
PROMPT="${PROMPTS[$i]}"
T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $LMS_API_KEY" \
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1)
T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
eval "$(echo "$LMS_RESP" | python3 -c "
import sys, json
try:
r = json.load(sys.stdin)
@ -314,69 +440,191 @@ except Exception as e:
print('LMS_OUT=0')
" 2>/dev/null)"
printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS"
echo "$LMS_TEXT"
echo ""
LMS_LATENCIES+=("$LMS_MS")
LMS_TPS_ARR+=("$LMS_TPS")
LMS_TTFT_ARR+=("$LMS_TTFT")
LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT},"
printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS"
LMS_LATENCIES+=("$LMS_MS")
LMS_TPS_ARR+=("$LMS_TPS")
LMS_TTFT_ARR+=("$LMS_TTFT")
LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT},"
done
printf '%.0s─' {1..55}; echo ""
M_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}")
M_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}")
M_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}")
printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$M_AVG_TPS" "$M_AVG_TTFT" "$M_AVG_LAT"
# Track the best model by decode t/s
if awk "BEGIN {exit !($M_AVG_TPS > $BEST_LMS_TPS)}" 2>/dev/null; then
BEST_LMS_MODEL="$LMS_MODEL"
BEST_LMS_TPS="$M_AVG_TPS"
BEST_LMS_LAT="$M_AVG_LAT"
BEST_LMS_TTFT="$M_AVG_TTFT"
fi
LMS_ALL_JSON="$LMS_ALL_JSON
\"$(echo "$LMS_MODEL" | sed 's/[^a-zA-Z0-9._-]/_/g')\": {
\"model\": \"$LMS_MODEL\",
\"avg_latency_ms\": $M_AVG_LAT,
\"avg_tps\": $M_AVG_TPS,
\"avg_ttft_ms\": $M_AVG_TTFT,
\"results\": [${LMS_JSON_ENTRIES%,}]
},"
done
printf '%.0s─' {1..55}; echo ""
# Averages (awk, no python)
LMS_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}")
LMS_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}")
LMS_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}")
printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$LMS_AVG_TPS" "$LMS_AVG_TTFT" "$LMS_AVG_LAT"
echo ""
# Side-by-side comparison
info "=== Side-by-Side Comparison ==="
dim "(Round-trip = wall-clock from client, apples-to-apples)"
echo ""
printf "%-24s %15s %15s\n" "" "ANE (qwen_ane)" "LM Studio"
printf '%.0s─' {1..56}; echo ""
printf "%-24s %12s t/s %12s t/s\n" "Decode speed" "$AVG_D" "$LMS_AVG_TPS"
printf "%-24s %12s t/s %12s\n" "Prefill speed" "$AVG_P" "N/A"
printf "%-24s %12s ms %12s ms\n" "TTFT" "$AVG_TTFT" "$LMS_AVG_TTFT"
printf "%-24s %12s ms %12s ms\n" "Avg round-trip" "$AVG_RT" "$LMS_AVG_LAT"
printf "%-24s %12s ms %12s ms\n" " (server-only)" "$AVG_INF" "N/A"
printf "%-24s %12s ms %12s\n" "Cold start" "$COLD_MS" "N/A"
printf "%-24s %15s %15s\n" "Precision" "F32 (from BF16)" "GGUF quantized"
printf "%-24s %15s %15s\n" "Accelerator" "Neural Engine" "CPU/GPU"
printf "%-24s %15s %15s\n" "Timing method" "Wall-clock" "Wall-clock"
# --- Final Comparison Table: all ANE formats + all LM Studio models ---
info "=== Multi-Format Comparison ==="
dim "(All times are wall-clock round-trip, apples-to-apples)"
echo ""
# Append LM Studio block to JSON results (pure shell, no python)
# Remove trailing "}" and newline, append lm_studio object
# Collect all column names and data
declare -a COL_NAMES=() COL_DECODE=() COL_PREFILL=() COL_TTFT=() COL_RT=() COL_PREC=() COL_ACCEL=()
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do
COL_NAMES+=("ANE ${ANE_FMT_NAMES[$fi2]}")
COL_DECODE+=("${ALL_AVG_D[$fi2]}")
COL_PREFILL+=("${ALL_AVG_P[$fi2]}")
COL_TTFT+=("${ALL_AVG_TTFT[$fi2]}")
COL_RT+=("${ALL_AVG_RT[$fi2]}")
COL_PREC+=("${ANE_FMT_LABELS[$fi2]}")
if [ -n "${ANE_FMT_GPU[$fi2]}" ]; then
COL_ACCEL+=("Metal GPU")
else
COL_ACCEL+=("CPU (AMX)")
fi
done
# Add each tested LM Studio model as a column
declare -a LMS_TESTED_NAMES=() LMS_TESTED_TPS=() LMS_TESTED_TTFT=() LMS_TESTED_LAT=()
for LMS_MODEL in "${LMS_MODEL_LIST[@]}"; do
# Check if this model was actually tested (has data in LMS_ALL_JSON)
SAFE_KEY=$(echo "$LMS_MODEL" | sed 's/[^a-zA-Z0-9._-]/_/g')
if echo "$LMS_ALL_JSON" | grep -q "\"$SAFE_KEY\""; then
M_TPS=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_tps":[[:space:]]*\([0-9.]*\).*/\1/p' | head -1)
M_TTFT=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_ttft_ms":[[:space:]]*\([0-9]*\).*/\1/p' | head -1)
M_LAT=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_latency_ms":[[:space:]]*\([0-9]*\).*/\1/p' | head -1)
SHORT_NAME=$(echo "$LMS_MODEL" | sed 's/qwen2.5-0.5b-instruct/q0.5b/; s/-mlx/mlx/')
COL_NAMES+=("LMS $SHORT_NAME")
COL_DECODE+=("${M_TPS:-0}")
COL_PREFILL+=("N/A")
COL_TTFT+=("${M_TTFT:-0}")
COL_RT+=("${M_LAT:-0}")
PREC_TAG="GGUF"
echo "$LMS_MODEL" | grep -q "8bit" && PREC_TAG="MLX 8-bit"
echo "$LMS_MODEL" | grep -q "4bit" && PREC_TAG="MLX 4-bit"
COL_PREC+=("$PREC_TAG")
COL_ACCEL+=("CPU/GPU")
LMS_TESTED_NAMES+=("$LMS_MODEL")
LMS_TESTED_TPS+=("${M_TPS:-0}")
LMS_TESTED_TTFT+=("${M_TTFT:-0}")
LMS_TESTED_LAT+=("${M_LAT:-0}")
fi
done
NUM_COLS=${#COL_NAMES[@]}
COL_W=16
# Print header row
printf "%-20s" ""
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_NAMES[$c]}"; done
echo ""
printf '%.0s─' $(seq 1 $((20 + NUM_COLS * COL_W))); echo ""
# Data rows
printf "%-20s" "Decode (t/s)"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_DECODE[$c]}"; done
echo ""
printf "%-20s" "Prefill (t/s)"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_PREFILL[$c]}"; done
echo ""
printf "%-20s" "TTFT (ms)"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_TTFT[$c]}"; done
echo ""
printf "%-20s" "Round-trip (ms)"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_RT[$c]}"; done
echo ""
printf "%-20s" "Cold start (ms)"
printf "%${COL_W}s" "$COLD_MS"
for c in $(seq 1 $((NUM_COLS - 1))); do printf "%${COL_W}s" "N/A"; done
echo ""
printf '%.0s─' $(seq 1 $((20 + NUM_COLS * COL_W))); echo ""
printf "%-20s" "Precision"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_PREC[$c]}"; done
echo ""
printf "%-20s" "Accelerator"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_ACCEL[$c]}"; done
echo ""
printf "%-20s" "Timing"
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "Wall-clock"; done
echo ""
echo ""
# Append LM Studio results to JSON
LMS_JSON_BLOCK=",
\"lm_studio\": {
\"port\": $LMS_PORT,
\"model\": \"$LMS_MODEL\",
\"avg_latency_ms\": $LMS_AVG_LAT,
\"avg_tps\": $LMS_AVG_TPS,
\"avg_ttft_ms\": $LMS_AVG_TTFT,
\"results\": [${LMS_JSON_ENTRIES%,}]
\"models_tested\": [$(printf '"%s",' "${LMS_MODEL_LIST[@]}" | sed 's/,$//')],$( echo "$LMS_ALL_JSON" | sed '$ s/,$//' )
}
}"
# Replace the final "}" with the LM Studio block
sed -i '' '$ s/}$//' "$RESULTS_JSON"
printf '%s\n' "$LMS_JSON_BLOCK" >> "$RESULTS_JSON"
dim "LM Studio results added to $RESULTS_JSON"
else
# No LM Studio -- print ANE-only comparison if we have multiple formats
if [ "$NUM_ANE_FMTS" -gt 1 ]; then
info "=== ANE Format Comparison ==="
echo ""
printf "%-20s" ""
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "ANE ${ANE_FMT_NAMES[$fi2]}"; done
echo ""
printf '%.0s─' $(seq 1 $((20 + NUM_ANE_FMTS * 16))); echo ""
printf "%-20s" "Decode (t/s)"
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_D[$fi2]}"; done
echo ""
printf "%-20s" "Prefill (t/s)"
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_P[$fi2]}"; done
echo ""
printf "%-20s" "TTFT (ms)"
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_TTFT[$fi2]}"; done
echo ""
printf "%-20s" "Round-trip (ms)"
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_RT[$fi2]}"; done
echo ""
printf '%.0s─' $(seq 1 $((20 + NUM_ANE_FMTS * 16))); echo ""
echo ""
fi
info "=== LM Studio Comparison ==="
echo ""
if [ "$LMS_REACHABLE" -eq 0 ]; then
echo " LM Studio server not detected on localhost:$LMS_PORT"
echo ""
echo " To enable automatic comparison:"
echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF)"
echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF + MLX variants)"
echo " 2. Load the model, go to Developer tab > Start Server"
echo " 3. Re-run this benchmark"
echo ""
echo " Or set env vars: LMS_PORT=1234 LMS_API_KEY=your-key ./benchmark.sh"
echo ""
echo " Models benchmarked by default:"
echo " - qwen2.5-0.5b-instruct (GGUF)"
echo " - qwen2.5-0.5b-instruct-mlx@8bit (MLX 8-bit)"
echo " - qwen2.5-0.5b-instruct-mlx@4bit (MLX 4-bit)"
echo ""
echo " Override with: LMS_MODELS='model1,model2' ./benchmark.sh"
fi
echo ""
echo " Manual test:"
@ -385,9 +633,9 @@ else
echo " -H 'Authorization: Bearer YOUR_API_KEY' \\"
echo " -d '{\"model\":\"qwen2.5-0.5b-instruct\",\"system_prompt\":\"You are a helpful assistant.\",\"input\":\"What is 2+2?\"}'"
echo ""
echo " ANE (this benchmark): prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms"
echo " ANE F16: prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms"
echo ""
echo " Note: LM Studio uses quantized GGUF (CPU/GPU) while we use"
echo " BF16 weights (full precision) running on the Neural Engine."
echo " Note: LM Studio uses quantized GGUF/MLX (CPU/GPU) while we use"
echo " F16/Q8 weights running on CPU AMX / NEON."
fi
echo ""

View File

@ -1,11 +1,17 @@
#!/usr/bin/env python3
"""Convert Qwen2.5-0.5B-Instruct safetensors → flat binary for ANE inference.
Output format: config header (7 ints) + all weights in f32, layer by layer.
Matches the layout expected by qwen_ane_infer.h.
Output format (F32): config header (8 ints) + all weights in f32
Output format (F16): config header (8 ints) + embeddings f32 + projection weights f16
Output format (Q8): config header (8 ints) + embeddings f32 + projection weights q8_0
Output format (Q4): config header (8 ints) + embeddings f32 + projection weights q4_0
The 8th config int is the format flag: 0 = F32, 1 = F16, 2 = Q8, 3 = Q4.
Q8_0 format: blocks of 32 values, each block = 1 f16 scale + 32 int8 values (34 bytes).
Q4_0 format: blocks of 32 values, each block = 1 f16 scale + 1 f16 zero + 16 uint8 packed pairs (20 bytes).
Usage:
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct /path/to/output.bin
python3 convert_weights.py <model_dir> <output.bin> [--f16|--q8|--q4]
"""
import struct
@ -14,10 +20,74 @@ import numpy as np
from pathlib import Path
from safetensors import safe_open
def convert(model_dir: str, output_path: str):
Q8_BLOCK_SIZE = 32
Q4_BLOCK_SIZE = 32
def quantize_q4_0(weights_f32):
"""Quantize a 2D weight matrix to Q4_0 block format.
Returns bytes: for each row, blocks of (f16_scale + f16_zero + 16*uint8 packed pairs).
Each uint8 stores two 4-bit values: low nibble = even index, high nibble = odd index."""
out_dim, in_dim = weights_f32.shape
assert in_dim % Q4_BLOCK_SIZE == 0, f"in_dim {in_dim} not divisible by {Q4_BLOCK_SIZE}"
n_blocks_per_row = in_dim // Q4_BLOCK_SIZE
result = bytearray()
for r in range(out_dim):
row = weights_f32[r]
for b in range(n_blocks_per_row):
block = row[b * Q4_BLOCK_SIZE : (b + 1) * Q4_BLOCK_SIZE]
bmin = np.min(block)
bmax = np.max(block)
if bmax == bmin:
scale = np.float16(0.0)
zero = np.float16(0.0)
packed = bytes(Q4_BLOCK_SIZE // 2)
else:
scale_f = (bmax - bmin) / 15.0
zero_f = bmin
scale = np.float16(scale_f)
zero = np.float16(zero_f)
scale_f = float(scale) if float(scale) != 0.0 else 1e-10
quant = np.clip(np.round((block - float(zero)) / scale_f), 0, 15).astype(np.uint8)
packed = bytearray(Q4_BLOCK_SIZE // 2)
for i in range(0, Q4_BLOCK_SIZE, 2):
packed[i // 2] = quant[i] | (quant[i + 1] << 4)
result += scale.tobytes()
result += zero.tobytes()
result += bytes(packed)
return bytes(result)
def quantize_q8_0(weights_f32):
"""Quantize a 2D weight matrix to Q8_0 block format.
Returns bytes: for each row, blocks of (f16_scale + 32*int8)."""
out_dim, in_dim = weights_f32.shape
assert in_dim % Q8_BLOCK_SIZE == 0, f"in_dim {in_dim} not divisible by {Q8_BLOCK_SIZE}"
n_blocks_per_row = in_dim // Q8_BLOCK_SIZE
result = bytearray()
for r in range(out_dim):
row = weights_f32[r]
for b in range(n_blocks_per_row):
block = row[b * Q8_BLOCK_SIZE : (b + 1) * Q8_BLOCK_SIZE]
amax = np.max(np.abs(block))
scale = amax / 127.0 if amax > 0 else 0.0
if scale > 0:
quant = np.round(block / scale).astype(np.int8)
else:
quant = np.zeros(Q8_BLOCK_SIZE, dtype=np.int8)
result += np.float16(scale).tobytes()
result += quant.tobytes()
return bytes(result)
def convert(model_dir: str, output_path: str, fmt: str = "f32"):
model_dir = Path(model_dir)
# Load safetensors
st_files = list(model_dir.glob("*.safetensors"))
if not st_files:
print(f"No safetensors files in {model_dir}")
@ -30,8 +100,8 @@ def convert(model_dir: str, output_path: str):
tensors[key] = sf.get_tensor(key).float().numpy()
print(f"Loaded {len(tensors)} tensors from {len(st_files)} files")
print(f"Mode: {fmt.upper()} projections (embeddings + norms + biases stay F32)")
# Qwen2.5-0.5B config
dim = 896
hidden = 4864
n_layers = 24
@ -39,37 +109,41 @@ def convert(model_dir: str, output_path: str):
n_kv_heads = 2
vocab_size = 151936
max_seq = 512
fmt_flag = {"f32": 0, "f16": 1, "q8": 2, "q4": 3}[fmt]
def write_proj(f_out, tensor_f32):
if fmt == "q4":
f_out.write(quantize_q4_0(tensor_f32))
elif fmt == "q8":
f_out.write(quantize_q8_0(tensor_f32))
elif fmt == "f16":
f_out.write(tensor_f32.astype(np.float16).tobytes())
else:
f_out.write(tensor_f32.astype(np.float32).tobytes())
with open(output_path, "wb") as f:
# Config header: 7 x int32
f.write(struct.pack("iiiiiii",
dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq))
f.write(struct.pack("iiiiiiii",
dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq, fmt_flag))
# Embedding [vocab, dim]
emb = tensors["model.embed_tokens.weight"].astype(np.float32)
print(f"embed: {emb.shape}")
print(f"embed: {emb.shape} (f32)")
f.write(emb.tobytes())
# Per-layer weights
for l in range(n_layers):
prefix = f"model.layers.{l}"
# Attention norm
rms_att = tensors[f"{prefix}.input_layernorm.weight"].astype(np.float32)
f.write(rms_att.tobytes())
# Q, K, V projections
wq = tensors[f"{prefix}.self_attn.q_proj.weight"].astype(np.float32)
wk = tensors[f"{prefix}.self_attn.k_proj.weight"].astype(np.float32)
wv = tensors[f"{prefix}.self_attn.v_proj.weight"].astype(np.float32)
wo = tensors[f"{prefix}.self_attn.o_proj.weight"].astype(np.float32)
f.write(wq.tobytes())
f.write(wk.tobytes())
f.write(wv.tobytes())
f.write(wo.tobytes())
write_proj(f, wq)
write_proj(f, wk)
write_proj(f, wv)
write_proj(f, wo)
# Q/K biases (Qwen has them)
# Q/K/V biases
qb = tensors.get(f"{prefix}.self_attn.q_proj.bias")
kb = tensors.get(f"{prefix}.self_attn.k_proj.bias")
vb = tensors.get(f"{prefix}.self_attn.v_proj.bias")
@ -77,22 +151,19 @@ def convert(model_dir: str, output_path: str):
f.write((kb if kb is not None else np.zeros(wk.shape[0])).astype(np.float32).tobytes())
f.write((vb if vb is not None else np.zeros(wv.shape[0])).astype(np.float32).tobytes())
# FFN norm
rms_ffn = tensors[f"{prefix}.post_attention_layernorm.weight"].astype(np.float32)
f.write(rms_ffn.tobytes())
# FFN: gate, up, down
w_gate = tensors[f"{prefix}.mlp.gate_proj.weight"].astype(np.float32)
w_up = tensors[f"{prefix}.mlp.up_proj.weight"].astype(np.float32)
w_down = tensors[f"{prefix}.mlp.down_proj.weight"].astype(np.float32)
f.write(w_gate.tobytes())
f.write(w_up.tobytes())
f.write(w_down.tobytes())
write_proj(f, w_gate)
write_proj(f, w_up)
write_proj(f, w_down)
print(f" Layer {l}: Q{wq.shape} K{wk.shape} V{wv.shape} O{wo.shape} "
f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape}")
f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape} [{fmt}]")
# Final norm
rms_final = tensors["model.norm.weight"].astype(np.float32)
f.write(rms_final.tobytes())
@ -101,7 +172,14 @@ def convert(model_dir: str, output_path: str):
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python3 convert_weights.py <model_dir> <output.bin>")
if len(sys.argv) < 3:
print("Usage: python3 convert_weights.py <model_dir> <output.bin> [--f16|--q8|--q4]")
sys.exit(1)
convert(sys.argv[1], sys.argv[2])
fmt = "f32"
if "--f16" in sys.argv:
fmt = "f16"
elif "--q8" in sys.argv:
fmt = "q8"
elif "--q4" in sys.argv:
fmt = "q4"
convert(sys.argv[1], sys.argv[2], fmt)

View File

@ -6,9 +6,10 @@
// 4. HTTP API: ./qwen_ane weights.bin --http 8000 --model-dir ~/models/Qwen2.5-0.5B-Instruct
//
// Build:
// xcrun clang -O2 -framework Foundation -framework IOSurface \
// -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
// -o qwen_ane main.m
// xcrun clang -O3 -ffast-math -mcpu=apple-m4 -flto \
// -framework Foundation -framework IOSurface \
// -framework CoreML -framework Accelerate -framework Metal \
// -ldl -lobjc -fobjc-arc -o qwen_ane main.m
//
#import <Foundation/Foundation.h>
#include <stdio.h>
@ -39,36 +40,112 @@ static void handle_signal(int sig) {
_exit(0);
}
static void *safe_malloc(size_t size, const char *desc) {
void *p = malloc(size);
if (!p) {
fprintf(stderr, "FATAL: malloc failed for %s (%.1f MB)\n",
desc, (double)size / (1024*1024));
exit(1);
}
return p;
}
static void *safe_calloc(size_t count, size_t size, const char *desc) {
void *p = calloc(count, size);
if (!p) {
fprintf(stderr, "FATAL: calloc failed for %s (%.1f MB)\n",
desc, (double)(count * size) / (1024*1024));
exit(1);
}
return p;
}
static int load_weights(const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
int config[7];
fread(config, sizeof(int), 7, f);
// Try 8-int header first (new format), fall back to 7-int (legacy)
int config[8] = {0};
size_t hdr_read = fread(config, sizeof(int), 8, f);
int dim = config[0], hidden = config[1], n_layers = config[2];
int n_heads = config[3], n_kv_heads = config[4], vocab = config[5];
printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d\n",
dim, hidden, n_layers, n_heads, n_kv_heads, vocab);
int fmt_flag = 0;
if (hdr_read == 8 && config[7] >= 0 && config[7] <= 3) {
fmt_flag = config[7];
} else {
fseek(f, 7 * sizeof(int), SEEK_SET);
}
g_model.weight_fmt = fmt_flag;
int is_f16 = (fmt_flag == 1);
int is_q8 = (fmt_flag == 2);
int is_q4 = (fmt_flag == 3);
const char *fmt_str = is_q4 ? "Q4" : (is_q8 ? "Q8" : (is_f16 ? "F16" : "F32"));
printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d fmt=%s\n",
dim, hidden, n_layers, n_heads, n_kv_heads, vocab, fmt_str);
int q_dim = n_heads * QWEN_HEAD_DIM;
int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float));
// Embeddings always F32
g_model.embed = (float*)safe_malloc((size_t)vocab * dim * sizeof(float), "embed");
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
for (int l = 0; l < n_layers; l++) {
// RMSNorm always F32
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_att[l], sizeof(float), dim, f);
g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f);
g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
if (is_q4) {
#define LOAD_Q4(q8ptr, out_d, in_d) do { \
size_t _nb = (size_t)(in_d) / Q4_BLOCK_SIZE; \
size_t _bytes = (size_t)(out_d) * _nb * Q4_BLOCK_BYTES; \
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
fread(q8ptr, 1, _bytes, f); \
} while(0)
LOAD_Q4(g_model.wq_q8[l], q_dim, dim);
LOAD_Q4(g_model.wk_q8[l], kv_dim, dim);
LOAD_Q4(g_model.wv_q8[l], kv_dim, dim);
LOAD_Q4(g_model.wo_q8[l], dim, q_dim);
#undef LOAD_Q4
} else if (is_q8) {
#define LOAD_Q8(q8ptr, out_d, in_d) do { \
size_t _nb = (size_t)(in_d) / Q8_BLOCK_SIZE; \
size_t _bytes = (size_t)(out_d) * _nb * Q8_BLOCK_BYTES; \
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
fread(q8ptr, 1, _bytes, f); \
} while(0)
LOAD_Q8(g_model.wq_q8[l], q_dim, dim);
LOAD_Q8(g_model.wk_q8[l], kv_dim, dim);
LOAD_Q8(g_model.wv_q8[l], kv_dim, dim);
LOAD_Q8(g_model.wo_q8[l], dim, q_dim);
#undef LOAD_Q8
} else if (is_f16) {
#define LOAD_F16_AS_F32(f32ptr, f16ptr, n) do { \
size_t _n = (size_t)(n); \
f16ptr = (_Float16*)malloc(_n * sizeof(_Float16)); \
fread(f16ptr, sizeof(_Float16), _n, f); \
f32ptr = (float*)malloc(_n * sizeof(float)); \
convert_f16_to_f32(f16ptr, f32ptr, _n); \
} while(0)
LOAD_F16_AS_F32(g_model.wq[l], g_model.wq_f16[l], (size_t)q_dim * dim);
LOAD_F16_AS_F32(g_model.wk[l], g_model.wk_f16[l], (size_t)kv_dim * dim);
LOAD_F16_AS_F32(g_model.wv[l], g_model.wv_f16[l], (size_t)kv_dim * dim);
LOAD_F16_AS_F32(g_model.wo[l], g_model.wo_f16[l], (size_t)dim * q_dim);
#undef LOAD_F16_AS_F32
} else {
g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f);
g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
}
// Biases always F32
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
@ -76,15 +153,52 @@ static int load_weights(const char *path) {
fread(g_model.k_bias[l], sizeof(float), kv_dim, f);
fread(g_model.v_bias[l], sizeof(float), kv_dim, f);
// FFN RMSNorm always F32
g_model.rms_ffn[l] = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_ffn[l], sizeof(float), dim, f);
g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f);
g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f);
g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float));
fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f);
if (is_q4) {
#define LOAD_Q4(q8ptr, out_d, in_d) do { \
size_t _nb = (size_t)(in_d) / Q4_BLOCK_SIZE; \
size_t _bytes = (size_t)(out_d) * _nb * Q4_BLOCK_BYTES; \
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
fread(q8ptr, 1, _bytes, f); \
} while(0)
LOAD_Q4(g_model.wgate_q8[l], hidden, dim);
LOAD_Q4(g_model.wup_q8[l], hidden, dim);
LOAD_Q4(g_model.wdown_q8[l], dim, hidden);
#undef LOAD_Q4
} else if (is_q8) {
#define LOAD_Q8(q8ptr, out_d, in_d) do { \
size_t _nb = (size_t)(in_d) / Q8_BLOCK_SIZE; \
size_t _bytes = (size_t)(out_d) * _nb * Q8_BLOCK_BYTES; \
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
fread(q8ptr, 1, _bytes, f); \
} while(0)
LOAD_Q8(g_model.wgate_q8[l], hidden, dim);
LOAD_Q8(g_model.wup_q8[l], hidden, dim);
LOAD_Q8(g_model.wdown_q8[l], dim, hidden);
#undef LOAD_Q8
} else if (is_f16) {
#define LOAD_F16_AS_F32(f32ptr, f16ptr, n) do { \
size_t _n = (size_t)(n); \
f16ptr = (_Float16*)malloc(_n * sizeof(_Float16)); \
fread(f16ptr, sizeof(_Float16), _n, f); \
f32ptr = (float*)malloc(_n * sizeof(float)); \
convert_f16_to_f32(f16ptr, f32ptr, _n); \
} while(0)
LOAD_F16_AS_F32(g_model.w_gate[l], g_model.wgate_f16[l], (size_t)hidden * dim);
LOAD_F16_AS_F32(g_model.w_up[l], g_model.wup_f16[l], (size_t)hidden * dim);
LOAD_F16_AS_F32(g_model.w_down[l], g_model.wdown_f16[l], (size_t)dim * hidden);
#undef LOAD_F16_AS_F32
} else {
g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f);
g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f);
g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float));
fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f);
}
}
g_model.rms_final = (float*)malloc(dim * sizeof(float));
@ -92,7 +206,8 @@ static int load_weights(const char *path) {
long file_size = ftell(f);
fclose(f);
printf("Weights loaded (%.0f MB)\n", (float)file_size / 1024 / 1024);
printf("Weights loaded (%.0f MB, %s projections)\n",
(float)file_size / 1024 / 1024, fmt_str);
return 0;
}
@ -115,16 +230,25 @@ static double timespec_diff(struct timespec *a, struct timespec *b) {
}
// Run one generation pass. Writes output token IDs to out_ids, returns count.
// If out_fd >= 0, writes formatted results there; otherwise prints to stdout.
// Uses batched prefill (sgemm) for prompt, sequential decode (sgemv) for generation.
static int generate(int *prompt_ids, int n_prompt, int max_gen,
int *out_ids, int max_out,
double *prefill_tps, double *decode_tps) {
struct timespec t0, t1, t_pre;
clock_gettime(CLOCK_MONOTONIC, &t0);
int next = 0;
for (int i = 0; i < n_prompt; i++)
next = qwen_forward(&g_model, prompt_ids[i]);
int next;
if (g_model.use_ane) {
for (int i = 0; i < n_prompt; i++)
next = qwen_forward_ane(&g_model, prompt_ids[i]);
} else if (n_prompt > 1 && g_model.weight_fmt == 3) {
next = qwen_prefill_q4(&g_model, prompt_ids, n_prompt);
} else if (n_prompt > 1 && g_model.weight_fmt != 2) {
next = qwen_prefill(&g_model, prompt_ids, n_prompt);
} else {
for (int i = 0; i < n_prompt; i++)
next = qwen_forward(&g_model, prompt_ids[i]);
}
clock_gettime(CLOCK_MONOTONIC, &t_pre);
double ps = timespec_diff(&t0, &t_pre);
@ -135,7 +259,10 @@ static int generate(int *prompt_ids, int n_prompt, int max_gen,
for (int i = 0; i < max_gen && n_out < max_out; i++) {
if (n_out < max_out) out_ids[n_out++] = next;
if (next == eos || next == eos2) break;
next = qwen_forward(&g_model, next);
if (g_model.use_ane)
next = qwen_forward_ane(&g_model, next);
else
next = qwen_forward(&g_model, next);
}
clock_gettime(CLOCK_MONOTONIC, &t1);
@ -427,6 +554,7 @@ int main(int argc, char **argv) {
int server_mode = 0;
int http_port = 0;
int test_ane = 0;
int use_ane = 0;
const char *sock_path = NULL;
const char *model_dir = NULL;
for (int i = 2; i < argc; i++) {
@ -442,6 +570,61 @@ int main(int argc, char **argv) {
else { fprintf(stderr, "--model-dir requires a path\n"); return 1; }
} else if (strcmp(argv[i], "--test-ane") == 0) {
test_ane = 1;
} else if (strcmp(argv[i], "--ane") == 0) {
use_ane = 1;
}
}
// Q4 CPU mode: dequantize Q4 to F32 at load time, use AMX cblas_sgemv
if (g_model.weight_fmt == 3) {
printf("Dequantizing Q4→F32 for AMX acceleration...\n");
int q_dim = QWEN_Q_DIM, kv_dim = QWEN_KV_DIM, dim = QWEN_DIM;
int hidden = QWEN_HIDDEN;
#define DEQUANT_Q4_TO_F32(f32ptr, q4ptr, out_d, in_d) do { \
size_t _n = (size_t)(out_d) * (in_d); \
f32ptr = (float*)malloc(_n * sizeof(float)); \
dequant_q4_to_f32(q4ptr, f32ptr, (in_d), (out_d)); \
free(q4ptr); q4ptr = NULL; \
} while(0)
for (int l = 0; l < QWEN_LAYERS; l++) {
DEQUANT_Q4_TO_F32(g_model.wq[l], g_model.wq_q8[l], q_dim, dim);
DEQUANT_Q4_TO_F32(g_model.wk[l], g_model.wk_q8[l], kv_dim, dim);
DEQUANT_Q4_TO_F32(g_model.wv[l], g_model.wv_q8[l], kv_dim, dim);
DEQUANT_Q4_TO_F32(g_model.wo[l], g_model.wo_q8[l], dim, q_dim);
DEQUANT_Q4_TO_F32(g_model.w_gate[l], g_model.wgate_q8[l], hidden, dim);
DEQUANT_Q4_TO_F32(g_model.w_up[l], g_model.wup_q8[l], hidden, dim);
DEQUANT_Q4_TO_F32(g_model.w_down[l], g_model.wdown_q8[l], dim, hidden);
}
#undef DEQUANT_Q4_TO_F32
g_model.weight_fmt = 0;
printf("Q4→F32 done. Using AMX cblas_sgemv (91+ t/s decode).\n");
}
// ANE fused kernel compilation (requires F32 weights for baked-weight convs)
if (use_ane) {
if (g_model.weight_fmt != 0) {
printf("--ane requires F32 weights (weight_fmt=0). Got fmt=%d\n", g_model.weight_fmt);
printf("Re-run with F32 weight file (convert_weights.py without --f16/--q4/--q8)\n");
use_ane = 0;
} else {
struct timespec ta0, ta1;
clock_gettime(CLOCK_MONOTONIC, &ta0);
qwen_compile_kernels_fused(&g_model);
clock_gettime(CLOCK_MONOTONIC, &ta1);
double ane_sec = timespec_diff(&ta0, &ta1);
printf("ANE fused compile time: %.1fs\n", ane_sec);
// Verify at least one QKV kernel compiled
if (g_model.k_qkv[0] && g_model.k_o[0] && g_model.k_ffn_up[0] && g_model.k_down[0]) {
g_model.use_ane = 1;
printf("ANE fused mode active: 112 kernels (QKV+FFN_up fused)\n");
} else {
printf("ANE fused compilation failed, falling back to CPU\n");
use_ane = 0;
}
}
}

921
inference/matmul.metal Normal file
View File

@ -0,0 +1,921 @@
#include <metal_stdlib>
using namespace metal;
// ── Q4_0 block format ────────────────────────────────────────────────
// Block of 32 values: 2 bytes F16 scale + 2 bytes F16 zero + 16 bytes packed uint8
// Each uint8 stores 2 values: low nibble = even index, high nibble = odd index
// Total: 20 bytes per block of 32 weights
#define Q4_BLOCK_SIZE 32
#define Q4_BLOCK_BYTES 20
// ── Q4 Matrix-vector multiply (legacy, 1 thread per row) ────────────
// Kept as fallback for edge cases.
kernel void sgemv_q4(
device const uint8_t *W [[buffer(0)]],
device const float *x [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &in_dim [[buffer(3)]],
constant uint &out_dim [[buffer(4)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= out_dim) return;
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
device const uint8_t *row = W + uint64_t(gid) * row_bytes;
float sum = 0.0f;
for (uint b = 0; b < n_blocks; b++) {
device const uint8_t *block = row + b * Q4_BLOCK_BYTES;
half scale_h, zero_h;
scale_h = *reinterpret_cast<device const half*>(block);
zero_h = *reinterpret_cast<device const half*>(block + 2);
float scale = float(scale_h);
float zero = float(zero_h);
device const uint8_t *packed = block + 4;
uint base = b * Q4_BLOCK_SIZE;
for (uint i = 0; i < 16; i++) {
uint8_t byte = packed[i];
float w0 = float(byte & 0xF) * scale + zero;
float w1 = float(byte >> 4) * scale + zero;
sum += w0 * x[base + i * 2];
sum += w1 * x[base + i * 2 + 1];
}
}
y[gid] = sum;
}
// ── Q4 SIMD-optimized matrix-vector multiply ─────────────────────────
// MLX-style cooperative SIMD kernel: 2 SIMD groups per threadgroup,
// each SIMD group handles ROWS_PER_SIMD output rows cooperatively.
// 32 threads in a SIMD group split the K (input) dimension, then
// reduce via simd_sum(). No threadgroup memory needed.
//
// Threadgroup layout: 64 threads = 2 SIMD groups of 32
// Grid: (ceil(out_dim / ROWS_PER_TG), 1, 1) threadgroups
//
// Optional bias: if bias pointer is non-null (use_bias != 0),
// y[r] = dot(W[r], x) + bias[r]
#define ROWS_PER_SIMD 4
#define SIMD_GROUPS 2
#define ROWS_PER_TG (ROWS_PER_SIMD * SIMD_GROUPS)
kernel void sgemv_q4_fast(
device const uint8_t *W [[buffer(0)]],
device const float *x [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &in_dim [[buffer(3)]],
constant uint &out_dim [[buffer(4)]],
device const float *bias [[buffer(5)]],
constant uint &use_bias [[buffer(6)]],
uint tgid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]])
{
uint base_row = tgid * ROWS_PER_TG + simd_gid * ROWS_PER_SIMD;
if (base_row >= out_dim) return;
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
uint rows_this = min((uint)ROWS_PER_SIMD, out_dim - base_row);
float accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
float zero_accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
// Each of 32 SIMD lanes processes a stripe of blocks.
// Lane i processes blocks i, i+32, i+64, ...
for (uint b = simd_lid; b < n_blocks; b += 32) {
uint k_base = b * Q4_BLOCK_SIZE;
// Load input vector segment for this block (32 floats)
float xv[Q4_BLOCK_SIZE];
for (uint j = 0; j < 16; j++) {
xv[j * 2] = x[k_base + j * 2];
xv[j * 2 + 1] = x[k_base + j * 2 + 1];
}
for (uint r = 0; r < rows_this; r++) {
device const uint8_t *block =
W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
half scale_h = *reinterpret_cast<device const half*>(block);
half zero_h = *reinterpret_cast<device const half*>(block + 2);
float scale = float(scale_h);
float zero = float(zero_h);
device const uint8_t *packed = block + 4;
float dot = 0.0f;
float xsum = 0.0f;
for (uint j = 0; j < 16; j++) {
uint8_t byte = packed[j];
float w0 = float(byte & 0xF);
float w1 = float(byte >> 4);
dot += w0 * xv[j * 2] + w1 * xv[j * 2 + 1];
xsum += xv[j * 2] + xv[j * 2 + 1];
}
accum[r] += dot * scale;
zero_accum[r] += xsum * zero;
}
}
// SIMD reduction across 32 lanes
for (uint r = 0; r < rows_this; r++) {
float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]);
if (simd_lid == 0) {
if (use_bias != 0) {
result += bias[base_row + r];
}
y[base_row + r] = result;
}
}
}
// ── Fused Gate+Up+SiLU: reads x once, computes gate=silu(Wg*x)*Wu*x ──
// Combines two Q4 matvecs + silu_mul into one kernel.
// W_gate and W_up have the same dimensions [out_dim, in_dim].
// Output: gate[r] = silu(dot(W_gate[r], x)) * dot(W_up[r], x)
#define FUSED_ROWS_PER_SIMD 2
#define FUSED_SIMD_GROUPS 2
#define FUSED_ROWS_PER_TG (FUSED_ROWS_PER_SIMD * FUSED_SIMD_GROUPS)
kernel void sgemv_q4_fused_ffn(
device const uint8_t *W_gate [[buffer(0)]],
device const uint8_t *W_up [[buffer(1)]],
device const float *x [[buffer(2)]],
device float *out [[buffer(3)]],
constant uint &in_dim [[buffer(4)]],
constant uint &out_dim [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]])
{
uint base_row = tgid * FUSED_ROWS_PER_TG + simd_gid * FUSED_ROWS_PER_SIMD;
if (base_row >= out_dim) return;
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
uint rows_this = min((uint)FUSED_ROWS_PER_SIMD, out_dim - base_row);
float gate_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
float gate_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
float up_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
float up_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
for (uint b = simd_lid; b < n_blocks; b += 32) {
uint k_base = b * Q4_BLOCK_SIZE;
float xv[Q4_BLOCK_SIZE];
for (uint j = 0; j < 16; j++) {
xv[j * 2] = x[k_base + j * 2];
xv[j * 2 + 1] = x[k_base + j * 2 + 1];
}
for (uint r = 0; r < rows_this; r++) {
uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
// Gate weight block
device const uint8_t *g_block = W_gate + row_off;
float g_scale = float(*reinterpret_cast<device const half*>(g_block));
float g_zero = float(*reinterpret_cast<device const half*>(g_block + 2));
device const uint8_t *g_packed = g_block + 4;
// Up weight block
device const uint8_t *u_block = W_up + row_off;
float u_scale = float(*reinterpret_cast<device const half*>(u_block));
float u_zero = float(*reinterpret_cast<device const half*>(u_block + 2));
device const uint8_t *u_packed = u_block + 4;
float g_dot = 0.0f, g_xsum = 0.0f;
float u_dot = 0.0f, u_xsum = 0.0f;
for (uint j = 0; j < 16; j++) {
float x0 = xv[j * 2];
float x1 = xv[j * 2 + 1];
float xs = x0 + x1;
uint8_t gb = g_packed[j];
g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1;
g_xsum += xs;
uint8_t ub = u_packed[j];
u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1;
u_xsum += xs;
}
gate_acc[r] += g_dot * g_scale;
gate_zacc[r] += g_xsum * g_zero;
up_acc[r] += u_dot * u_scale;
up_zacc[r] += u_xsum * u_zero;
}
}
for (uint r = 0; r < rows_this; r++) {
float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]);
float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]);
if (simd_lid == 0) {
float s = g / (1.0f + exp(-g));
out[base_row + r] = s * u;
}
}
}
// ── Q4 batched matrix-matrix multiply (SGEMM) for prefill ────────────
// Y[t, r] = sum_k(dequant(W[r, k]) * X[t, k]) for t in [0, n_tokens), r in [0, out_dim)
// Grid: (ceil(out_dim / GEMM_TILE_M), n_tokens, 1)
// Each threadgroup: 2 SIMD groups, each handles GEMM_TILE_M/2 output rows for one token.
#define GEMM_TILE_M 8
#define GEMM_SIMD_GROUPS 2
#define GEMM_ROWS_PER_SIMD (GEMM_TILE_M / GEMM_SIMD_GROUPS)
kernel void sgemm_q4(
device const uint8_t *W [[buffer(0)]],
device const float *X [[buffer(1)]],
device float *Y [[buffer(2)]],
constant uint &in_dim [[buffer(3)]],
constant uint &out_dim [[buffer(4)]],
device const float *bias [[buffer(5)]],
constant uint &use_bias [[buffer(6)]],
constant uint &n_tokens [[buffer(7)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]])
{
uint base_row = tgid.x * GEMM_TILE_M + simd_gid * GEMM_ROWS_PER_SIMD;
uint t = tgid.y;
if (base_row >= out_dim || t >= n_tokens) return;
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
uint rows_this = min((uint)GEMM_ROWS_PER_SIMD, out_dim - base_row);
device const float *xt = X + uint64_t(t) * in_dim;
float accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
float zero_accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
for (uint b = simd_lid; b < n_blocks; b += 32) {
uint k_base = b * Q4_BLOCK_SIZE;
float xv[Q4_BLOCK_SIZE];
for (uint j = 0; j < 16; j++) {
xv[j * 2] = xt[k_base + j * 2];
xv[j * 2 + 1] = xt[k_base + j * 2 + 1];
}
for (uint r = 0; r < rows_this; r++) {
device const uint8_t *block =
W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
float scale = float(*reinterpret_cast<device const half*>(block));
float zero = float(*reinterpret_cast<device const half*>(block + 2));
device const uint8_t *packed = block + 4;
float dot = 0.0f;
float xsum = 0.0f;
for (uint j = 0; j < 16; j++) {
uint8_t byte = packed[j];
dot += float(byte & 0xF) * xv[j * 2] + float(byte >> 4) * xv[j * 2 + 1];
xsum += xv[j * 2] + xv[j * 2 + 1];
}
accum[r] += dot * scale;
zero_accum[r] += xsum * zero;
}
}
for (uint r = 0; r < rows_this; r++) {
float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]);
if (simd_lid == 0) {
if (use_bias != 0)
result += bias[base_row + r];
Y[uint64_t(t) * out_dim + base_row + r] = result;
}
}
}
// ── Q4 batched fused Gate+Up+SiLU (SGEMM variant) ───────────────────
// out[t, r] = silu(Wg[r] . X[t]) * Wu[r] . X[t] for all t and r
// Grid: (ceil(out_dim / GEMM_FFN_TILE_M), n_tokens, 1)
#define GEMM_FFN_TILE_M 4
#define GEMM_FFN_SIMD_GROUPS 2
#define GEMM_FFN_ROWS_PER_SIMD (GEMM_FFN_TILE_M / GEMM_FFN_SIMD_GROUPS)
kernel void sgemm_q4_fused_ffn(
device const uint8_t *W_gate [[buffer(0)]],
device const uint8_t *W_up [[buffer(1)]],
device const float *X [[buffer(2)]],
device float *out [[buffer(3)]],
constant uint &in_dim [[buffer(4)]],
constant uint &out_dim [[buffer(5)]],
constant uint &n_tokens [[buffer(6)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]])
{
uint base_row = tgid.x * GEMM_FFN_TILE_M + simd_gid * GEMM_FFN_ROWS_PER_SIMD;
uint t = tgid.y;
if (base_row >= out_dim || t >= n_tokens) return;
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
uint rows_this = min((uint)GEMM_FFN_ROWS_PER_SIMD, out_dim - base_row);
device const float *xt = X + uint64_t(t) * in_dim;
float gate_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
float gate_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
float up_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
float up_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
for (uint b = simd_lid; b < n_blocks; b += 32) {
uint k_base = b * Q4_BLOCK_SIZE;
float xv[Q4_BLOCK_SIZE];
for (uint j = 0; j < 16; j++) {
xv[j * 2] = xt[k_base + j * 2];
xv[j * 2 + 1] = xt[k_base + j * 2 + 1];
}
for (uint r = 0; r < rows_this; r++) {
uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
device const uint8_t *g_block = W_gate + row_off;
float g_scale = float(*reinterpret_cast<device const half*>(g_block));
float g_zero = float(*reinterpret_cast<device const half*>(g_block + 2));
device const uint8_t *g_packed = g_block + 4;
device const uint8_t *u_block = W_up + row_off;
float u_scale = float(*reinterpret_cast<device const half*>(u_block));
float u_zero = float(*reinterpret_cast<device const half*>(u_block + 2));
device const uint8_t *u_packed = u_block + 4;
float g_dot = 0.0f, g_xsum = 0.0f;
float u_dot = 0.0f, u_xsum = 0.0f;
for (uint j = 0; j < 16; j++) {
float x0 = xv[j * 2];
float x1 = xv[j * 2 + 1];
float xs = x0 + x1;
uint8_t gb = g_packed[j];
g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1;
g_xsum += xs;
uint8_t ub = u_packed[j];
u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1;
u_xsum += xs;
}
gate_acc[r] += g_dot * g_scale;
gate_zacc[r] += g_xsum * g_zero;
up_acc[r] += u_dot * u_scale;
up_zacc[r] += u_xsum * u_zero;
}
}
for (uint r = 0; r < rows_this; r++) {
float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]);
float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]);
if (simd_lid == 0) {
float s = g / (1.0f + exp(-g));
out[uint64_t(t) * out_dim + base_row + r] = s * u;
}
}
}
// ── Batched RMSNorm (N tokens) ──────────────────────────────────────
// x[t*dim .. (t+1)*dim-1] → out[t*dim .. (t+1)*dim-1]
// Grid: (n_tokens, 1, 1) threadgroups, each normalizes one token.
kernel void rms_norm_batched(
device const float *x [[buffer(0)]],
device const float *w [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &dim [[buffer(3)]],
constant float &eps [[buffer(4)]],
constant uint &n_tokens [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint tpg [[threads_per_threadgroup]])
{
if (tgid >= n_tokens) return;
device const float *xi = x + uint64_t(tgid) * dim;
device float *oi = out + uint64_t(tgid) * dim;
threadgroup float partial[1024];
float local_sum = 0.0f;
for (uint i = tid; i < dim; i += tpg)
local_sum += xi[i] * xi[i];
partial[tid] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s) partial[tid] += partial[tid + s];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float rms_inv = rsqrt(partial[0] / float(dim) + eps);
for (uint i = tid; i < dim; i += tpg)
oi[i] = xi[i] * rms_inv * w[i];
}
// ── Batched embedding lookup (N tokens) ─────────────────────────────
// Grid: (dim, n_tokens, 1). Each thread copies one element.
kernel void embed_lookup_batched(
device const float *embed [[buffer(0)]],
device float *out [[buffer(1)]],
device const uint *token_ids [[buffer(2)]],
constant uint &dim [[buffer(3)]],
uint2 gid [[thread_position_in_grid]])
{
uint i = gid.x;
uint t = gid.y;
if (i >= dim) return;
out[uint64_t(t) * dim + i] = embed[uint64_t(token_ids[t]) * dim + i];
}
// ── Batched RoPE (N tokens) ─────────────────────────────────────────
// Applies RoPE to Q[t] and K[t] for each token t at position base_pos+t.
// Grid: total_pairs per token * n_tokens
kernel void rope_apply_batched(
device float *q [[buffer(0)]],
device float *k [[buffer(1)]],
device const float *cos_tbl [[buffer(2)]],
device const float *sin_tbl [[buffer(3)]],
constant uint &n_q_heads [[buffer(4)]],
constant uint &n_kv_heads [[buffer(5)]],
constant uint &head_dim [[buffer(6)]],
constant uint &base_pos [[buffer(7)]],
constant uint &q_stride [[buffer(8)]],
constant uint &k_stride [[buffer(9)]],
uint2 gid [[thread_position_in_grid]])
{
uint pair_idx = gid.x;
uint t = gid.y;
uint half_dim = head_dim / 2;
uint total_pairs = (n_q_heads + n_kv_heads) * half_dim;
if (pair_idx >= total_pairs) return;
uint head_pair = pair_idx / half_dim;
uint i = pair_idx % half_dim;
uint pos = base_pos + t;
device float *vec;
if (head_pair < n_q_heads)
vec = q + uint64_t(t) * q_stride + head_pair * head_dim;
else
vec = k + uint64_t(t) * k_stride + (head_pair - n_q_heads) * head_dim;
uint cos_off = pos * half_dim;
float f = vec[i];
float s = vec[i + half_dim];
float c = cos_tbl[cos_off + i];
float sv = sin_tbl[cos_off + i];
vec[i] = f * c - s * sv;
vec[i + half_dim] = s * c + f * sv;
}
// ── Batched vec_add: out[i] = a[i] + b[i] for N*dim elements ────────
kernel void vec_add_batched(
device const float *a [[buffer(0)]],
device const float *b [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &total_n [[buffer(3)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= total_n) return;
out[gid] = a[gid] + b[gid];
}
// ── F16 matrix-vector multiply ───────────────────────────────────────
kernel void sgemv_f16(
device const half *W [[buffer(0)]],
device const float *x [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &in_dim [[buffer(3)]],
constant uint &out_dim [[buffer(4)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= out_dim) return;
device const half *row = W + uint64_t(gid) * in_dim;
float sum = 0.0f;
uint i = 0;
for (; i + 7 < in_dim; i += 8) {
sum += float(row[i]) * x[i];
sum += float(row[i + 1]) * x[i + 1];
sum += float(row[i + 2]) * x[i + 2];
sum += float(row[i + 3]) * x[i + 3];
sum += float(row[i + 4]) * x[i + 4];
sum += float(row[i + 5]) * x[i + 5];
sum += float(row[i + 6]) * x[i + 6];
sum += float(row[i + 7]) * x[i + 7];
}
for (; i < in_dim; i++)
sum += float(row[i]) * x[i];
y[gid] = sum;
}
// ── F32 matrix-vector multiply ───────────────────────────────────────
kernel void sgemv_f32(
device const float *W [[buffer(0)]],
device const float *x [[buffer(1)]],
device float *y [[buffer(2)]],
constant uint &in_dim [[buffer(3)]],
constant uint &out_dim [[buffer(4)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= out_dim) return;
device const float *row = W + uint64_t(gid) * in_dim;
float sum = 0.0f;
uint i = 0;
for (; i + 7 < in_dim; i += 8) {
sum += row[i] * x[i];
sum += row[i + 1] * x[i + 1];
sum += row[i + 2] * x[i + 2];
sum += row[i + 3] * x[i + 3];
sum += row[i + 4] * x[i + 4];
sum += row[i + 5] * x[i + 5];
sum += row[i + 6] * x[i + 6];
sum += row[i + 7] * x[i + 7];
}
for (; i < in_dim; i++)
sum += row[i] * x[i];
y[gid] = sum;
}
// ── RMS Normalization ────────────────────────────────────────────────
// out[i] = x[i] * w[i] / sqrt(mean(x^2) + eps)
// Two-pass: first compute sum of squares (reduction), then normalize.
// Single threadgroup processes the entire vector.
kernel void rms_norm(
device const float *x [[buffer(0)]],
device const float *w [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &dim [[buffer(3)]],
constant float &eps [[buffer(4)]],
uint tid [[thread_index_in_threadgroup]],
uint tpg [[threads_per_threadgroup]])
{
threadgroup float partial[1024];
float local_sum = 0.0f;
for (uint i = tid; i < dim; i += tpg)
local_sum += x[i] * x[i];
partial[tid] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
// Tree reduction
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s) partial[tid] += partial[tid + s];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float rms_inv = rsqrt(partial[0] / float(dim) + eps);
for (uint i = tid; i < dim; i += tpg)
out[i] = x[i] * rms_inv * w[i];
}
// ── RoPE (Rotary Position Embedding) ─────────────────────────────────
// Applies RoPE to Q and K vectors in-place.
// cos_sin is precomputed: [half_dim] cos values followed by [half_dim] sin values.
kernel void rope_apply(
device float *q [[buffer(0)]],
device float *k [[buffer(1)]],
device const float *cos_v [[buffer(2)]],
device const float *sin_v [[buffer(3)]],
constant uint &n_q_heads [[buffer(4)]],
constant uint &n_kv_heads [[buffer(5)]],
constant uint &head_dim [[buffer(6)]],
uint gid [[thread_position_in_grid]])
{
uint half_dim = head_dim / 2;
uint total_pairs = (n_q_heads + n_kv_heads) * half_dim;
if (gid >= total_pairs) return;
uint head_pair = gid / half_dim;
uint i = gid % half_dim;
device float *vec;
if (head_pair < n_q_heads) {
vec = q + head_pair * head_dim;
} else {
vec = k + (head_pair - n_q_heads) * head_dim;
}
float f = vec[i];
float s = vec[i + half_dim];
float c = cos_v[i];
float sv = sin_v[i];
vec[i] = f * c - s * sv;
vec[i + half_dim] = s * c + f * sv;
}
// ── SiLU activation + element-wise multiply ──────────────────────────
// gate[i] = silu(gate[i]) * up[i]
// silu(x) = x / (1 + exp(-x))
kernel void silu_mul(
device float *gate [[buffer(0)]],
device const float *up [[buffer(1)]],
constant uint &n [[buffer(2)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= n) return;
float x = gate[gid];
float s = x / (1.0f + exp(-x));
gate[gid] = s * up[gid];
}
// ── Vector add (residual connection) ─────────────────────────────────
// out[i] = a[i] + b[i]
kernel void vec_add(
device const float *a [[buffer(0)]],
device const float *b [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &n [[buffer(3)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= n) return;
out[gid] = a[gid] + b[gid];
}
// ── Bias add ─────────────────────────────────────────────────────────
// x[i] += bias[i]
kernel void bias_add(
device float *x [[buffer(0)]],
device const float *bias [[buffer(1)]],
constant uint &n [[buffer(2)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= n) return;
x[gid] += bias[gid];
}
// ── Embedding lookup ─────────────────────────────────────────────────
// out[i] = embed[token_id * dim + i]
kernel void embed_lookup(
device const float *embed [[buffer(0)]],
device float *out [[buffer(1)]],
constant uint &token_id [[buffer(2)]],
constant uint &dim [[buffer(3)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= dim) return;
out[gid] = embed[uint64_t(token_id) * dim + gid];
}
// ── Attention score: Q @ K^T for one head (legacy) ──────────────────
kernel void attn_score(
device const float *qh [[buffer(0)]],
device const float *kv_cache_k [[buffer(1)]],
device float *att [[buffer(2)]],
constant uint &head_dim [[buffer(3)]],
constant uint &kv_dim [[buffer(4)]],
constant uint &kv_head_offset [[buffer(5)]],
constant float &scale [[buffer(6)]],
constant uint &seq_len [[buffer(7)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= seq_len) return;
device const float *kt = kv_cache_k + uint64_t(gid) * kv_dim + kv_head_offset;
float dot = 0.0f;
for (uint i = 0; i < head_dim; i++)
dot += qh[i] * kt[i];
att[gid] = dot * scale;
}
// ── Batched attention score: all Q heads in one dispatch ─────────────
// Grid: (seq_len, n_q_heads, 1). Each thread computes one score for one head.
// GQA: maps Q head h to KV head h/gqa_factor.
kernel void attn_score_batched(
device const float *q [[buffer(0)]],
device const float *kv_cache_k [[buffer(1)]],
device float *att [[buffer(2)]],
constant uint &head_dim [[buffer(3)]],
constant uint &kv_dim [[buffer(4)]],
constant uint &n_q_heads [[buffer(5)]],
constant uint &gqa_factor [[buffer(6)]],
constant float &scale [[buffer(7)]],
constant uint &seq_len [[buffer(8)]],
constant uint &max_seq [[buffer(9)]],
uint2 gid [[thread_position_in_grid]])
{
uint t = gid.x;
uint h = gid.y;
if (t >= seq_len || h >= n_q_heads) return;
uint kv_h = h / gqa_factor;
device const float *qh = q + h * head_dim;
device const float *kt = kv_cache_k + uint64_t(t) * kv_dim + kv_h * head_dim;
float dot = 0.0f;
for (uint i = 0; i < head_dim; i++)
dot += qh[i] * kt[i];
att[h * max_seq + t] = dot * scale;
}
// ── Batched softmax: all heads in one dispatch ───────────────────────
// One threadgroup per head. tid reduces over seq_len dimension.
kernel void softmax_batched(
device float *att [[buffer(0)]],
constant uint &seq_len [[buffer(1)]],
constant uint &max_seq [[buffer(2)]],
constant uint &n_q_heads [[buffer(3)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint tpg [[threads_per_threadgroup]])
{
uint h = tgid;
if (h >= n_q_heads) return;
device float *head_att = att + h * max_seq;
threadgroup float shared[1024];
float local_max = -1e30f;
for (uint i = tid; i < seq_len; i += tpg)
local_max = max(local_max, head_att[i]);
shared[tid] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float max_val = shared[0];
float local_sum = 0.0f;
for (uint i = tid; i < seq_len; i += tpg) {
float e = exp(head_att[i] - max_val);
head_att[i] = e;
local_sum += e;
}
shared[tid] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] += shared[tid + s];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float inv_sum = 1.0f / shared[0];
for (uint i = tid; i < seq_len; i += tpg)
head_att[i] *= inv_sum;
}
// ── Batched attention weighted sum: all heads in one dispatch ────────
// Grid: (head_dim, n_q_heads, 1). Each thread computes one output dim for one head.
kernel void attn_wsum_batched(
device const float *att [[buffer(0)]],
device const float *kv_cache_v [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &head_dim [[buffer(3)]],
constant uint &kv_dim [[buffer(4)]],
constant uint &n_q_heads [[buffer(5)]],
constant uint &gqa_factor [[buffer(6)]],
constant uint &seq_len [[buffer(7)]],
constant uint &max_seq [[buffer(8)]],
uint2 gid [[thread_position_in_grid]])
{
uint d = gid.x;
uint h = gid.y;
if (d >= head_dim || h >= n_q_heads) return;
uint kv_h = h / gqa_factor;
device const float *head_att = att + h * max_seq;
float sum = 0.0f;
for (uint t = 0; t < seq_len; t++) {
float a = head_att[t];
float v = kv_cache_v[uint64_t(t) * kv_dim + kv_h * head_dim + d];
sum += a * v;
}
out[h * head_dim + d] = sum;
}
// ── Softmax (legacy, single head) ───────────────────────────────────
kernel void softmax_inplace(
device float *att [[buffer(0)]],
constant uint &seq_len [[buffer(1)]],
uint tid [[thread_index_in_threadgroup]],
uint tpg [[threads_per_threadgroup]])
{
threadgroup float shared[1024];
float local_max = -1e30f;
for (uint i = tid; i < seq_len; i += tpg)
local_max = max(local_max, att[i]);
shared[tid] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float max_val = shared[0];
float local_sum = 0.0f;
for (uint i = tid; i < seq_len; i += tpg) {
float e = exp(att[i] - max_val);
att[i] = e;
local_sum += e;
}
shared[tid] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s) shared[tid] += shared[tid + s];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float inv_sum = 1.0f / shared[0];
for (uint i = tid; i < seq_len; i += tpg)
att[i] *= inv_sum;
}
// ── Attention weighted sum (legacy, single head) ─────────────────────
kernel void attn_weighted_sum(
device const float *att [[buffer(0)]],
device const float *kv_cache_v [[buffer(1)]],
device float *out [[buffer(2)]],
constant uint &head_dim [[buffer(3)]],
constant uint &kv_dim [[buffer(4)]],
constant uint &kv_head_offset [[buffer(5)]],
constant uint &seq_len [[buffer(6)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= head_dim) return;
float sum = 0.0f;
for (uint t = 0; t < seq_len; t++) {
float a = att[t];
float v = kv_cache_v[uint64_t(t) * kv_dim + kv_head_offset + gid];
sum += a * v;
}
out[gid] = sum;
}
// ── Argmax ───────────────────────────────────────────────────────────
// Finds argmax of logits[0..n-1], writes to result[0].
// Single threadgroup.
kernel void argmax_kernel(
device const float *logits [[buffer(0)]],
device int *result [[buffer(1)]],
constant uint &n [[buffer(2)]],
uint tid [[thread_index_in_threadgroup]],
uint tpg [[threads_per_threadgroup]])
{
threadgroup float shared_val[1024];
threadgroup int shared_idx[1024];
float local_max = -1e30f;
int local_idx = 0;
for (uint i = tid; i < n; i += tpg) {
if (logits[i] > local_max) {
local_max = logits[i];
local_idx = int(i);
}
}
shared_val[tid] = local_max;
shared_idx[tid] = local_idx;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = tpg / 2; s > 0; s >>= 1) {
if (tid < s && shared_val[tid + s] > shared_val[tid]) {
shared_val[tid] = shared_val[tid + s];
shared_idx[tid] = shared_idx[tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tid == 0) result[0] = shared_idx[0];
}
// ── Copy kernel ──────────────────────────────────────────────────────
// dst[i] = src[i] for i in [0, n)
kernel void vec_copy(
device const float *src [[buffer(0)]],
device float *dst [[buffer(1)]],
constant uint &n [[buffer(2)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= n) return;
dst[gid] = src[gid];
}
// ── Zero-fill ────────────────────────────────────────────────────────
kernel void vec_zero(
device float *dst [[buffer(0)]],
constant uint &n [[buffer(1)]],
uint gid [[thread_position_in_grid]])
{
if (gid >= n) return;
dst[gid] = 0.0f;
}

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,8 @@ MODEL_DIR="$HOME/models/Qwen2.5-0.5B-Instruct"
WEIGHTS_BIN="$SCRIPT_DIR/qwen05b.bin"
BINARY="$SCRIPT_DIR/qwen_ane"
VENV_DIR="$SCRIPT_DIR/.venv"
EXPECTED_WEIGHT_SIZE=1976131100
EXPECTED_WEIGHT_SIZE_F32=1976131100
EXPECTED_WEIGHT_SIZE_F16=988082236
info() { printf "\033[1;34m==> %s\033[0m\n" "$1"; }
ok() { printf "\033[1;32m ✓ %s\033[0m\n" "$1"; }
@ -86,16 +87,16 @@ info "Converting weights to binary format..."
if [ -f "$WEIGHTS_BIN" ]; then
ACTUAL_SIZE=$(stat -f%z "$WEIGHTS_BIN" 2>/dev/null || stat -c%s "$WEIGHTS_BIN" 2>/dev/null)
if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE" ]; then
if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE_F16" ] || [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE_F32" ]; then
ok "Weights already converted ($((ACTUAL_SIZE / 1024 / 1024)) MB)"
else
warn "Weight file exists but wrong size ($ACTUAL_SIZE vs $EXPECTED_WEIGHT_SIZE), reconverting"
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN"
ok "Weights converted"
warn "Weight file exists but unexpected size ($ACTUAL_SIZE), reconverting as F16"
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" --f16
ok "Weights converted (F16)"
fi
else
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN"
ok "Weights converted"
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" --f16
ok "Weights converted (F16)"
fi
# --- Step 5: Build binary ---
@ -113,8 +114,10 @@ elif [ "$SCRIPT_DIR/main.m" -nt "$BINARY" ] || \
fi
if [ "$NEEDS_BUILD" -eq 1 ]; then
xcrun clang -O2 -framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
xcrun clang -O3 -ffast-math -mcpu=apple-m4 -flto \
-framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -framework Metal \
-ldl -lobjc -fobjc-arc \
-o "$BINARY" "$SCRIPT_DIR/main.m"
ok "Binary built: $BINARY"
else

View File

@ -232,6 +232,124 @@ static NSData *mil_build_ffn_up_weight_blob(const float *w1, const float *w3, in
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
}
// Generate MIL for fused GQA QKV: Q, K, V have different output dimensions
// Qwen2.5-0.5B: Q=[q_dim, dim], K=[kv_dim, dim], V=[kv_dim, dim]
// Weight blob: Wq[q_dim,dim] @ chunk0, Wk[kv_dim,dim] @ chunk1, Wv[kv_dim,dim] @ chunk2
static NSString *mil_gen_qkv_gqa(int dim, int q_dim, int kv_dim, int spatial) {
NSUInteger cs_q = 64 + (NSUInteger)q_dim * dim * 2;
NSUInteger cs_kv = 64 + (NSUInteger)kv_dim * dim * 2;
NSUInteger off_k = 64 + cs_q;
NSUInteger off_v = off_k + cs_kv;
if (g_fp16_io) {
return [NSString stringWithFormat:
@"program(1.0)\n"
"[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{\"coremlc-version\", \"3505.4.1\"}})]\n"
"{\n"
" func main<ios16>(tensor<fp16, [1, %d, 1, %d]> x) {\n"
" tensor<string, []> c_pad_type = const()[name = tensor<string, []>(\"c_pad_type\"), val = tensor<string, []>(\"valid\")];\n"
" tensor<int32, [2]> c_strides = const()[name = tensor<string, []>(\"c_strides\"), val = tensor<int32, [2]>([1, 1])];\n"
" tensor<int32, [4]> c_pad = const()[name = tensor<string, []>(\"c_pad\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
" tensor<int32, [2]> c_dilations = const()[name = tensor<string, []>(\"c_dilations\"), val = tensor<int32, [2]>([1, 1])];\n"
" tensor<int32, []> c_groups = const()[name = tensor<string, []>(\"c_groups\"), val = tensor<int32, []>(1)];\n"
" tensor<fp16, [%d, %d, 1, 1]> Wq = const()[name = tensor<string, []>(\"Wq\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(64)))];\n"
" tensor<fp16, [%d, %d, 1, 1]> Wk = const()[name = tensor<string, []>(\"Wk\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
" tensor<fp16, [%d, %d, 1, 1]> Wv = const()[name = tensor<string, []>(\"Wv\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
" tensor<fp16, [1, %d, 1, %d]> q = conv(dilations = c_dilations, groups = c_groups, "
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wq, x = x)[name = tensor<string, []>(\"conv_q\")];\n"
" tensor<fp16, [1, %d, 1, %d]> k = conv(dilations = c_dilations, groups = c_groups, "
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wk, x = x)[name = tensor<string, []>(\"conv_k\")];\n"
" tensor<fp16, [1, %d, 1, %d]> v = conv(dilations = c_dilations, groups = c_groups, "
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wv, x = x)[name = tensor<string, []>(\"conv_v\")];\n"
" } -> (q, k, v);\n"
"}\n",
dim, spatial,
q_dim, dim, q_dim, dim,
kv_dim, dim, kv_dim, dim, (unsigned long)off_k,
kv_dim, dim, kv_dim, dim, (unsigned long)off_v,
q_dim, spatial, kv_dim, spatial, kv_dim, spatial];
}
return [NSString stringWithFormat:
@"program(1.0)\n"
"[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{\"coremlc-version\", \"3505.4.1\"}})]\n"
"{\n"
" func main<ios16>(tensor<fp32, [1, %d, 1, %d]> x) {\n"
" tensor<string, []> c_pad_type = const()[name = tensor<string, []>(\"c_pad_type\"), val = tensor<string, []>(\"valid\")];\n"
" tensor<int32, [2]> c_strides = const()[name = tensor<string, []>(\"c_strides\"), val = tensor<int32, [2]>([1, 1])];\n"
" tensor<int32, [4]> c_pad = const()[name = tensor<string, []>(\"c_pad\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
" tensor<int32, [2]> c_dilations = const()[name = tensor<string, []>(\"c_dilations\"), val = tensor<int32, [2]>([1, 1])];\n"
" tensor<int32, []> c_groups = const()[name = tensor<string, []>(\"c_groups\"), val = tensor<int32, []>(1)];\n"
" tensor<string, []> to_fp16 = const()[name = tensor<string, []>(\"to_fp16\"), val = tensor<string, []>(\"fp16\")];\n"
" tensor<fp16, [1, %d, 1, %d]> x16 = cast(dtype = to_fp16, x = x)[name = tensor<string, []>(\"cast_in\")];\n"
" tensor<fp16, [%d, %d, 1, 1]> Wq = const()[name = tensor<string, []>(\"Wq\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(64)))];\n"
" tensor<fp16, [%d, %d, 1, 1]> Wk = const()[name = tensor<string, []>(\"Wk\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
" tensor<fp16, [%d, %d, 1, 1]> Wv = const()[name = tensor<string, []>(\"Wv\"), "
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
" tensor<fp16, [1, %d, 1, %d]> q16 = conv(dilations = c_dilations, groups = c_groups, "
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wq, x = x16)[name = tensor<string, []>(\"conv_q\")];\n"
" tensor<fp16, [1, %d, 1, %d]> k16 = conv(dilations = c_dilations, groups = c_groups, "
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wk, x = x16)[name = tensor<string, []>(\"conv_k\")];\n"
" tensor<fp16, [1, %d, 1, %d]> v16 = conv(dilations = c_dilations, groups = c_groups, "
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wv, x = x16)[name = tensor<string, []>(\"conv_v\")];\n"
" tensor<string, []> to_fp32 = const()[name = tensor<string, []>(\"to_fp32\"), val = tensor<string, []>(\"fp32\")];\n"
" tensor<fp32, [1, %d, 1, %d]> q = cast(dtype = to_fp32, x = q16)[name = tensor<string, []>(\"cast_q\")];\n"
" tensor<fp32, [1, %d, 1, %d]> k = cast(dtype = to_fp32, x = k16)[name = tensor<string, []>(\"cast_k\")];\n"
" tensor<fp32, [1, %d, 1, %d]> v = cast(dtype = to_fp32, x = v16)[name = tensor<string, []>(\"cast_v\")];\n"
" } -> (q, k, v);\n"
"}\n",
dim, spatial, dim, spatial,
q_dim, dim, q_dim, dim,
kv_dim, dim, kv_dim, dim, (unsigned long)off_k,
kv_dim, dim, kv_dim, dim, (unsigned long)off_v,
q_dim, spatial, kv_dim, spatial, kv_dim, spatial,
q_dim, spatial, kv_dim, spatial, kv_dim, spatial];
}
// Build weight blob for GQA QKV (3 weight matrices with different shapes)
static NSData *mil_build_qkv_gqa_weight_blob(const float *wq, int q_dim, int dim,
const float *wk, const float *wv, int kv_dim) {
NSUInteger wsize_q = (NSUInteger)q_dim * dim * 2;
NSUInteger wsize_kv = (NSUInteger)kv_dim * dim * 2;
NSUInteger cs_q = 64 + wsize_q;
NSUInteger cs_kv = 64 + wsize_kv;
NSUInteger total = 64 + cs_q + 2 * cs_kv;
uint8_t *buf = (uint8_t*)calloc(total, 1);
buf[0] = 0x01; buf[4] = 0x02;
// Chunk 0: Wq
{
uint8_t *chunk = buf + 64;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01;
*(uint32_t*)(chunk + 8) = (uint32_t)wsize_q;
*(uint32_t*)(chunk + 16) = (uint32_t)(64 + 64);
_Float16 *fp16 = (_Float16*)(chunk + 64);
for (NSUInteger i = 0; i < (NSUInteger)q_dim * dim; i++) fp16[i] = (_Float16)wq[i];
}
// Chunk 1: Wk
{
uint8_t *chunk = buf + 64 + cs_q;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01;
*(uint32_t*)(chunk + 8) = (uint32_t)wsize_kv;
*(uint32_t*)(chunk + 16) = (uint32_t)(64 + cs_q + 64);
_Float16 *fp16 = (_Float16*)(chunk + 64);
for (NSUInteger i = 0; i < (NSUInteger)kv_dim * dim; i++) fp16[i] = (_Float16)wk[i];
}
// Chunk 2: Wv
{
uint8_t *chunk = buf + 64 + cs_q + cs_kv;
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01;
*(uint32_t*)(chunk + 8) = (uint32_t)wsize_kv;
*(uint32_t*)(chunk + 16) = (uint32_t)(64 + cs_q + cs_kv + 64);
_Float16 *fp16 = (_Float16*)(chunk + 64);
for (NSUInteger i = 0; i < (NSUInteger)kv_dim * dim; i++) fp16[i] = (_Float16)wv[i];
}
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
}
// Generate MIL for fused FFN up: w1 + w3 parallel convs
static NSString *mil_gen_ffn_up(int dim, int hidden_dim, int spatial) {
NSUInteger cs = 64 + (NSUInteger)hidden_dim * dim * 2;