diff --git a/.gitignore b/.gitignore index 6969a85..b2eb942 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,9 @@ training/test_* # Inference binaries and runtime data inference/qwen_ane inference/qwen05b.bin +inference/qwen05b_f32.bin +inference/qwen05b_f16.bin +inference/qwen05b_q8.bin inference/.venv/ inference/benchmark_results.json @@ -59,6 +62,7 @@ web/ training/tinystories_data00.bin training/ane_stories110M_ckpt.bin *.bin +*.metallib !training/download_data.sh # Secrets / env diff --git a/inference/benchmark.sh b/inference/benchmark.sh index b623692..3573549 100755 --- a/inference/benchmark.sh +++ b/inference/benchmark.sh @@ -66,125 +66,235 @@ MEM_BYTES=$(sysctl -n hw.memsize 2>/dev/null || echo "0") MEM_GB=$((MEM_BYTES / 1073741824)) echo "" -info "=== ANE Inference Benchmark (qwen_ane) ===" +info "=== ANE Multi-Format Inference Benchmark ===" echo "Hardware: $CHIP" echo "macOS: $MACOS" echo "Memory: ${MEM_GB} GB" -echo "Model: Qwen2.5-0.5B-Instruct (BF16, 494M params)" +echo "Model: Qwen2.5-0.5B-Instruct (494M params)" echo "" -# --- Phase 1: Server mode benchmark (HTTP API) --- -info "Phase 1: Server mode (persistent ANE kernels via HTTP API)" -dim "Starting server on port $HTTP_PORT..." +# --- Phase 0: Prepare weight files (F16 + Q8) --- +WEIGHTS_F16="$SCRIPT_DIR/qwen05b.bin" +WEIGHTS_Q8="$SCRIPT_DIR/qwen05b_q8.bin" +WEIGHTS_Q4="$SCRIPT_DIR/qwen05b_q4.bin" +CONVERT="$SCRIPT_DIR/convert_weights.py" +VENV_DIR="$SCRIPT_DIR/.venv" -# Start HTTP server in background -"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 & -SERVER_PID=$! +info "Phase 0: Preparing weight files" +if [ ! -f "$WEIGHTS_Q8" ]; then + if [ ! -f "$CONVERT" ]; then + echo " convert_weights.py not found, skipping Q8 generation." + WEIGHTS_Q8="" + else + dim "Generating Q8 weights (one-time)..." + if [ -d "$VENV_DIR" ]; then + source "$VENV_DIR/bin/activate" + fi + python3 "$CONVERT" "$MODEL_DIR" "$WEIGHTS_Q8" --q8 + dim "Q8 weights ready: $(du -h "$WEIGHTS_Q8" | cut -f1)" + fi +else + dim "Q8 weights already exist: $(du -h "$WEIGHTS_Q8" | cut -f1)" +fi + +if [ ! -f "$WEIGHTS_Q4" ]; then + if [ ! -f "$CONVERT" ]; then + echo " convert_weights.py not found, skipping Q4 generation." + WEIGHTS_Q4="" + else + dim "Generating Q4 weights (one-time)..." + if [ -d "$VENV_DIR" ]; then + source "$VENV_DIR/bin/activate" + fi + python3 "$CONVERT" "$MODEL_DIR" "$WEIGHTS_Q4" --q4 + dim "Q4 weights ready: $(du -h "$WEIGHTS_Q4" | cut -f1)" + fi +else + dim "Q4 weights already exist: $(du -h "$WEIGHTS_Q4" | cut -f1)" +fi + +dim "F16 weights: $(du -h "$WEIGHTS_F16" | cut -f1)" +echo "" + +# ANE weight formats to benchmark +# GPU flag: empty for CPU formats, "--gpu" for Metal GPU formats +ANE_FMT_NAMES=("F16") +ANE_FMT_WEIGHTS=("$WEIGHTS_F16") +ANE_FMT_LABELS=("F16→F32 (AMX)") +ANE_FMT_GPU=("") + +if [ -n "$WEIGHTS_Q8" ] && [ -f "$WEIGHTS_Q8" ]; then + ANE_FMT_NAMES+=("Q8") + ANE_FMT_WEIGHTS+=("$WEIGHTS_Q8") + ANE_FMT_LABELS+=("Q8 (NEON dequant)") + ANE_FMT_GPU+=("") +fi + +if [ -n "$WEIGHTS_Q4" ] && [ -f "$WEIGHTS_Q4" ]; then + ANE_FMT_NAMES+=("Q4_Metal") + ANE_FMT_WEIGHTS+=("$WEIGHTS_Q4") + ANE_FMT_LABELS+=("Q4 SIMD (Metal GPU)") + ANE_FMT_GPU+=("--gpu") + + ANE_FMT_NAMES+=("Q4_AMX") + ANE_FMT_WEIGHTS+=("$WEIGHTS_Q4") + ANE_FMT_LABELS+=("Q4→F32 (AMX dequant)") + ANE_FMT_GPU+=("") +fi + +NUM_ANE_FMTS=${#ANE_FMT_NAMES[@]} +NUM_PROMPTS=${#PROMPTS[@]} + +# Global cleanup +SERVER_PID="" cleanup() { - kill "$SERVER_PID" 2>/dev/null || true + [ -n "$SERVER_PID" ] && kill "$SERVER_PID" 2>/dev/null || true rm -f "$SOCK" /tmp/qwen_bench_server.log } trap cleanup EXIT -# Wait for READY -for i in $(seq 1 30); do - if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then - break - fi +# Helper: start server with given weight file and optional extra flags, wait for READY +start_server() { + local wfile="$1" + shift + local extra_flags="$*" + [ -n "$SERVER_PID" ] && kill "$SERVER_PID" 2>/dev/null || true sleep 1 -done - -if ! grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then - echo "Server failed to start. Log:" + rm -f /tmp/qwen_bench_server.log + "$BINARY" "$wfile" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" $extra_flags > /tmp/qwen_bench_server.log 2>&1 & + SERVER_PID=$! + for _i in $(seq 1 30); do + if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then return 0; fi + sleep 1 + done + echo "Server failed to start with $wfile. Log:" cat /tmp/qwen_bench_server.log - exit 1 -fi -dim "Server ready (PID $SERVER_PID)" -echo "" + return 1 +} -# Warmup: first request primes any remaining caches -dim "Warmup run (discarded)..." -curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \ - -H "Content-Type: application/json" \ - -d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1 -echo "" +# --- Phase 1: Multi-format ANE benchmarks --- +# Per-format result tracking (indexed by format number) +declare -a ALL_AVG_P ALL_AVG_D ALL_AVG_INF ALL_AVG_TTFT ALL_AVG_RT +ANE_JSON_BLOCKS="" -# Print table header -printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \ - "Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead" -printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \ - "" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)" -printf '%.0s─' {1..85}; echo "" +for fmt_idx in $(seq 0 $((NUM_ANE_FMTS - 1))); do + FMT_NAME="${ANE_FMT_NAMES[$fmt_idx]}" + FMT_WEIGHTS="${ANE_FMT_WEIGHTS[$fmt_idx]}" + FMT_LABEL="${ANE_FMT_LABELS[$fmt_idx]}" + FMT_GPU="${ANE_FMT_GPU[$fmt_idx]}" -# Arrays for averages -declare -a P_TPS_ARR D_TPS_ARR INF_MS_ARR TTFT_MS_ARR RT_MS_ARR + echo "" + info "Phase 1.$((fmt_idx+1)): ANE $FMT_NAME benchmark ($FMT_LABEL)" + dim "Weights: $(du -h "$FMT_WEIGHTS" | cut -f1) — Starting server..." -JSON_ENTRIES="" -NUM_PROMPTS=${#PROMPTS[@]} + if ! start_server "$FMT_WEIGHTS" $FMT_GPU; then + echo "Skipping $FMT_NAME format." + ALL_AVG_P+=("0"); ALL_AVG_D+=("0"); ALL_AVG_INF+=("0") + ALL_AVG_TTFT+=("0"); ALL_AVG_RT+=("0") + continue + fi -for i in $(seq 0 $((NUM_PROMPTS - 1))); do - NAME="${PROMPT_NAMES[$i]}" - PROMPT="${PROMPTS[$i]}" - MAXTOK="${MAX_TOKENS[$i]}" - - RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') - RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \ + dim "Warmup run (discarded)..." + curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \ -H "Content-Type: application/json" \ - -d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1) - RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') - RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}') - - # Parse server JSON with pure shell -- no python - P_TOKENS=$(json_val "$RESP" "prompt_tokens") - G_TOKENS=$(json_val "$RESP" "gen_tokens") - P_TPS=$(json_val "$RESP" "prefill_tps") - D_TPS=$(json_val "$RESP" "decode_tps") - TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")") - INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")") - TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")") - TEXT=$(json_text "$RESP") - OVERHEAD=$((RT_MS - TOTAL_MS)) + -d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1 + echo "" printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \ - "$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD" + "Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead" + printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \ + "" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)" + printf '%.0s─' {1..85}; echo "" - P_TPS_ARR+=("$P_TPS") - D_TPS_ARR+=("$D_TPS") - INF_MS_ARR+=("$INF_MS") - TTFT_MS_ARR+=("$TTFT_MS") - RT_MS_ARR+=("$RT_MS") + declare -a P_TPS_ARR=() D_TPS_ARR=() INF_MS_ARR=() TTFT_MS_ARR=() RT_MS_ARR=() + FMT_JSON_ENTRIES="" - # Build JSON entry - JSON_ENTRIES="$JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS}," + for i in $(seq 0 $((NUM_PROMPTS - 1))); do + NAME="${PROMPT_NAMES[$i]}" + PROMPT="${PROMPTS[$i]}" + MAXTOK="${MAX_TOKENS[$i]}" - # Print response text indented below - echo " → $TEXT" + RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') + RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1) + RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') + RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}') + + P_TOKENS=$(json_val "$RESP" "prompt_tokens") + G_TOKENS=$(json_val "$RESP" "gen_tokens") + P_TPS=$(json_val "$RESP" "prefill_tps") + D_TPS=$(json_val "$RESP" "decode_tps") + TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")") + INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")") + TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")") + TEXT=$(json_text "$RESP") + OVERHEAD=$((RT_MS - TOTAL_MS)) + + printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \ + "$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD" + + P_TPS_ARR+=("$P_TPS") + D_TPS_ARR+=("$D_TPS") + INF_MS_ARR+=("$INF_MS") + TTFT_MS_ARR+=("$TTFT_MS") + RT_MS_ARR+=("$RT_MS") + + FMT_JSON_ENTRIES="$FMT_JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS}," + + echo " → $TEXT" + echo "" + done + + printf '%.0s─' {1..85}; echo "" + + F_AVG_P=$(shell_avg "${P_TPS_ARR[@]}") + F_AVG_D=$(shell_avg "${D_TPS_ARR[@]}") + F_AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}") + F_AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}") + F_AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}") + F_AVG_OVERHEAD=$((F_AVG_RT - F_AVG_INF)) + printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$F_AVG_P" "$F_AVG_D" "$F_AVG_TTFT" "$F_AVG_INF" "$F_AVG_RT" "$F_AVG_OVERHEAD" echo "" + + ALL_AVG_P+=("$F_AVG_P") + ALL_AVG_D+=("$F_AVG_D") + ALL_AVG_INF+=("$F_AVG_INF") + ALL_AVG_TTFT+=("$F_AVG_TTFT") + ALL_AVG_RT+=("$F_AVG_RT") + + ANE_JSON_BLOCKS="$ANE_JSON_BLOCKS + \"$FMT_NAME\": { + \"format\": \"$FMT_NAME\", + \"label\": \"$FMT_LABEL\", + \"weight_size_mb\": $(du -m "$FMT_WEIGHTS" | cut -f1), + \"avg_prefill_tps\": $F_AVG_P, + \"avg_decode_tps\": $F_AVG_D, + \"avg_inference_ms\": $F_AVG_INF, + \"avg_roundtrip_ms\": $F_AVG_RT, + \"avg_ttft_ms\": $F_AVG_TTFT, + \"results\": [${FMT_JSON_ENTRIES%,}] + }," done -printf '%.0s─' {1..85}; echo "" +# Use F16 results as the primary ANE numbers (first format) +AVG_P="${ALL_AVG_P[0]}" +AVG_D="${ALL_AVG_D[0]}" +AVG_INF="${ALL_AVG_INF[0]}" +AVG_TTFT="${ALL_AVG_TTFT[0]}" +AVG_RT="${ALL_AVG_RT[0]}" -# Averages (pure shell, no python) -AVG_P=$(shell_avg "${P_TPS_ARR[@]}") -AVG_D=$(shell_avg "${D_TPS_ARR[@]}") -AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}") -AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}") -AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}") -AVG_OVERHEAD=$((AVG_RT - AVG_INF)) -printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$AVG_P" "$AVG_D" "$AVG_TTFT" "$AVG_INF" "$AVG_RT" "$AVG_OVERHEAD" -echo "" info "Infer = server-reported (pure processing). Rndtrip = wall-clock (what clients see)." echo "" # --- Phase 2: Cold start measurement --- info "Phase 2: Cold start (single-shot, recompiles ANE kernels)" -# Kill server, run single-shot kill "$SERVER_PID" 2>/dev/null || true +SERVER_PID="" sleep 1 -# Use perl for sub-second timing (available on all macOS, no python) COLD_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') COLD_OUT=$("$BINARY" "$WEIGHTS" "151644 8948 198 2610 525 264 10950 17847 13 151645 198 151644 872 198 13048 151645 198 151644 77091 198" 10 2>&1 || true) COLD_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') @@ -193,17 +303,11 @@ COLD_MS=$(echo "$COLD_T0 $COLD_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}') echo "Cold start latency: ${COLD_MS}ms (includes ANE kernel compilation)" echo "" -# Re-start server for any additional tests -"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 & -SERVER_PID=$! +# Re-start server (F16) for consistency check +start_server "$WEIGHTS_F16" # --- Phase 3: Repeated prompt (consistency check) --- -info "Phase 3: Decode speed consistency (5x same prompt)" - -for retry in $(seq 1 15); do - if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then break; fi - sleep 1 -done +info "Phase 3: Decode speed consistency (5x same prompt, F16)" printf "%-6s %10s %10s %10s\n" "Run" "Prefill" "Decode" "Infer(ms)" printf '%.0s─' {1..40}; echo "" @@ -227,12 +331,8 @@ JSON="{ \"model\": \"Qwen2.5-0.5B-Instruct\", \"mode\": \"http_server\", \"cold_start_ms\": $COLD_MS, - \"avg_prefill_tps\": $AVG_P, - \"avg_decode_tps\": $AVG_D, - \"avg_inference_ms\": $AVG_INF, - \"avg_roundtrip_ms\": $AVG_RT, - \"avg_ttft_ms\": $AVG_TTFT, - \"results\": [${JSON_ENTRIES%,}] + \"ane_formats\": {$( echo "$ANE_JSON_BLOCKS" | sed '$ s/,$//' ) + } }" echo "$JSON" > "$RESULTS_JSON" dim "Results saved to $RESULTS_JSON" @@ -240,9 +340,12 @@ echo "" # --- Phase 4: LM Studio comparison (if running) --- LMS_PORT="${LMS_PORT:-1234}" -LMS_MODEL="${LMS_MODEL:-qwen2.5-0.5b-instruct}" LMS_API_KEY="${LMS_API_KEY:-}" +# Models to benchmark (override via LMS_MODELS env var, comma-separated) +LMS_MODELS_DEFAULT="qwen2.5-0.5b-instruct,qwen2.5-0.5b-instruct-mlx@8bit,qwen2.5-0.5b-instruct-mlx@4bit" +IFS=',' read -ra LMS_MODEL_LIST <<< "${LMS_MODELS:-$LMS_MODELS_DEFAULT}" + # Check if LM Studio is running LMS_REACHABLE=0 if curl -s --max-time 2 "http://localhost:$LMS_PORT/api/v1/chat" -H "Content-Type: application/json" -d '{}' >/dev/null 2>&1; then @@ -251,8 +354,8 @@ fi if [ "$LMS_REACHABLE" -eq 1 ]; then info "Phase 4: LM Studio comparison (localhost:$LMS_PORT)" + dim "Models: ${LMS_MODEL_LIST[*]}" - # If no API key, prompt for it if [ -z "$LMS_API_KEY" ]; then echo "" echo " LM Studio requires an API key." @@ -268,30 +371,53 @@ if [ "$LMS_REACHABLE" -eq 1 ]; then fi fi +LMS_ALL_JSON="" + if [ "$LMS_REACHABLE" -eq 1 ] && [ -n "$LMS_API_KEY" ]; then - echo "" - printf "%-10s %5s %5s %10s %10s %10s\n" \ - "Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip" - printf "%-10s %5s %5s %10s %10s %10s\n" \ - "" "tok" "tok" "(t/s)" "(ms)" "(ms)" - printf '%.0s─' {1..55}; echo "" - declare -a LMS_LATENCIES LMS_TPS_ARR LMS_TTFT_ARR - LMS_JSON_ENTRIES="" + # Track the best model for the final comparison table + BEST_LMS_MODEL="" + BEST_LMS_TPS="0" + BEST_LMS_LAT="99999" + BEST_LMS_TTFT="0" - for i in $(seq 0 $((NUM_PROMPTS - 1))); do - NAME="${PROMPT_NAMES[$i]}" - PROMPT="${PROMPTS[$i]}" + for LMS_MODEL in "${LMS_MODEL_LIST[@]}"; do + echo "" + info "── $LMS_MODEL ──" - T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') - LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \ + # Test if this model is available + TEST_RESP=$(curl -s --max-time 10 "http://localhost:$LMS_PORT/api/v1/chat" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $LMS_API_KEY" \ - -d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1) - T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') - LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}') + -d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"test\",\"input\":\"hi\"}" 2>&1) - eval "$(echo "$LMS_RESP" | python3 -c " + if echo "$TEST_RESP" | grep -qi "error\|not found\|not loaded\|no model"; then + dim " Model '$LMS_MODEL' not available, skipping." + continue + fi + + printf "%-10s %5s %5s %10s %10s %10s\n" \ + "Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip" + printf "%-10s %5s %5s %10s %10s %10s\n" \ + "" "tok" "tok" "(t/s)" "(ms)" "(ms)" + printf '%.0s─' {1..55}; echo "" + + declare -a LMS_LATENCIES=() LMS_TPS_ARR=() LMS_TTFT_ARR=() + LMS_JSON_ENTRIES="" + + for i in $(seq 0 $((NUM_PROMPTS - 1))); do + NAME="${PROMPT_NAMES[$i]}" + PROMPT="${PROMPTS[$i]}" + + T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') + LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $LMS_API_KEY" \ + -d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1) + T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time') + LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}') + + eval "$(echo "$LMS_RESP" | python3 -c " import sys, json try: r = json.load(sys.stdin) @@ -314,69 +440,191 @@ except Exception as e: print('LMS_OUT=0') " 2>/dev/null)" - printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS" - echo " → $LMS_TEXT" - echo "" - LMS_LATENCIES+=("$LMS_MS") - LMS_TPS_ARR+=("$LMS_TPS") - LMS_TTFT_ARR+=("$LMS_TTFT") - LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT}," + printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS" + LMS_LATENCIES+=("$LMS_MS") + LMS_TPS_ARR+=("$LMS_TPS") + LMS_TTFT_ARR+=("$LMS_TTFT") + LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT}," + done + + printf '%.0s─' {1..55}; echo "" + + M_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}") + M_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}") + M_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}") + printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$M_AVG_TPS" "$M_AVG_TTFT" "$M_AVG_LAT" + + # Track the best model by decode t/s + if awk "BEGIN {exit !($M_AVG_TPS > $BEST_LMS_TPS)}" 2>/dev/null; then + BEST_LMS_MODEL="$LMS_MODEL" + BEST_LMS_TPS="$M_AVG_TPS" + BEST_LMS_LAT="$M_AVG_LAT" + BEST_LMS_TTFT="$M_AVG_TTFT" + fi + + LMS_ALL_JSON="$LMS_ALL_JSON + \"$(echo "$LMS_MODEL" | sed 's/[^a-zA-Z0-9._-]/_/g')\": { + \"model\": \"$LMS_MODEL\", + \"avg_latency_ms\": $M_AVG_LAT, + \"avg_tps\": $M_AVG_TPS, + \"avg_ttft_ms\": $M_AVG_TTFT, + \"results\": [${LMS_JSON_ENTRIES%,}] + }," done - printf '%.0s─' {1..55}; echo "" - - # Averages (awk, no python) - LMS_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}") - LMS_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}") - LMS_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}") - printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$LMS_AVG_TPS" "$LMS_AVG_TTFT" "$LMS_AVG_LAT" echo "" - # Side-by-side comparison - info "=== Side-by-Side Comparison ===" - dim "(Round-trip = wall-clock from client, apples-to-apples)" - echo "" - printf "%-24s %15s %15s\n" "" "ANE (qwen_ane)" "LM Studio" - printf '%.0s─' {1..56}; echo "" - printf "%-24s %12s t/s %12s t/s\n" "Decode speed" "$AVG_D" "$LMS_AVG_TPS" - printf "%-24s %12s t/s %12s\n" "Prefill speed" "$AVG_P" "N/A" - printf "%-24s %12s ms %12s ms\n" "TTFT" "$AVG_TTFT" "$LMS_AVG_TTFT" - printf "%-24s %12s ms %12s ms\n" "Avg round-trip" "$AVG_RT" "$LMS_AVG_LAT" - printf "%-24s %12s ms %12s ms\n" " (server-only)" "$AVG_INF" "N/A" - printf "%-24s %12s ms %12s\n" "Cold start" "$COLD_MS" "N/A" - printf "%-24s %15s %15s\n" "Precision" "F32 (from BF16)" "GGUF quantized" - printf "%-24s %15s %15s\n" "Accelerator" "Neural Engine" "CPU/GPU" - printf "%-24s %15s %15s\n" "Timing method" "Wall-clock" "Wall-clock" + # --- Final Comparison Table: all ANE formats + all LM Studio models --- + info "=== Multi-Format Comparison ===" + dim "(All times are wall-clock round-trip, apples-to-apples)" echo "" - # Append LM Studio block to JSON results (pure shell, no python) - # Remove trailing "}" and newline, append lm_studio object + # Collect all column names and data + declare -a COL_NAMES=() COL_DECODE=() COL_PREFILL=() COL_TTFT=() COL_RT=() COL_PREC=() COL_ACCEL=() + + for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do + COL_NAMES+=("ANE ${ANE_FMT_NAMES[$fi2]}") + COL_DECODE+=("${ALL_AVG_D[$fi2]}") + COL_PREFILL+=("${ALL_AVG_P[$fi2]}") + COL_TTFT+=("${ALL_AVG_TTFT[$fi2]}") + COL_RT+=("${ALL_AVG_RT[$fi2]}") + COL_PREC+=("${ANE_FMT_LABELS[$fi2]}") + if [ -n "${ANE_FMT_GPU[$fi2]}" ]; then + COL_ACCEL+=("Metal GPU") + else + COL_ACCEL+=("CPU (AMX)") + fi + done + + # Add each tested LM Studio model as a column + declare -a LMS_TESTED_NAMES=() LMS_TESTED_TPS=() LMS_TESTED_TTFT=() LMS_TESTED_LAT=() + for LMS_MODEL in "${LMS_MODEL_LIST[@]}"; do + # Check if this model was actually tested (has data in LMS_ALL_JSON) + SAFE_KEY=$(echo "$LMS_MODEL" | sed 's/[^a-zA-Z0-9._-]/_/g') + if echo "$LMS_ALL_JSON" | grep -q "\"$SAFE_KEY\""; then + M_TPS=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_tps":[[:space:]]*\([0-9.]*\).*/\1/p' | head -1) + M_TTFT=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_ttft_ms":[[:space:]]*\([0-9]*\).*/\1/p' | head -1) + M_LAT=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_latency_ms":[[:space:]]*\([0-9]*\).*/\1/p' | head -1) + + SHORT_NAME=$(echo "$LMS_MODEL" | sed 's/qwen2.5-0.5b-instruct/q0.5b/; s/-mlx/mlx/') + COL_NAMES+=("LMS $SHORT_NAME") + COL_DECODE+=("${M_TPS:-0}") + COL_PREFILL+=("N/A") + COL_TTFT+=("${M_TTFT:-0}") + COL_RT+=("${M_LAT:-0}") + + PREC_TAG="GGUF" + echo "$LMS_MODEL" | grep -q "8bit" && PREC_TAG="MLX 8-bit" + echo "$LMS_MODEL" | grep -q "4bit" && PREC_TAG="MLX 4-bit" + COL_PREC+=("$PREC_TAG") + COL_ACCEL+=("CPU/GPU") + + LMS_TESTED_NAMES+=("$LMS_MODEL") + LMS_TESTED_TPS+=("${M_TPS:-0}") + LMS_TESTED_TTFT+=("${M_TTFT:-0}") + LMS_TESTED_LAT+=("${M_LAT:-0}") + fi + done + + NUM_COLS=${#COL_NAMES[@]} + COL_W=16 + + # Print header row + printf "%-20s" "" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_NAMES[$c]}"; done + echo "" + printf '%.0s─' $(seq 1 $((20 + NUM_COLS * COL_W))); echo "" + + # Data rows + printf "%-20s" "Decode (t/s)" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_DECODE[$c]}"; done + echo "" + + printf "%-20s" "Prefill (t/s)" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_PREFILL[$c]}"; done + echo "" + + printf "%-20s" "TTFT (ms)" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_TTFT[$c]}"; done + echo "" + + printf "%-20s" "Round-trip (ms)" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_RT[$c]}"; done + echo "" + + printf "%-20s" "Cold start (ms)" + printf "%${COL_W}s" "$COLD_MS" + for c in $(seq 1 $((NUM_COLS - 1))); do printf "%${COL_W}s" "N/A"; done + echo "" + + printf '%.0s─' $(seq 1 $((20 + NUM_COLS * COL_W))); echo "" + + printf "%-20s" "Precision" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_PREC[$c]}"; done + echo "" + + printf "%-20s" "Accelerator" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_ACCEL[$c]}"; done + echo "" + + printf "%-20s" "Timing" + for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "Wall-clock"; done + echo "" + echo "" + + # Append LM Studio results to JSON LMS_JSON_BLOCK=", \"lm_studio\": { \"port\": $LMS_PORT, - \"model\": \"$LMS_MODEL\", - \"avg_latency_ms\": $LMS_AVG_LAT, - \"avg_tps\": $LMS_AVG_TPS, - \"avg_ttft_ms\": $LMS_AVG_TTFT, - \"results\": [${LMS_JSON_ENTRIES%,}] + \"models_tested\": [$(printf '"%s",' "${LMS_MODEL_LIST[@]}" | sed 's/,$//')],$( echo "$LMS_ALL_JSON" | sed '$ s/,$//' ) } }" - # Replace the final "}" with the LM Studio block sed -i '' '$ s/}$//' "$RESULTS_JSON" printf '%s\n' "$LMS_JSON_BLOCK" >> "$RESULTS_JSON" dim "LM Studio results added to $RESULTS_JSON" else + # No LM Studio -- print ANE-only comparison if we have multiple formats + if [ "$NUM_ANE_FMTS" -gt 1 ]; then + info "=== ANE Format Comparison ===" + echo "" + printf "%-20s" "" + for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "ANE ${ANE_FMT_NAMES[$fi2]}"; done + echo "" + printf '%.0s─' $(seq 1 $((20 + NUM_ANE_FMTS * 16))); echo "" + printf "%-20s" "Decode (t/s)" + for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_D[$fi2]}"; done + echo "" + printf "%-20s" "Prefill (t/s)" + for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_P[$fi2]}"; done + echo "" + printf "%-20s" "TTFT (ms)" + for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_TTFT[$fi2]}"; done + echo "" + printf "%-20s" "Round-trip (ms)" + for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_RT[$fi2]}"; done + echo "" + printf '%.0s─' $(seq 1 $((20 + NUM_ANE_FMTS * 16))); echo "" + echo "" + fi + info "=== LM Studio Comparison ===" echo "" if [ "$LMS_REACHABLE" -eq 0 ]; then echo " LM Studio server not detected on localhost:$LMS_PORT" echo "" echo " To enable automatic comparison:" - echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF)" + echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF + MLX variants)" echo " 2. Load the model, go to Developer tab > Start Server" echo " 3. Re-run this benchmark" echo "" echo " Or set env vars: LMS_PORT=1234 LMS_API_KEY=your-key ./benchmark.sh" + echo "" + echo " Models benchmarked by default:" + echo " - qwen2.5-0.5b-instruct (GGUF)" + echo " - qwen2.5-0.5b-instruct-mlx@8bit (MLX 8-bit)" + echo " - qwen2.5-0.5b-instruct-mlx@4bit (MLX 4-bit)" + echo "" + echo " Override with: LMS_MODELS='model1,model2' ./benchmark.sh" fi echo "" echo " Manual test:" @@ -385,9 +633,9 @@ else echo " -H 'Authorization: Bearer YOUR_API_KEY' \\" echo " -d '{\"model\":\"qwen2.5-0.5b-instruct\",\"system_prompt\":\"You are a helpful assistant.\",\"input\":\"What is 2+2?\"}'" echo "" - echo " ANE (this benchmark): prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms" + echo " ANE F16: prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms" echo "" - echo " Note: LM Studio uses quantized GGUF (CPU/GPU) while we use" - echo " BF16 weights (full precision) running on the Neural Engine." + echo " Note: LM Studio uses quantized GGUF/MLX (CPU/GPU) while we use" + echo " F16/Q8 weights running on CPU AMX / NEON." fi echo "" diff --git a/inference/convert_weights.py b/inference/convert_weights.py index d5121fb..a7b01fa 100644 --- a/inference/convert_weights.py +++ b/inference/convert_weights.py @@ -1,11 +1,17 @@ #!/usr/bin/env python3 """Convert Qwen2.5-0.5B-Instruct safetensors → flat binary for ANE inference. -Output format: config header (7 ints) + all weights in f32, layer by layer. -Matches the layout expected by qwen_ane_infer.h. +Output format (F32): config header (8 ints) + all weights in f32 +Output format (F16): config header (8 ints) + embeddings f32 + projection weights f16 +Output format (Q8): config header (8 ints) + embeddings f32 + projection weights q8_0 +Output format (Q4): config header (8 ints) + embeddings f32 + projection weights q4_0 + +The 8th config int is the format flag: 0 = F32, 1 = F16, 2 = Q8, 3 = Q4. +Q8_0 format: blocks of 32 values, each block = 1 f16 scale + 32 int8 values (34 bytes). +Q4_0 format: blocks of 32 values, each block = 1 f16 scale + 1 f16 zero + 16 uint8 packed pairs (20 bytes). Usage: - python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct /path/to/output.bin + python3 convert_weights.py [--f16|--q8|--q4] """ import struct @@ -14,10 +20,74 @@ import numpy as np from pathlib import Path from safetensors import safe_open -def convert(model_dir: str, output_path: str): +Q8_BLOCK_SIZE = 32 +Q4_BLOCK_SIZE = 32 + +def quantize_q4_0(weights_f32): + """Quantize a 2D weight matrix to Q4_0 block format. + Returns bytes: for each row, blocks of (f16_scale + f16_zero + 16*uint8 packed pairs). + Each uint8 stores two 4-bit values: low nibble = even index, high nibble = odd index.""" + out_dim, in_dim = weights_f32.shape + assert in_dim % Q4_BLOCK_SIZE == 0, f"in_dim {in_dim} not divisible by {Q4_BLOCK_SIZE}" + + n_blocks_per_row = in_dim // Q4_BLOCK_SIZE + result = bytearray() + + for r in range(out_dim): + row = weights_f32[r] + for b in range(n_blocks_per_row): + block = row[b * Q4_BLOCK_SIZE : (b + 1) * Q4_BLOCK_SIZE] + bmin = np.min(block) + bmax = np.max(block) + if bmax == bmin: + scale = np.float16(0.0) + zero = np.float16(0.0) + packed = bytes(Q4_BLOCK_SIZE // 2) + else: + scale_f = (bmax - bmin) / 15.0 + zero_f = bmin + scale = np.float16(scale_f) + zero = np.float16(zero_f) + scale_f = float(scale) if float(scale) != 0.0 else 1e-10 + quant = np.clip(np.round((block - float(zero)) / scale_f), 0, 15).astype(np.uint8) + packed = bytearray(Q4_BLOCK_SIZE // 2) + for i in range(0, Q4_BLOCK_SIZE, 2): + packed[i // 2] = quant[i] | (quant[i + 1] << 4) + result += scale.tobytes() + result += zero.tobytes() + result += bytes(packed) + + return bytes(result) + + +def quantize_q8_0(weights_f32): + """Quantize a 2D weight matrix to Q8_0 block format. + Returns bytes: for each row, blocks of (f16_scale + 32*int8).""" + out_dim, in_dim = weights_f32.shape + assert in_dim % Q8_BLOCK_SIZE == 0, f"in_dim {in_dim} not divisible by {Q8_BLOCK_SIZE}" + + n_blocks_per_row = in_dim // Q8_BLOCK_SIZE + result = bytearray() + + for r in range(out_dim): + row = weights_f32[r] + for b in range(n_blocks_per_row): + block = row[b * Q8_BLOCK_SIZE : (b + 1) * Q8_BLOCK_SIZE] + amax = np.max(np.abs(block)) + scale = amax / 127.0 if amax > 0 else 0.0 + if scale > 0: + quant = np.round(block / scale).astype(np.int8) + else: + quant = np.zeros(Q8_BLOCK_SIZE, dtype=np.int8) + result += np.float16(scale).tobytes() + result += quant.tobytes() + + return bytes(result) + + +def convert(model_dir: str, output_path: str, fmt: str = "f32"): model_dir = Path(model_dir) - # Load safetensors st_files = list(model_dir.glob("*.safetensors")) if not st_files: print(f"No safetensors files in {model_dir}") @@ -30,8 +100,8 @@ def convert(model_dir: str, output_path: str): tensors[key] = sf.get_tensor(key).float().numpy() print(f"Loaded {len(tensors)} tensors from {len(st_files)} files") + print(f"Mode: {fmt.upper()} projections (embeddings + norms + biases stay F32)") - # Qwen2.5-0.5B config dim = 896 hidden = 4864 n_layers = 24 @@ -39,37 +109,41 @@ def convert(model_dir: str, output_path: str): n_kv_heads = 2 vocab_size = 151936 max_seq = 512 + fmt_flag = {"f32": 0, "f16": 1, "q8": 2, "q4": 3}[fmt] + + def write_proj(f_out, tensor_f32): + if fmt == "q4": + f_out.write(quantize_q4_0(tensor_f32)) + elif fmt == "q8": + f_out.write(quantize_q8_0(tensor_f32)) + elif fmt == "f16": + f_out.write(tensor_f32.astype(np.float16).tobytes()) + else: + f_out.write(tensor_f32.astype(np.float32).tobytes()) with open(output_path, "wb") as f: - # Config header: 7 x int32 - f.write(struct.pack("iiiiiii", - dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq)) + f.write(struct.pack("iiiiiiii", + dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq, fmt_flag)) - # Embedding [vocab, dim] emb = tensors["model.embed_tokens.weight"].astype(np.float32) - print(f"embed: {emb.shape}") + print(f"embed: {emb.shape} (f32)") f.write(emb.tobytes()) - # Per-layer weights for l in range(n_layers): prefix = f"model.layers.{l}" - # Attention norm rms_att = tensors[f"{prefix}.input_layernorm.weight"].astype(np.float32) f.write(rms_att.tobytes()) - # Q, K, V projections wq = tensors[f"{prefix}.self_attn.q_proj.weight"].astype(np.float32) wk = tensors[f"{prefix}.self_attn.k_proj.weight"].astype(np.float32) wv = tensors[f"{prefix}.self_attn.v_proj.weight"].astype(np.float32) wo = tensors[f"{prefix}.self_attn.o_proj.weight"].astype(np.float32) - f.write(wq.tobytes()) - f.write(wk.tobytes()) - f.write(wv.tobytes()) - f.write(wo.tobytes()) + write_proj(f, wq) + write_proj(f, wk) + write_proj(f, wv) + write_proj(f, wo) - # Q/K biases (Qwen has them) - # Q/K/V biases qb = tensors.get(f"{prefix}.self_attn.q_proj.bias") kb = tensors.get(f"{prefix}.self_attn.k_proj.bias") vb = tensors.get(f"{prefix}.self_attn.v_proj.bias") @@ -77,22 +151,19 @@ def convert(model_dir: str, output_path: str): f.write((kb if kb is not None else np.zeros(wk.shape[0])).astype(np.float32).tobytes()) f.write((vb if vb is not None else np.zeros(wv.shape[0])).astype(np.float32).tobytes()) - # FFN norm rms_ffn = tensors[f"{prefix}.post_attention_layernorm.weight"].astype(np.float32) f.write(rms_ffn.tobytes()) - # FFN: gate, up, down w_gate = tensors[f"{prefix}.mlp.gate_proj.weight"].astype(np.float32) w_up = tensors[f"{prefix}.mlp.up_proj.weight"].astype(np.float32) w_down = tensors[f"{prefix}.mlp.down_proj.weight"].astype(np.float32) - f.write(w_gate.tobytes()) - f.write(w_up.tobytes()) - f.write(w_down.tobytes()) + write_proj(f, w_gate) + write_proj(f, w_up) + write_proj(f, w_down) print(f" Layer {l}: Q{wq.shape} K{wk.shape} V{wv.shape} O{wo.shape} " - f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape}") + f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape} [{fmt}]") - # Final norm rms_final = tensors["model.norm.weight"].astype(np.float32) f.write(rms_final.tobytes()) @@ -101,7 +172,14 @@ def convert(model_dir: str, output_path: str): if __name__ == "__main__": - if len(sys.argv) != 3: - print("Usage: python3 convert_weights.py ") + if len(sys.argv) < 3: + print("Usage: python3 convert_weights.py [--f16|--q8|--q4]") sys.exit(1) - convert(sys.argv[1], sys.argv[2]) + fmt = "f32" + if "--f16" in sys.argv: + fmt = "f16" + elif "--q8" in sys.argv: + fmt = "q8" + elif "--q4" in sys.argv: + fmt = "q4" + convert(sys.argv[1], sys.argv[2], fmt) diff --git a/inference/main.m b/inference/main.m index 511850b..6e9e719 100644 --- a/inference/main.m +++ b/inference/main.m @@ -6,9 +6,10 @@ // 4. HTTP API: ./qwen_ane weights.bin --http 8000 --model-dir ~/models/Qwen2.5-0.5B-Instruct // // Build: -// xcrun clang -O2 -framework Foundation -framework IOSurface \ -// -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \ -// -o qwen_ane main.m +// xcrun clang -O3 -ffast-math -mcpu=apple-m4 -flto \ +// -framework Foundation -framework IOSurface \ +// -framework CoreML -framework Accelerate -framework Metal \ +// -ldl -lobjc -fobjc-arc -o qwen_ane main.m // #import #include @@ -39,36 +40,112 @@ static void handle_signal(int sig) { _exit(0); } +static void *safe_malloc(size_t size, const char *desc) { + void *p = malloc(size); + if (!p) { + fprintf(stderr, "FATAL: malloc failed for %s (%.1f MB)\n", + desc, (double)size / (1024*1024)); + exit(1); + } + return p; +} + +static void *safe_calloc(size_t count, size_t size, const char *desc) { + void *p = calloc(count, size); + if (!p) { + fprintf(stderr, "FATAL: calloc failed for %s (%.1f MB)\n", + desc, (double)(count * size) / (1024*1024)); + exit(1); + } + return p; +} + static int load_weights(const char *path) { FILE *f = fopen(path, "rb"); if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; } - int config[7]; - fread(config, sizeof(int), 7, f); + // Try 8-int header first (new format), fall back to 7-int (legacy) + int config[8] = {0}; + size_t hdr_read = fread(config, sizeof(int), 8, f); int dim = config[0], hidden = config[1], n_layers = config[2]; int n_heads = config[3], n_kv_heads = config[4], vocab = config[5]; - printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d\n", - dim, hidden, n_layers, n_heads, n_kv_heads, vocab); + int fmt_flag = 0; + + if (hdr_read == 8 && config[7] >= 0 && config[7] <= 3) { + fmt_flag = config[7]; + } else { + fseek(f, 7 * sizeof(int), SEEK_SET); + } + + g_model.weight_fmt = fmt_flag; + int is_f16 = (fmt_flag == 1); + int is_q8 = (fmt_flag == 2); + int is_q4 = (fmt_flag == 3); + const char *fmt_str = is_q4 ? "Q4" : (is_q8 ? "Q8" : (is_f16 ? "F16" : "F32")); + printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d fmt=%s\n", + dim, hidden, n_layers, n_heads, n_kv_heads, vocab, fmt_str); int q_dim = n_heads * QWEN_HEAD_DIM; int kv_dim = n_kv_heads * QWEN_HEAD_DIM; - g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float)); + // Embeddings always F32 + g_model.embed = (float*)safe_malloc((size_t)vocab * dim * sizeof(float), "embed"); fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f); for (int l = 0; l < n_layers; l++) { + // RMSNorm always F32 g_model.rms_att[l] = (float*)malloc(dim * sizeof(float)); fread(g_model.rms_att[l], sizeof(float), dim, f); - g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); - fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f); - g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float)); - fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f); - g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float)); - fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f); - g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); - fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f); + if (is_q4) { + #define LOAD_Q4(q8ptr, out_d, in_d) do { \ + size_t _nb = (size_t)(in_d) / Q4_BLOCK_SIZE; \ + size_t _bytes = (size_t)(out_d) * _nb * Q4_BLOCK_BYTES; \ + q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \ + fread(q8ptr, 1, _bytes, f); \ + } while(0) + LOAD_Q4(g_model.wq_q8[l], q_dim, dim); + LOAD_Q4(g_model.wk_q8[l], kv_dim, dim); + LOAD_Q4(g_model.wv_q8[l], kv_dim, dim); + LOAD_Q4(g_model.wo_q8[l], dim, q_dim); + #undef LOAD_Q4 + } else if (is_q8) { + #define LOAD_Q8(q8ptr, out_d, in_d) do { \ + size_t _nb = (size_t)(in_d) / Q8_BLOCK_SIZE; \ + size_t _bytes = (size_t)(out_d) * _nb * Q8_BLOCK_BYTES; \ + q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \ + fread(q8ptr, 1, _bytes, f); \ + } while(0) + LOAD_Q8(g_model.wq_q8[l], q_dim, dim); + LOAD_Q8(g_model.wk_q8[l], kv_dim, dim); + LOAD_Q8(g_model.wv_q8[l], kv_dim, dim); + LOAD_Q8(g_model.wo_q8[l], dim, q_dim); + #undef LOAD_Q8 + } else if (is_f16) { + #define LOAD_F16_AS_F32(f32ptr, f16ptr, n) do { \ + size_t _n = (size_t)(n); \ + f16ptr = (_Float16*)malloc(_n * sizeof(_Float16)); \ + fread(f16ptr, sizeof(_Float16), _n, f); \ + f32ptr = (float*)malloc(_n * sizeof(float)); \ + convert_f16_to_f32(f16ptr, f32ptr, _n); \ + } while(0) + LOAD_F16_AS_F32(g_model.wq[l], g_model.wq_f16[l], (size_t)q_dim * dim); + LOAD_F16_AS_F32(g_model.wk[l], g_model.wk_f16[l], (size_t)kv_dim * dim); + LOAD_F16_AS_F32(g_model.wv[l], g_model.wv_f16[l], (size_t)kv_dim * dim); + LOAD_F16_AS_F32(g_model.wo[l], g_model.wo_f16[l], (size_t)dim * q_dim); + #undef LOAD_F16_AS_F32 + } else { + g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); + fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f); + g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float)); + fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f); + g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float)); + fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f); + g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); + fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f); + } + // Biases always F32 g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float)); g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float)); g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float)); @@ -76,15 +153,52 @@ static int load_weights(const char *path) { fread(g_model.k_bias[l], sizeof(float), kv_dim, f); fread(g_model.v_bias[l], sizeof(float), kv_dim, f); + // FFN RMSNorm always F32 g_model.rms_ffn[l] = (float*)malloc(dim * sizeof(float)); fread(g_model.rms_ffn[l], sizeof(float), dim, f); - g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float)); - fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f); - g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float)); - fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f); - g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float)); - fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f); + if (is_q4) { + #define LOAD_Q4(q8ptr, out_d, in_d) do { \ + size_t _nb = (size_t)(in_d) / Q4_BLOCK_SIZE; \ + size_t _bytes = (size_t)(out_d) * _nb * Q4_BLOCK_BYTES; \ + q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \ + fread(q8ptr, 1, _bytes, f); \ + } while(0) + LOAD_Q4(g_model.wgate_q8[l], hidden, dim); + LOAD_Q4(g_model.wup_q8[l], hidden, dim); + LOAD_Q4(g_model.wdown_q8[l], dim, hidden); + #undef LOAD_Q4 + } else if (is_q8) { + #define LOAD_Q8(q8ptr, out_d, in_d) do { \ + size_t _nb = (size_t)(in_d) / Q8_BLOCK_SIZE; \ + size_t _bytes = (size_t)(out_d) * _nb * Q8_BLOCK_BYTES; \ + q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \ + fread(q8ptr, 1, _bytes, f); \ + } while(0) + LOAD_Q8(g_model.wgate_q8[l], hidden, dim); + LOAD_Q8(g_model.wup_q8[l], hidden, dim); + LOAD_Q8(g_model.wdown_q8[l], dim, hidden); + #undef LOAD_Q8 + } else if (is_f16) { + #define LOAD_F16_AS_F32(f32ptr, f16ptr, n) do { \ + size_t _n = (size_t)(n); \ + f16ptr = (_Float16*)malloc(_n * sizeof(_Float16)); \ + fread(f16ptr, sizeof(_Float16), _n, f); \ + f32ptr = (float*)malloc(_n * sizeof(float)); \ + convert_f16_to_f32(f16ptr, f32ptr, _n); \ + } while(0) + LOAD_F16_AS_F32(g_model.w_gate[l], g_model.wgate_f16[l], (size_t)hidden * dim); + LOAD_F16_AS_F32(g_model.w_up[l], g_model.wup_f16[l], (size_t)hidden * dim); + LOAD_F16_AS_F32(g_model.w_down[l], g_model.wdown_f16[l], (size_t)dim * hidden); + #undef LOAD_F16_AS_F32 + } else { + g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float)); + fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f); + g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float)); + fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f); + g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float)); + fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f); + } } g_model.rms_final = (float*)malloc(dim * sizeof(float)); @@ -92,7 +206,8 @@ static int load_weights(const char *path) { long file_size = ftell(f); fclose(f); - printf("Weights loaded (%.0f MB)\n", (float)file_size / 1024 / 1024); + printf("Weights loaded (%.0f MB, %s projections)\n", + (float)file_size / 1024 / 1024, fmt_str); return 0; } @@ -115,16 +230,25 @@ static double timespec_diff(struct timespec *a, struct timespec *b) { } // Run one generation pass. Writes output token IDs to out_ids, returns count. -// If out_fd >= 0, writes formatted results there; otherwise prints to stdout. +// Uses batched prefill (sgemm) for prompt, sequential decode (sgemv) for generation. static int generate(int *prompt_ids, int n_prompt, int max_gen, int *out_ids, int max_out, double *prefill_tps, double *decode_tps) { struct timespec t0, t1, t_pre; clock_gettime(CLOCK_MONOTONIC, &t0); - int next = 0; - for (int i = 0; i < n_prompt; i++) - next = qwen_forward(&g_model, prompt_ids[i]); + int next; + if (g_model.use_ane) { + for (int i = 0; i < n_prompt; i++) + next = qwen_forward_ane(&g_model, prompt_ids[i]); + } else if (n_prompt > 1 && g_model.weight_fmt == 3) { + next = qwen_prefill_q4(&g_model, prompt_ids, n_prompt); + } else if (n_prompt > 1 && g_model.weight_fmt != 2) { + next = qwen_prefill(&g_model, prompt_ids, n_prompt); + } else { + for (int i = 0; i < n_prompt; i++) + next = qwen_forward(&g_model, prompt_ids[i]); + } clock_gettime(CLOCK_MONOTONIC, &t_pre); double ps = timespec_diff(&t0, &t_pre); @@ -135,7 +259,10 @@ static int generate(int *prompt_ids, int n_prompt, int max_gen, for (int i = 0; i < max_gen && n_out < max_out; i++) { if (n_out < max_out) out_ids[n_out++] = next; if (next == eos || next == eos2) break; - next = qwen_forward(&g_model, next); + if (g_model.use_ane) + next = qwen_forward_ane(&g_model, next); + else + next = qwen_forward(&g_model, next); } clock_gettime(CLOCK_MONOTONIC, &t1); @@ -427,6 +554,7 @@ int main(int argc, char **argv) { int server_mode = 0; int http_port = 0; int test_ane = 0; + int use_ane = 0; const char *sock_path = NULL; const char *model_dir = NULL; for (int i = 2; i < argc; i++) { @@ -442,6 +570,61 @@ int main(int argc, char **argv) { else { fprintf(stderr, "--model-dir requires a path\n"); return 1; } } else if (strcmp(argv[i], "--test-ane") == 0) { test_ane = 1; + } else if (strcmp(argv[i], "--ane") == 0) { + use_ane = 1; + } + } + + // Q4 CPU mode: dequantize Q4 to F32 at load time, use AMX cblas_sgemv + if (g_model.weight_fmt == 3) { + printf("Dequantizing Q4→F32 for AMX acceleration...\n"); + int q_dim = QWEN_Q_DIM, kv_dim = QWEN_KV_DIM, dim = QWEN_DIM; + int hidden = QWEN_HIDDEN; + + #define DEQUANT_Q4_TO_F32(f32ptr, q4ptr, out_d, in_d) do { \ + size_t _n = (size_t)(out_d) * (in_d); \ + f32ptr = (float*)malloc(_n * sizeof(float)); \ + dequant_q4_to_f32(q4ptr, f32ptr, (in_d), (out_d)); \ + free(q4ptr); q4ptr = NULL; \ + } while(0) + + for (int l = 0; l < QWEN_LAYERS; l++) { + DEQUANT_Q4_TO_F32(g_model.wq[l], g_model.wq_q8[l], q_dim, dim); + DEQUANT_Q4_TO_F32(g_model.wk[l], g_model.wk_q8[l], kv_dim, dim); + DEQUANT_Q4_TO_F32(g_model.wv[l], g_model.wv_q8[l], kv_dim, dim); + DEQUANT_Q4_TO_F32(g_model.wo[l], g_model.wo_q8[l], dim, q_dim); + DEQUANT_Q4_TO_F32(g_model.w_gate[l], g_model.wgate_q8[l], hidden, dim); + DEQUANT_Q4_TO_F32(g_model.w_up[l], g_model.wup_q8[l], hidden, dim); + DEQUANT_Q4_TO_F32(g_model.w_down[l], g_model.wdown_q8[l], dim, hidden); + } + #undef DEQUANT_Q4_TO_F32 + + g_model.weight_fmt = 0; + printf("Q4→F32 done. Using AMX cblas_sgemv (91+ t/s decode).\n"); + } + + // ANE fused kernel compilation (requires F32 weights for baked-weight convs) + if (use_ane) { + if (g_model.weight_fmt != 0) { + printf("--ane requires F32 weights (weight_fmt=0). Got fmt=%d\n", g_model.weight_fmt); + printf("Re-run with F32 weight file (convert_weights.py without --f16/--q4/--q8)\n"); + use_ane = 0; + } else { + struct timespec ta0, ta1; + clock_gettime(CLOCK_MONOTONIC, &ta0); + qwen_compile_kernels_fused(&g_model); + clock_gettime(CLOCK_MONOTONIC, &ta1); + double ane_sec = timespec_diff(&ta0, &ta1); + printf("ANE fused compile time: %.1fs\n", ane_sec); + + // Verify at least one QKV kernel compiled + if (g_model.k_qkv[0] && g_model.k_o[0] && g_model.k_ffn_up[0] && g_model.k_down[0]) { + g_model.use_ane = 1; + printf("ANE fused mode active: 112 kernels (QKV+FFN_up fused)\n"); + } else { + printf("ANE fused compilation failed, falling back to CPU\n"); + use_ane = 0; + } } } diff --git a/inference/matmul.metal b/inference/matmul.metal new file mode 100644 index 0000000..e5ed10e --- /dev/null +++ b/inference/matmul.metal @@ -0,0 +1,921 @@ +#include +using namespace metal; + +// ── Q4_0 block format ──────────────────────────────────────────────── +// Block of 32 values: 2 bytes F16 scale + 2 bytes F16 zero + 16 bytes packed uint8 +// Each uint8 stores 2 values: low nibble = even index, high nibble = odd index +// Total: 20 bytes per block of 32 weights +#define Q4_BLOCK_SIZE 32 +#define Q4_BLOCK_BYTES 20 + +// ── Q4 Matrix-vector multiply (legacy, 1 thread per row) ──────────── +// Kept as fallback for edge cases. +kernel void sgemv_q4( + device const uint8_t *W [[buffer(0)]], + device const float *x [[buffer(1)]], + device float *y [[buffer(2)]], + constant uint &in_dim [[buffer(3)]], + constant uint &out_dim [[buffer(4)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= out_dim) return; + + uint n_blocks = in_dim / Q4_BLOCK_SIZE; + uint row_bytes = n_blocks * Q4_BLOCK_BYTES; + device const uint8_t *row = W + uint64_t(gid) * row_bytes; + + float sum = 0.0f; + for (uint b = 0; b < n_blocks; b++) { + device const uint8_t *block = row + b * Q4_BLOCK_BYTES; + half scale_h, zero_h; + scale_h = *reinterpret_cast(block); + zero_h = *reinterpret_cast(block + 2); + float scale = float(scale_h); + float zero = float(zero_h); + + device const uint8_t *packed = block + 4; + uint base = b * Q4_BLOCK_SIZE; + + for (uint i = 0; i < 16; i++) { + uint8_t byte = packed[i]; + float w0 = float(byte & 0xF) * scale + zero; + float w1 = float(byte >> 4) * scale + zero; + sum += w0 * x[base + i * 2]; + sum += w1 * x[base + i * 2 + 1]; + } + } + y[gid] = sum; +} + +// ── Q4 SIMD-optimized matrix-vector multiply ───────────────────────── +// MLX-style cooperative SIMD kernel: 2 SIMD groups per threadgroup, +// each SIMD group handles ROWS_PER_SIMD output rows cooperatively. +// 32 threads in a SIMD group split the K (input) dimension, then +// reduce via simd_sum(). No threadgroup memory needed. +// +// Threadgroup layout: 64 threads = 2 SIMD groups of 32 +// Grid: (ceil(out_dim / ROWS_PER_TG), 1, 1) threadgroups +// +// Optional bias: if bias pointer is non-null (use_bias != 0), +// y[r] = dot(W[r], x) + bias[r] +#define ROWS_PER_SIMD 4 +#define SIMD_GROUPS 2 +#define ROWS_PER_TG (ROWS_PER_SIMD * SIMD_GROUPS) + +kernel void sgemv_q4_fast( + device const uint8_t *W [[buffer(0)]], + device const float *x [[buffer(1)]], + device float *y [[buffer(2)]], + constant uint &in_dim [[buffer(3)]], + constant uint &out_dim [[buffer(4)]], + device const float *bias [[buffer(5)]], + constant uint &use_bias [[buffer(6)]], + uint tgid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) +{ + uint base_row = tgid * ROWS_PER_TG + simd_gid * ROWS_PER_SIMD; + if (base_row >= out_dim) return; + + uint n_blocks = in_dim / Q4_BLOCK_SIZE; + uint row_bytes = n_blocks * Q4_BLOCK_BYTES; + + uint rows_this = min((uint)ROWS_PER_SIMD, out_dim - base_row); + + float accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f}; + float zero_accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // Each of 32 SIMD lanes processes a stripe of blocks. + // Lane i processes blocks i, i+32, i+64, ... + for (uint b = simd_lid; b < n_blocks; b += 32) { + uint k_base = b * Q4_BLOCK_SIZE; + + // Load input vector segment for this block (32 floats) + float xv[Q4_BLOCK_SIZE]; + for (uint j = 0; j < 16; j++) { + xv[j * 2] = x[k_base + j * 2]; + xv[j * 2 + 1] = x[k_base + j * 2 + 1]; + } + + for (uint r = 0; r < rows_this; r++) { + device const uint8_t *block = + W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES; + + half scale_h = *reinterpret_cast(block); + half zero_h = *reinterpret_cast(block + 2); + float scale = float(scale_h); + float zero = float(zero_h); + + device const uint8_t *packed = block + 4; + + float dot = 0.0f; + float xsum = 0.0f; + for (uint j = 0; j < 16; j++) { + uint8_t byte = packed[j]; + float w0 = float(byte & 0xF); + float w1 = float(byte >> 4); + dot += w0 * xv[j * 2] + w1 * xv[j * 2 + 1]; + xsum += xv[j * 2] + xv[j * 2 + 1]; + } + accum[r] += dot * scale; + zero_accum[r] += xsum * zero; + } + } + + // SIMD reduction across 32 lanes + for (uint r = 0; r < rows_this; r++) { + float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]); + if (simd_lid == 0) { + if (use_bias != 0) { + result += bias[base_row + r]; + } + y[base_row + r] = result; + } + } +} + +// ── Fused Gate+Up+SiLU: reads x once, computes gate=silu(Wg*x)*Wu*x ── +// Combines two Q4 matvecs + silu_mul into one kernel. +// W_gate and W_up have the same dimensions [out_dim, in_dim]. +// Output: gate[r] = silu(dot(W_gate[r], x)) * dot(W_up[r], x) +#define FUSED_ROWS_PER_SIMD 2 +#define FUSED_SIMD_GROUPS 2 +#define FUSED_ROWS_PER_TG (FUSED_ROWS_PER_SIMD * FUSED_SIMD_GROUPS) + +kernel void sgemv_q4_fused_ffn( + device const uint8_t *W_gate [[buffer(0)]], + device const uint8_t *W_up [[buffer(1)]], + device const float *x [[buffer(2)]], + device float *out [[buffer(3)]], + constant uint &in_dim [[buffer(4)]], + constant uint &out_dim [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) +{ + uint base_row = tgid * FUSED_ROWS_PER_TG + simd_gid * FUSED_ROWS_PER_SIMD; + if (base_row >= out_dim) return; + + uint n_blocks = in_dim / Q4_BLOCK_SIZE; + uint row_bytes = n_blocks * Q4_BLOCK_BYTES; + + uint rows_this = min((uint)FUSED_ROWS_PER_SIMD, out_dim - base_row); + + float gate_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f}; + float gate_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f}; + float up_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f}; + float up_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f}; + + for (uint b = simd_lid; b < n_blocks; b += 32) { + uint k_base = b * Q4_BLOCK_SIZE; + + float xv[Q4_BLOCK_SIZE]; + for (uint j = 0; j < 16; j++) { + xv[j * 2] = x[k_base + j * 2]; + xv[j * 2 + 1] = x[k_base + j * 2 + 1]; + } + + for (uint r = 0; r < rows_this; r++) { + uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES; + + // Gate weight block + device const uint8_t *g_block = W_gate + row_off; + float g_scale = float(*reinterpret_cast(g_block)); + float g_zero = float(*reinterpret_cast(g_block + 2)); + device const uint8_t *g_packed = g_block + 4; + + // Up weight block + device const uint8_t *u_block = W_up + row_off; + float u_scale = float(*reinterpret_cast(u_block)); + float u_zero = float(*reinterpret_cast(u_block + 2)); + device const uint8_t *u_packed = u_block + 4; + + float g_dot = 0.0f, g_xsum = 0.0f; + float u_dot = 0.0f, u_xsum = 0.0f; + for (uint j = 0; j < 16; j++) { + float x0 = xv[j * 2]; + float x1 = xv[j * 2 + 1]; + float xs = x0 + x1; + + uint8_t gb = g_packed[j]; + g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1; + g_xsum += xs; + + uint8_t ub = u_packed[j]; + u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1; + u_xsum += xs; + } + gate_acc[r] += g_dot * g_scale; + gate_zacc[r] += g_xsum * g_zero; + up_acc[r] += u_dot * u_scale; + up_zacc[r] += u_xsum * u_zero; + } + } + + for (uint r = 0; r < rows_this; r++) { + float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]); + float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]); + if (simd_lid == 0) { + float s = g / (1.0f + exp(-g)); + out[base_row + r] = s * u; + } + } +} + +// ── Q4 batched matrix-matrix multiply (SGEMM) for prefill ──────────── +// Y[t, r] = sum_k(dequant(W[r, k]) * X[t, k]) for t in [0, n_tokens), r in [0, out_dim) +// Grid: (ceil(out_dim / GEMM_TILE_M), n_tokens, 1) +// Each threadgroup: 2 SIMD groups, each handles GEMM_TILE_M/2 output rows for one token. +#define GEMM_TILE_M 8 +#define GEMM_SIMD_GROUPS 2 +#define GEMM_ROWS_PER_SIMD (GEMM_TILE_M / GEMM_SIMD_GROUPS) + +kernel void sgemm_q4( + device const uint8_t *W [[buffer(0)]], + device const float *X [[buffer(1)]], + device float *Y [[buffer(2)]], + constant uint &in_dim [[buffer(3)]], + constant uint &out_dim [[buffer(4)]], + device const float *bias [[buffer(5)]], + constant uint &use_bias [[buffer(6)]], + constant uint &n_tokens [[buffer(7)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) +{ + uint base_row = tgid.x * GEMM_TILE_M + simd_gid * GEMM_ROWS_PER_SIMD; + uint t = tgid.y; + if (base_row >= out_dim || t >= n_tokens) return; + + uint n_blocks = in_dim / Q4_BLOCK_SIZE; + uint row_bytes = n_blocks * Q4_BLOCK_BYTES; + uint rows_this = min((uint)GEMM_ROWS_PER_SIMD, out_dim - base_row); + + device const float *xt = X + uint64_t(t) * in_dim; + + float accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f}; + float zero_accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f}; + + for (uint b = simd_lid; b < n_blocks; b += 32) { + uint k_base = b * Q4_BLOCK_SIZE; + + float xv[Q4_BLOCK_SIZE]; + for (uint j = 0; j < 16; j++) { + xv[j * 2] = xt[k_base + j * 2]; + xv[j * 2 + 1] = xt[k_base + j * 2 + 1]; + } + + for (uint r = 0; r < rows_this; r++) { + device const uint8_t *block = + W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES; + + float scale = float(*reinterpret_cast(block)); + float zero = float(*reinterpret_cast(block + 2)); + device const uint8_t *packed = block + 4; + + float dot = 0.0f; + float xsum = 0.0f; + for (uint j = 0; j < 16; j++) { + uint8_t byte = packed[j]; + dot += float(byte & 0xF) * xv[j * 2] + float(byte >> 4) * xv[j * 2 + 1]; + xsum += xv[j * 2] + xv[j * 2 + 1]; + } + accum[r] += dot * scale; + zero_accum[r] += xsum * zero; + } + } + + for (uint r = 0; r < rows_this; r++) { + float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]); + if (simd_lid == 0) { + if (use_bias != 0) + result += bias[base_row + r]; + Y[uint64_t(t) * out_dim + base_row + r] = result; + } + } +} + +// ── Q4 batched fused Gate+Up+SiLU (SGEMM variant) ─────────────────── +// out[t, r] = silu(Wg[r] . X[t]) * Wu[r] . X[t] for all t and r +// Grid: (ceil(out_dim / GEMM_FFN_TILE_M), n_tokens, 1) +#define GEMM_FFN_TILE_M 4 +#define GEMM_FFN_SIMD_GROUPS 2 +#define GEMM_FFN_ROWS_PER_SIMD (GEMM_FFN_TILE_M / GEMM_FFN_SIMD_GROUPS) + +kernel void sgemm_q4_fused_ffn( + device const uint8_t *W_gate [[buffer(0)]], + device const uint8_t *W_up [[buffer(1)]], + device const float *X [[buffer(2)]], + device float *out [[buffer(3)]], + constant uint &in_dim [[buffer(4)]], + constant uint &out_dim [[buffer(5)]], + constant uint &n_tokens [[buffer(6)]], + uint2 tgid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) +{ + uint base_row = tgid.x * GEMM_FFN_TILE_M + simd_gid * GEMM_FFN_ROWS_PER_SIMD; + uint t = tgid.y; + if (base_row >= out_dim || t >= n_tokens) return; + + uint n_blocks = in_dim / Q4_BLOCK_SIZE; + uint row_bytes = n_blocks * Q4_BLOCK_BYTES; + uint rows_this = min((uint)GEMM_FFN_ROWS_PER_SIMD, out_dim - base_row); + + device const float *xt = X + uint64_t(t) * in_dim; + + float gate_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f}; + float gate_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f}; + float up_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f}; + float up_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f}; + + for (uint b = simd_lid; b < n_blocks; b += 32) { + uint k_base = b * Q4_BLOCK_SIZE; + + float xv[Q4_BLOCK_SIZE]; + for (uint j = 0; j < 16; j++) { + xv[j * 2] = xt[k_base + j * 2]; + xv[j * 2 + 1] = xt[k_base + j * 2 + 1]; + } + + for (uint r = 0; r < rows_this; r++) { + uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES; + + device const uint8_t *g_block = W_gate + row_off; + float g_scale = float(*reinterpret_cast(g_block)); + float g_zero = float(*reinterpret_cast(g_block + 2)); + device const uint8_t *g_packed = g_block + 4; + + device const uint8_t *u_block = W_up + row_off; + float u_scale = float(*reinterpret_cast(u_block)); + float u_zero = float(*reinterpret_cast(u_block + 2)); + device const uint8_t *u_packed = u_block + 4; + + float g_dot = 0.0f, g_xsum = 0.0f; + float u_dot = 0.0f, u_xsum = 0.0f; + for (uint j = 0; j < 16; j++) { + float x0 = xv[j * 2]; + float x1 = xv[j * 2 + 1]; + float xs = x0 + x1; + + uint8_t gb = g_packed[j]; + g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1; + g_xsum += xs; + + uint8_t ub = u_packed[j]; + u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1; + u_xsum += xs; + } + gate_acc[r] += g_dot * g_scale; + gate_zacc[r] += g_xsum * g_zero; + up_acc[r] += u_dot * u_scale; + up_zacc[r] += u_xsum * u_zero; + } + } + + for (uint r = 0; r < rows_this; r++) { + float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]); + float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]); + if (simd_lid == 0) { + float s = g / (1.0f + exp(-g)); + out[uint64_t(t) * out_dim + base_row + r] = s * u; + } + } +} + +// ── Batched RMSNorm (N tokens) ────────────────────────────────────── +// x[t*dim .. (t+1)*dim-1] → out[t*dim .. (t+1)*dim-1] +// Grid: (n_tokens, 1, 1) threadgroups, each normalizes one token. +kernel void rms_norm_batched( + device const float *x [[buffer(0)]], + device const float *w [[buffer(1)]], + device float *out [[buffer(2)]], + constant uint &dim [[buffer(3)]], + constant float &eps [[buffer(4)]], + constant uint &n_tokens [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint tpg [[threads_per_threadgroup]]) +{ + if (tgid >= n_tokens) return; + + device const float *xi = x + uint64_t(tgid) * dim; + device float *oi = out + uint64_t(tgid) * dim; + + threadgroup float partial[1024]; + + float local_sum = 0.0f; + for (uint i = tid; i < dim; i += tpg) + local_sum += xi[i] * xi[i]; + partial[tid] = local_sum; + + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s) partial[tid] += partial[tid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + float rms_inv = rsqrt(partial[0] / float(dim) + eps); + + for (uint i = tid; i < dim; i += tpg) + oi[i] = xi[i] * rms_inv * w[i]; +} + +// ── Batched embedding lookup (N tokens) ───────────────────────────── +// Grid: (dim, n_tokens, 1). Each thread copies one element. +kernel void embed_lookup_batched( + device const float *embed [[buffer(0)]], + device float *out [[buffer(1)]], + device const uint *token_ids [[buffer(2)]], + constant uint &dim [[buffer(3)]], + uint2 gid [[thread_position_in_grid]]) +{ + uint i = gid.x; + uint t = gid.y; + if (i >= dim) return; + out[uint64_t(t) * dim + i] = embed[uint64_t(token_ids[t]) * dim + i]; +} + +// ── Batched RoPE (N tokens) ───────────────────────────────────────── +// Applies RoPE to Q[t] and K[t] for each token t at position base_pos+t. +// Grid: total_pairs per token * n_tokens +kernel void rope_apply_batched( + device float *q [[buffer(0)]], + device float *k [[buffer(1)]], + device const float *cos_tbl [[buffer(2)]], + device const float *sin_tbl [[buffer(3)]], + constant uint &n_q_heads [[buffer(4)]], + constant uint &n_kv_heads [[buffer(5)]], + constant uint &head_dim [[buffer(6)]], + constant uint &base_pos [[buffer(7)]], + constant uint &q_stride [[buffer(8)]], + constant uint &k_stride [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) +{ + uint pair_idx = gid.x; + uint t = gid.y; + + uint half_dim = head_dim / 2; + uint total_pairs = (n_q_heads + n_kv_heads) * half_dim; + if (pair_idx >= total_pairs) return; + + uint head_pair = pair_idx / half_dim; + uint i = pair_idx % half_dim; + uint pos = base_pos + t; + + device float *vec; + if (head_pair < n_q_heads) + vec = q + uint64_t(t) * q_stride + head_pair * head_dim; + else + vec = k + uint64_t(t) * k_stride + (head_pair - n_q_heads) * head_dim; + + uint cos_off = pos * half_dim; + float f = vec[i]; + float s = vec[i + half_dim]; + float c = cos_tbl[cos_off + i]; + float sv = sin_tbl[cos_off + i]; + vec[i] = f * c - s * sv; + vec[i + half_dim] = s * c + f * sv; +} + +// ── Batched vec_add: out[i] = a[i] + b[i] for N*dim elements ──────── +kernel void vec_add_batched( + device const float *a [[buffer(0)]], + device const float *b [[buffer(1)]], + device float *out [[buffer(2)]], + constant uint &total_n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= total_n) return; + out[gid] = a[gid] + b[gid]; +} + +// ── F16 matrix-vector multiply ─────────────────────────────────────── +kernel void sgemv_f16( + device const half *W [[buffer(0)]], + device const float *x [[buffer(1)]], + device float *y [[buffer(2)]], + constant uint &in_dim [[buffer(3)]], + constant uint &out_dim [[buffer(4)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= out_dim) return; + + device const half *row = W + uint64_t(gid) * in_dim; + float sum = 0.0f; + + uint i = 0; + for (; i + 7 < in_dim; i += 8) { + sum += float(row[i]) * x[i]; + sum += float(row[i + 1]) * x[i + 1]; + sum += float(row[i + 2]) * x[i + 2]; + sum += float(row[i + 3]) * x[i + 3]; + sum += float(row[i + 4]) * x[i + 4]; + sum += float(row[i + 5]) * x[i + 5]; + sum += float(row[i + 6]) * x[i + 6]; + sum += float(row[i + 7]) * x[i + 7]; + } + for (; i < in_dim; i++) + sum += float(row[i]) * x[i]; + + y[gid] = sum; +} + +// ── F32 matrix-vector multiply ─────────────────────────────────────── +kernel void sgemv_f32( + device const float *W [[buffer(0)]], + device const float *x [[buffer(1)]], + device float *y [[buffer(2)]], + constant uint &in_dim [[buffer(3)]], + constant uint &out_dim [[buffer(4)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= out_dim) return; + + device const float *row = W + uint64_t(gid) * in_dim; + float sum = 0.0f; + + uint i = 0; + for (; i + 7 < in_dim; i += 8) { + sum += row[i] * x[i]; + sum += row[i + 1] * x[i + 1]; + sum += row[i + 2] * x[i + 2]; + sum += row[i + 3] * x[i + 3]; + sum += row[i + 4] * x[i + 4]; + sum += row[i + 5] * x[i + 5]; + sum += row[i + 6] * x[i + 6]; + sum += row[i + 7] * x[i + 7]; + } + for (; i < in_dim; i++) + sum += row[i] * x[i]; + + y[gid] = sum; +} + +// ── RMS Normalization ──────────────────────────────────────────────── +// out[i] = x[i] * w[i] / sqrt(mean(x^2) + eps) +// Two-pass: first compute sum of squares (reduction), then normalize. +// Single threadgroup processes the entire vector. +kernel void rms_norm( + device const float *x [[buffer(0)]], + device const float *w [[buffer(1)]], + device float *out [[buffer(2)]], + constant uint &dim [[buffer(3)]], + constant float &eps [[buffer(4)]], + uint tid [[thread_index_in_threadgroup]], + uint tpg [[threads_per_threadgroup]]) +{ + threadgroup float partial[1024]; + + float local_sum = 0.0f; + for (uint i = tid; i < dim; i += tpg) + local_sum += x[i] * x[i]; + partial[tid] = local_sum; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Tree reduction + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s) partial[tid] += partial[tid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + float rms_inv = rsqrt(partial[0] / float(dim) + eps); + + for (uint i = tid; i < dim; i += tpg) + out[i] = x[i] * rms_inv * w[i]; +} + +// ── RoPE (Rotary Position Embedding) ───────────────────────────────── +// Applies RoPE to Q and K vectors in-place. +// cos_sin is precomputed: [half_dim] cos values followed by [half_dim] sin values. +kernel void rope_apply( + device float *q [[buffer(0)]], + device float *k [[buffer(1)]], + device const float *cos_v [[buffer(2)]], + device const float *sin_v [[buffer(3)]], + constant uint &n_q_heads [[buffer(4)]], + constant uint &n_kv_heads [[buffer(5)]], + constant uint &head_dim [[buffer(6)]], + uint gid [[thread_position_in_grid]]) +{ + uint half_dim = head_dim / 2; + uint total_pairs = (n_q_heads + n_kv_heads) * half_dim; + if (gid >= total_pairs) return; + + uint head_pair = gid / half_dim; + uint i = gid % half_dim; + + device float *vec; + if (head_pair < n_q_heads) { + vec = q + head_pair * head_dim; + } else { + vec = k + (head_pair - n_q_heads) * head_dim; + } + + float f = vec[i]; + float s = vec[i + half_dim]; + float c = cos_v[i]; + float sv = sin_v[i]; + vec[i] = f * c - s * sv; + vec[i + half_dim] = s * c + f * sv; +} + +// ── SiLU activation + element-wise multiply ────────────────────────── +// gate[i] = silu(gate[i]) * up[i] +// silu(x) = x / (1 + exp(-x)) +kernel void silu_mul( + device float *gate [[buffer(0)]], + device const float *up [[buffer(1)]], + constant uint &n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= n) return; + float x = gate[gid]; + float s = x / (1.0f + exp(-x)); + gate[gid] = s * up[gid]; +} + +// ── Vector add (residual connection) ───────────────────────────────── +// out[i] = a[i] + b[i] +kernel void vec_add( + device const float *a [[buffer(0)]], + device const float *b [[buffer(1)]], + device float *out [[buffer(2)]], + constant uint &n [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= n) return; + out[gid] = a[gid] + b[gid]; +} + +// ── Bias add ───────────────────────────────────────────────────────── +// x[i] += bias[i] +kernel void bias_add( + device float *x [[buffer(0)]], + device const float *bias [[buffer(1)]], + constant uint &n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= n) return; + x[gid] += bias[gid]; +} + +// ── Embedding lookup ───────────────────────────────────────────────── +// out[i] = embed[token_id * dim + i] +kernel void embed_lookup( + device const float *embed [[buffer(0)]], + device float *out [[buffer(1)]], + constant uint &token_id [[buffer(2)]], + constant uint &dim [[buffer(3)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= dim) return; + out[gid] = embed[uint64_t(token_id) * dim + gid]; +} + +// ── Attention score: Q @ K^T for one head (legacy) ────────────────── +kernel void attn_score( + device const float *qh [[buffer(0)]], + device const float *kv_cache_k [[buffer(1)]], + device float *att [[buffer(2)]], + constant uint &head_dim [[buffer(3)]], + constant uint &kv_dim [[buffer(4)]], + constant uint &kv_head_offset [[buffer(5)]], + constant float &scale [[buffer(6)]], + constant uint &seq_len [[buffer(7)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= seq_len) return; + + device const float *kt = kv_cache_k + uint64_t(gid) * kv_dim + kv_head_offset; + float dot = 0.0f; + for (uint i = 0; i < head_dim; i++) + dot += qh[i] * kt[i]; + att[gid] = dot * scale; +} + +// ── Batched attention score: all Q heads in one dispatch ───────────── +// Grid: (seq_len, n_q_heads, 1). Each thread computes one score for one head. +// GQA: maps Q head h to KV head h/gqa_factor. +kernel void attn_score_batched( + device const float *q [[buffer(0)]], + device const float *kv_cache_k [[buffer(1)]], + device float *att [[buffer(2)]], + constant uint &head_dim [[buffer(3)]], + constant uint &kv_dim [[buffer(4)]], + constant uint &n_q_heads [[buffer(5)]], + constant uint &gqa_factor [[buffer(6)]], + constant float &scale [[buffer(7)]], + constant uint &seq_len [[buffer(8)]], + constant uint &max_seq [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) +{ + uint t = gid.x; + uint h = gid.y; + if (t >= seq_len || h >= n_q_heads) return; + + uint kv_h = h / gqa_factor; + device const float *qh = q + h * head_dim; + device const float *kt = kv_cache_k + uint64_t(t) * kv_dim + kv_h * head_dim; + + float dot = 0.0f; + for (uint i = 0; i < head_dim; i++) + dot += qh[i] * kt[i]; + + att[h * max_seq + t] = dot * scale; +} + +// ── Batched softmax: all heads in one dispatch ─────────────────────── +// One threadgroup per head. tid reduces over seq_len dimension. +kernel void softmax_batched( + device float *att [[buffer(0)]], + constant uint &seq_len [[buffer(1)]], + constant uint &max_seq [[buffer(2)]], + constant uint &n_q_heads [[buffer(3)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint tpg [[threads_per_threadgroup]]) +{ + uint h = tgid; + if (h >= n_q_heads) return; + + device float *head_att = att + h * max_seq; + threadgroup float shared[1024]; + + float local_max = -1e30f; + for (uint i = tid; i < seq_len; i += tpg) + local_max = max(local_max, head_att[i]); + shared[tid] = local_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + float max_val = shared[0]; + + float local_sum = 0.0f; + for (uint i = tid; i < seq_len; i += tpg) { + float e = exp(head_att[i] - max_val); + head_att[i] = e; + local_sum += e; + } + shared[tid] = local_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s) shared[tid] += shared[tid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + float inv_sum = 1.0f / shared[0]; + + for (uint i = tid; i < seq_len; i += tpg) + head_att[i] *= inv_sum; +} + +// ── Batched attention weighted sum: all heads in one dispatch ──────── +// Grid: (head_dim, n_q_heads, 1). Each thread computes one output dim for one head. +kernel void attn_wsum_batched( + device const float *att [[buffer(0)]], + device const float *kv_cache_v [[buffer(1)]], + device float *out [[buffer(2)]], + constant uint &head_dim [[buffer(3)]], + constant uint &kv_dim [[buffer(4)]], + constant uint &n_q_heads [[buffer(5)]], + constant uint &gqa_factor [[buffer(6)]], + constant uint &seq_len [[buffer(7)]], + constant uint &max_seq [[buffer(8)]], + uint2 gid [[thread_position_in_grid]]) +{ + uint d = gid.x; + uint h = gid.y; + if (d >= head_dim || h >= n_q_heads) return; + + uint kv_h = h / gqa_factor; + device const float *head_att = att + h * max_seq; + + float sum = 0.0f; + for (uint t = 0; t < seq_len; t++) { + float a = head_att[t]; + float v = kv_cache_v[uint64_t(t) * kv_dim + kv_h * head_dim + d]; + sum += a * v; + } + out[h * head_dim + d] = sum; +} + +// ── Softmax (legacy, single head) ─────────────────────────────────── +kernel void softmax_inplace( + device float *att [[buffer(0)]], + constant uint &seq_len [[buffer(1)]], + uint tid [[thread_index_in_threadgroup]], + uint tpg [[threads_per_threadgroup]]) +{ + threadgroup float shared[1024]; + + float local_max = -1e30f; + for (uint i = tid; i < seq_len; i += tpg) + local_max = max(local_max, att[i]); + shared[tid] = local_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + float max_val = shared[0]; + + float local_sum = 0.0f; + for (uint i = tid; i < seq_len; i += tpg) { + float e = exp(att[i] - max_val); + att[i] = e; + local_sum += e; + } + shared[tid] = local_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s) shared[tid] += shared[tid + s]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + float inv_sum = 1.0f / shared[0]; + + for (uint i = tid; i < seq_len; i += tpg) + att[i] *= inv_sum; +} + +// ── Attention weighted sum (legacy, single head) ───────────────────── +kernel void attn_weighted_sum( + device const float *att [[buffer(0)]], + device const float *kv_cache_v [[buffer(1)]], + device float *out [[buffer(2)]], + constant uint &head_dim [[buffer(3)]], + constant uint &kv_dim [[buffer(4)]], + constant uint &kv_head_offset [[buffer(5)]], + constant uint &seq_len [[buffer(6)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= head_dim) return; + + float sum = 0.0f; + for (uint t = 0; t < seq_len; t++) { + float a = att[t]; + float v = kv_cache_v[uint64_t(t) * kv_dim + kv_head_offset + gid]; + sum += a * v; + } + out[gid] = sum; +} + +// ── Argmax ─────────────────────────────────────────────────────────── +// Finds argmax of logits[0..n-1], writes to result[0]. +// Single threadgroup. +kernel void argmax_kernel( + device const float *logits [[buffer(0)]], + device int *result [[buffer(1)]], + constant uint &n [[buffer(2)]], + uint tid [[thread_index_in_threadgroup]], + uint tpg [[threads_per_threadgroup]]) +{ + threadgroup float shared_val[1024]; + threadgroup int shared_idx[1024]; + + float local_max = -1e30f; + int local_idx = 0; + for (uint i = tid; i < n; i += tpg) { + if (logits[i] > local_max) { + local_max = logits[i]; + local_idx = int(i); + } + } + shared_val[tid] = local_max; + shared_idx[tid] = local_idx; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = tpg / 2; s > 0; s >>= 1) { + if (tid < s && shared_val[tid + s] > shared_val[tid]) { + shared_val[tid] = shared_val[tid + s]; + shared_idx[tid] = shared_idx[tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tid == 0) result[0] = shared_idx[0]; +} + +// ── Copy kernel ────────────────────────────────────────────────────── +// dst[i] = src[i] for i in [0, n) +kernel void vec_copy( + device const float *src [[buffer(0)]], + device float *dst [[buffer(1)]], + constant uint &n [[buffer(2)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= n) return; + dst[gid] = src[gid]; +} + +// ── Zero-fill ──────────────────────────────────────────────────────── +kernel void vec_zero( + device float *dst [[buffer(0)]], + constant uint &n [[buffer(1)]], + uint gid [[thread_position_in_grid]]) +{ + if (gid >= n) return; + dst[gid] = 0.0f; +} diff --git a/inference/qwen_ane_infer.h b/inference/qwen_ane_infer.h index 29634fa..47bd1f4 100644 --- a/inference/qwen_ane_infer.h +++ b/inference/qwen_ane_infer.h @@ -23,12 +23,75 @@ static ANEKernel *compile_conv_kernel(const float *weights, int in_ch, int out_c size_t outBytes = (size_t)out_ch * spatial * 4; return ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, 1, &inBytes, 1, &outBytes); } + +// Compile baked-weight conv with FP16 IOSurfaces (for fused ANE path) +static ANEKernel *compile_conv_kernel_fp16io(const float *weights, int in_ch, int out_ch, int spatial) { + int saved = g_fp16_io; g_fp16_io = 1; + NSData *wb = mil_build_weight_blob(weights, out_ch, in_ch); + NSString *mil = mil_gen_conv(in_ch, out_ch, spatial); + size_t inBytes = (size_t)in_ch * spatial * sizeof(_Float16); + size_t outBytes = (size_t)out_ch * spatial * sizeof(_Float16); + ANEKernel *k = ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, 1, &inBytes, 1, &outBytes); + g_fp16_io = saved; + return k; +} #include #include #include #include #include +static void *qwen_calloc(size_t count, size_t size, const char *desc) { + void *p = calloc(count, size); + if (!p) { + fprintf(stderr, "FATAL: calloc failed for %s (%.1f MB)\n", + desc, (double)(count * size) / (1024*1024)); + exit(1); + } + return p; +} + +// ── Metal GPU context (defined in main.m, used for GPU matmuls) ────── +#ifdef __OBJC__ +#import +#endif + +typedef struct { + void *device; // id + void *queue; // id + void *pipeline_f16; // id + void *pipeline_f32; // id + void *pipeline_q4; // id for sgemv_q4 + void *pipeline_rms; // id for rms_norm + void *pipeline_rope; // id for rope_apply + void *pipeline_silu; // id for silu_mul + void *pipeline_add; // id for vec_add + void *pipeline_bias; // id for bias_add + void *pipeline_embed; // id for embed_lookup + void *pipeline_attn_score; // id + void *pipeline_softmax; // id + void *pipeline_attn_wsum; // id + void *pipeline_argmax; // id + void *pipeline_copy; // id + void *pipeline_zero; // id + void *pipeline_q4_fast; // id for sgemv_q4_fast (SIMD) + void *pipeline_q4_fused_ffn; // id for fused gate+up+silu + void *pipeline_attn_score_b; // batched attn_score (all heads) + void *pipeline_softmax_b; // batched softmax (all heads) + void *pipeline_attn_wsum_b; // batched attn weighted sum (all heads) + void *pipeline_sgemm_q4; // batched Q4 matmul (prefill) + void *pipeline_sgemm_q4_fused_ffn; // batched fused FFN (prefill) + void *pipeline_rms_batched; // batched RMSNorm (prefill) + void *pipeline_embed_batched; // batched embed lookup (prefill) + void *pipeline_rope_batched; // batched RoPE (prefill) + void *pipeline_add_batched; // batched vec_add (prefill) + void *x_buf; // id for input vector (reusable) + void *y_buf; // id for output vector (reusable) + int initialized; +} MetalContext; + +static MetalContext g_metal = {0}; + #ifndef QWEN_DEBUG #define QWEN_DEBUG 0 #endif @@ -53,21 +116,62 @@ static ANEKernel *compile_conv_kernel(const float *weights, int in_ch, int out_c #define QWEN_KV_DIM (QWEN_KV_HEADS * QWEN_HEAD_DIM) // 128 typedef struct { - // Weights (f32) + // Weight format: 0 = F32 everywhere, 1 = F16 projections + int weight_fmt; + + // Embeddings + norms always F32 float *embed; // [vocab, dim] float *rms_att[QWEN_LAYERS]; // [dim] - float *wq[QWEN_LAYERS]; // [q_dim, dim] - float *wk[QWEN_LAYERS]; // [kv_dim, dim] - float *wv[QWEN_LAYERS]; // [kv_dim, dim] - float *wo[QWEN_LAYERS]; // [dim, q_dim] float *rms_ffn[QWEN_LAYERS]; // [dim] - float *w_gate[QWEN_LAYERS]; // [hidden, dim] - float *w_up[QWEN_LAYERS]; // [hidden, dim] - float *w_down[QWEN_LAYERS]; // [dim, hidden] float *rms_final; // [dim] - // wcls = embed (tied) - // ANE kernels (one per linear projection per layer) + // Projection weights: F32 or F16 depending on weight_fmt + // When weight_fmt=1, the f32 pointers are NULL and f16 pointers are set + float *wq[QWEN_LAYERS]; // [q_dim, dim] (F32) + float *wk[QWEN_LAYERS]; // [kv_dim, dim] (F32) + float *wv[QWEN_LAYERS]; // [kv_dim, dim] (F32) + float *wo[QWEN_LAYERS]; // [dim, q_dim] (F32) + float *w_gate[QWEN_LAYERS]; // [hidden, dim] (F32) + float *w_up[QWEN_LAYERS]; // [hidden, dim] (F32) + float *w_down[QWEN_LAYERS]; // [dim, hidden] (F32) + + _Float16 *wq_f16[QWEN_LAYERS]; // (F16) + _Float16 *wk_f16[QWEN_LAYERS]; + _Float16 *wv_f16[QWEN_LAYERS]; + _Float16 *wo_f16[QWEN_LAYERS]; + _Float16 *wgate_f16[QWEN_LAYERS]; + _Float16 *wup_f16[QWEN_LAYERS]; + _Float16 *wdown_f16[QWEN_LAYERS]; + + uint8_t *wq_q8[QWEN_LAYERS]; // (Q8_0 blocks) + uint8_t *wk_q8[QWEN_LAYERS]; + uint8_t *wv_q8[QWEN_LAYERS]; + uint8_t *wo_q8[QWEN_LAYERS]; + uint8_t *wgate_q8[QWEN_LAYERS]; + uint8_t *wup_q8[QWEN_LAYERS]; + uint8_t *wdown_q8[QWEN_LAYERS]; + + // Metal GPU buffers (id cast to void*) + void *gpu_wq[QWEN_LAYERS]; + void *gpu_wk[QWEN_LAYERS]; + void *gpu_wv[QWEN_LAYERS]; + void *gpu_wo[QWEN_LAYERS]; + void *gpu_wgate[QWEN_LAYERS]; + void *gpu_wup[QWEN_LAYERS]; + void *gpu_wdown[QWEN_LAYERS]; + void *gpu_embed; // embedding table (F32) + void *gpu_rms_att[QWEN_LAYERS]; // RMSNorm weights + void *gpu_rms_ffn[QWEN_LAYERS]; + void *gpu_rms_final; + void *gpu_q_bias[QWEN_LAYERS]; + void *gpu_k_bias[QWEN_LAYERS]; + void *gpu_v_bias[QWEN_LAYERS]; + void *gpu_kv_cache_k[QWEN_LAYERS]; + void *gpu_kv_cache_v[QWEN_LAYERS]; + int use_gpu; + // wcls = embed (tied, always F32) + + // ANE kernels -- unfused (one per linear projection per layer) ANEKernel *k_q[QWEN_LAYERS]; ANEKernel *k_k[QWEN_LAYERS]; ANEKernel *k_v[QWEN_LAYERS]; @@ -80,6 +184,11 @@ typedef struct { #define QWEN_LM_CHUNK_SIZE 9496 // 151936 / 16 ANEKernel *k_lmhead[QWEN_LM_CHUNKS]; + // ANE kernels -- fused (reduces 184 → 112 kernels, under 119 limit) + ANEKernel *k_qkv[QWEN_LAYERS]; // fused Q+K+V → 3 outputs + ANEKernel *k_ffn_up[QWEN_LAYERS]; // fused Gate+Up → 2 outputs + int use_ane; // 1 = fused ANE matmuls + CPU element-wise + // Q/K/V biases per layer float *q_bias[QWEN_LAYERS]; // [q_dim] float *k_bias[QWEN_LAYERS]; // [kv_dim] @@ -190,21 +299,92 @@ static void qwen_silu(float *x, int n) { x[i] = x[i] / (1.0f + expf(-x[i])); } -// ── ANE projection helper (single token: spatial=1) ───────────────── +// ── ANE projection helpers ────────────────────────────────────────── +// ANE IOSurfaces are always FP16 at the hardware level. +// We use g_fp16_io=1 MIL (FP16 I/O, no cast ops) and convert F32<->F16 here. static inline bool ane_run(ANEKernel *k) { return ane_eval(k); } +static void ane_write_f32_as_f16(ANEKernel *kernel, int idx, const float *f32, int n) { + IOSurfaceLock(kernel->ioInputs[idx], 0, NULL); + _Float16 *dst = (_Float16 *)IOSurfaceGetBaseAddress(kernel->ioInputs[idx]); + for (int i = 0; i < n; i++) dst[i] = (_Float16)f32[i]; + IOSurfaceUnlock(kernel->ioInputs[idx], 0, NULL); +} + +static void ane_read_f16_to_f32(ANEKernel *kernel, int idx, float *f32, int n) { + IOSurfaceLock(kernel->ioOutputs[idx], kIOSurfaceLockReadOnly, NULL); + const _Float16 *src = (const _Float16 *)IOSurfaceGetBaseAddress(kernel->ioOutputs[idx]); + for (int i = 0; i < n; i++) f32[i] = (float)src[i]; + IOSurfaceUnlock(kernel->ioOutputs[idx], kIOSurfaceLockReadOnly, NULL); +} + static void ane_project(ANEKernel *kernel, const float *in, float *out, int in_dim, int out_dim) { - ane_write_input(kernel, 0, in, in_dim * sizeof(float)); + ane_write_f32_as_f16(kernel, 0, in, in_dim); ane_run(kernel); - ane_read_output(kernel, 0, out, out_dim * sizeof(float)); + ane_read_f16_to_f32(kernel, 0, out, out_dim); +} + +// Fused QKV: one ANE kernel → 3 outputs (Q, K, V with different dims) +static void ane_project_qkv(ANEKernel *kernel, const float *in, + float *q, float *k, float *v, + int in_dim, int q_dim, int kv_dim) { + ane_write_f32_as_f16(kernel, 0, in, in_dim); + ane_run(kernel); + ane_read_f16_to_f32(kernel, 0, q, q_dim); + ane_read_f16_to_f32(kernel, 1, k, kv_dim); + ane_read_f16_to_f32(kernel, 2, v, kv_dim); +} + +// Fused Gate+Up: one ANE kernel → 2 outputs (gate, up) +static void ane_project_ffn_up(ANEKernel *kernel, const float *in, + float *gate, float *up, + int in_dim, int hidden_dim) { + ane_write_f32_as_f16(kernel, 0, in, in_dim); + ane_run(kernel); + ane_read_f16_to_f32(kernel, 0, gate, hidden_dim); + ane_read_f16_to_f32(kernel, 1, up, hidden_dim); +} + +// Compile fused QKV kernel (GQA-aware: Q=[q_dim,dim], K/V=[kv_dim,dim]) +// Uses FP16 IOSurfaces (ANE hardware requirement) +static ANEKernel *compile_qkv_gqa_kernel(const float *wq, const float *wk, const float *wv, + int dim, int q_dim, int kv_dim) { + int saved = g_fp16_io; g_fp16_io = 1; + NSData *wb = mil_build_qkv_gqa_weight_blob(wq, q_dim, dim, wk, wv, kv_dim); + NSString *mil = mil_gen_qkv_gqa(dim, q_dim, kv_dim, 1); + size_t inBytes = (size_t)dim * sizeof(_Float16); + size_t outSizes[3] = { + (size_t)q_dim * sizeof(_Float16), + (size_t)kv_dim * sizeof(_Float16), + (size_t)kv_dim * sizeof(_Float16) + }; + ANEKernel *k = ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, + 1, &inBytes, 3, outSizes); + g_fp16_io = saved; + return k; +} + +// Compile fused FFN up kernel (Gate + Up, both [hidden_dim, dim]) +static ANEKernel *compile_ffn_up_kernel(const float *w_gate, const float *w_up, + int dim, int hidden_dim) { + int saved = g_fp16_io; g_fp16_io = 1; + NSData *wb = mil_build_ffn_up_weight_blob(w_gate, w_up, hidden_dim, dim); + NSString *mil = mil_gen_ffn_up(dim, hidden_dim, 1); + size_t inBytes = (size_t)dim * sizeof(_Float16); + size_t outSizes[2] = { + (size_t)hidden_dim * sizeof(_Float16), + (size_t)hidden_dim * sizeof(_Float16) + }; + ANEKernel *k = ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, + 1, &inBytes, 2, outSizes); + g_fp16_io = saved; + return k; } // CPU matmul via Accelerate BLAS: y = W @ x, W[out_dim, in_dim] static void cpu_project(const float *W, const float *x, float *y, int in_dim, int out_dim) { - // y = W @ x where W is [out_dim, in_dim] row-major - // cblas_sgemv: y = alpha * A * x + beta * y cblas_sgemv(CblasRowMajor, CblasNoTrans, out_dim, in_dim, 1.0f, W, in_dim, @@ -212,9 +392,309 @@ static void cpu_project(const float *W, const float *x, float *y, int in_dim, in 0.0f, y, 1); } +// Bulk F16→F32 conversion using NEON vcvt +static void convert_f16_to_f32(const _Float16 *src, float *dst, size_t n) { + size_t i = 0; + for (; i + 7 < n; i += 8) { + float16x8_t h = vld1q_f16((const __fp16*)(src + i)); + vst1q_f32(dst + i, vcvt_f32_f16(vget_low_f16(h))); + vst1q_f32(dst + i + 4, vcvt_f32_f16(vget_high_f16(h))); + } + for (; i < n; i++) + dst[i] = (float)src[i]; +} + +// ── Q8_0 quantization support ──────────────────────────────────────── +// Block format: 2 bytes F16 scale + 32 bytes int8 values = 34 bytes per block +#define Q8_BLOCK_SIZE 32 +#define Q8_BLOCK_BYTES (2 + Q8_BLOCK_SIZE) // 34 + +// Q8 matmul: y = W_q8 @ x, dequantize-and-dot using NEON int8 +// W is stored as blocks of [f16_scale, 32*int8], row-major +static void cpu_project_q8(const uint8_t *W, const float *x, float *y, + int in_dim, int out_dim) { + int n_blocks = in_dim / Q8_BLOCK_SIZE; + size_t row_bytes = (size_t)n_blocks * Q8_BLOCK_BYTES; + + for (int r = 0; r < out_dim; r++) { + const uint8_t *row = W + (size_t)r * row_bytes; + float sum = 0.0f; + + for (int b = 0; b < n_blocks; b++) { + const uint8_t *block = row + (size_t)b * Q8_BLOCK_BYTES; + _Float16 scale_f16; + memcpy(&scale_f16, block, 2); + float scale = (float)scale_f16; + const int8_t *qvals = (const int8_t*)(block + 2); + const float *xb = x + b * Q8_BLOCK_SIZE; + + // NEON: load 32 int8 values, widen to int16, convert to f32, FMA + int8x16_t q0 = vld1q_s8(qvals); + int8x16_t q1 = vld1q_s8(qvals + 16); + + // Widen int8 -> int16 -> int32 -> float32, then FMA with x + int16x8_t w0 = vmovl_s8(vget_low_s8(q0)); + int16x8_t w1 = vmovl_s8(vget_high_s8(q0)); + int16x8_t w2 = vmovl_s8(vget_low_s8(q1)); + int16x8_t w3 = vmovl_s8(vget_high_s8(q1)); + + float32x4_t a0 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(w0))), vld1q_f32(xb)); + float32x4_t a1 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(w0))), vld1q_f32(xb + 4)); + float32x4_t a2 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(w1))), vld1q_f32(xb + 8)); + float32x4_t a3 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(w1))), vld1q_f32(xb + 12)); + float32x4_t a4 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(w2))), vld1q_f32(xb + 16)); + float32x4_t a5 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(w2))), vld1q_f32(xb + 20)); + float32x4_t a6 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(w3))), vld1q_f32(xb + 24)); + float32x4_t a7 = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(w3))), vld1q_f32(xb + 28)); + + float32x4_t s01 = vaddq_f32(a0, a1); + float32x4_t s23 = vaddq_f32(a2, a3); + float32x4_t s45 = vaddq_f32(a4, a5); + float32x4_t s67 = vaddq_f32(a6, a7); + float32x4_t stot = vaddq_f32(vaddq_f32(s01, s23), vaddq_f32(s45, s67)); + sum += scale * vaddvq_f32(stot); + } + y[r] = sum; + } +} + +// ── Q4_0 block constants ───────────────────────────────────────────── +#define Q4_BLOCK_SIZE 32 +#define Q4_BLOCK_BYTES 20 // 2(scale) + 2(zero) + 16(packed) + +// ── Q4_0 dequantization helper: Q4 blocks to F32 ── +// Dequantizes one weight matrix from Q4 blocks into a caller-provided F32 buffer. +static void dequant_q4_to_f32(const uint8_t *W_q4, float *W_f32, + int in_dim, int out_dim) { + int n_blocks = in_dim / Q4_BLOCK_SIZE; + size_t row_bytes = (size_t)n_blocks * Q4_BLOCK_BYTES; + + for (int r = 0; r < out_dim; r++) { + const uint8_t *row = W_q4 + (size_t)r * row_bytes; + float *out_row = W_f32 + (size_t)r * in_dim; + + for (int b = 0; b < n_blocks; b++) { + const uint8_t *block = row + (size_t)b * Q4_BLOCK_BYTES; + _Float16 scale_f16, zero_f16; + memcpy(&scale_f16, block, 2); + memcpy(&zero_f16, block + 2, 2); + float scale = (float)scale_f16; + float zero = (float)zero_f16; + const uint8_t *packed = block + 4; + float *out = out_row + b * Q4_BLOCK_SIZE; + + for (int i = 0; i < 16; i++) { + uint8_t byte = packed[i]; + out[i * 2] = (float)(byte & 0xF) * scale + zero; + out[i * 2 + 1] = (float)(byte >> 4) * scale + zero; + } + } + } +} + +// Q4 fused NEON dequant-and-dot: reads Q4 from memory, avoids F32 intermediate +// Each block: 2B F16 scale + 2B F16 zero + 16B packed uint8 (32 values) +// Uses NEON to extract nibbles, convert to float, FMA with input vector +static void cpu_project_q4_amx(const uint8_t *W_q4, const float *x, float *y, + int in_dim, int out_dim) { + int n_blocks = in_dim / Q4_BLOCK_SIZE; + size_t row_bytes = (size_t)n_blocks * Q4_BLOCK_BYTES; + + for (int r = 0; r < out_dim; r++) { + const uint8_t *row = W_q4 + (size_t)r * row_bytes; + float32x4_t acc0 = vdupq_n_f32(0.0f); + float32x4_t acc1 = vdupq_n_f32(0.0f); + + for (int b = 0; b < n_blocks; b++) { + const uint8_t *block = row + (size_t)b * Q4_BLOCK_BYTES; + _Float16 scale_f16, zero_f16; + memcpy(&scale_f16, block, 2); + memcpy(&zero_f16, block + 2, 2); + float scale = (float)scale_f16; + float zero = (float)zero_f16; + const uint8_t *packed = block + 4; + const float *xb = x + b * Q4_BLOCK_SIZE; + float32x4_t vscale = vdupq_n_f32(scale); + float32x4_t vzero = vdupq_n_f32(zero); + + // Process 16 packed bytes = 32 values, 8 values at a time + for (int i = 0; i < 16; i += 4) { + // Load 4 packed bytes + uint8x8_t raw = vld1_u8(packed + i); // only 4 used + + // Extract low and high nibbles + uint8_t b0 = packed[i], b1 = packed[i+1], b2 = packed[i+2], b3 = packed[i+3]; + + // Even indices (low nibbles): b0&0xF, b1&0xF, b2&0xF, b3&0xF + float32x4_t wlo = vmlaq_f32(vzero, vcvtq_f32_u32((uint32x4_t){ + b0 & 0xF, b1 & 0xF, b2 & 0xF, b3 & 0xF}), vscale); + // Odd indices (high nibbles): b0>>4, b1>>4, b2>>4, b3>>4 + float32x4_t whi = vmlaq_f32(vzero, vcvtq_f32_u32((uint32x4_t){ + b0 >> 4, b1 >> 4, b2 >> 4, b3 >> 4}), vscale); + + // Interleaved dot: x[0]*w[0] + x[1]*w[1] + ... (even/odd pairs) + int xi = i * 2; + float32x4_t x_even = {xb[xi], xb[xi+2], xb[xi+4], xb[xi+6]}; + float32x4_t x_odd = {xb[xi+1], xb[xi+3], xb[xi+5], xb[xi+7]}; + + acc0 = vmlaq_f32(acc0, wlo, x_even); + acc1 = vmlaq_f32(acc1, whi, x_odd); + } + } + y[r] = vaddvq_f32(vaddq_f32(acc0, acc1)); + } +} + +// Q4 batched projection: dequant full matrix to F32, then cblas_sgemm +static void cpu_project_batch_q4_amx(const uint8_t *W_q4, const float *X, float *Y, + int in_dim, int out_dim, int n_tokens) { + float *W_f32 = (float*)malloc((size_t)out_dim * in_dim * sizeof(float)); + dequant_q4_to_f32(W_q4, W_f32, in_dim, out_dim); + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + n_tokens, out_dim, in_dim, + 1.0f, X, in_dim, + W_f32, in_dim, + 0.0f, Y, out_dim); + free(W_f32); +} + // Toggle: 1 = use ANE for projections, 0 = CPU fallback #define USE_ANE_PROJECTIONS 0 +// ── Metal GPU matmul ───────────────────────────────────────────────── +#ifdef __OBJC__ + +static int metal_init(void) { + if (g_metal.initialized) return 0; + + id dev = MTLCreateSystemDefaultDevice(); + if (!dev) { fprintf(stderr, "Metal: no GPU device\n"); return -1; } + + NSString *shaderPath = [[NSBundle mainBundle] pathForResource:@"matmul" ofType:@"metallib"]; + NSError *error = nil; + id lib = nil; + + // Try loading from compiled metallib next to binary + NSString *execDir = [[[NSProcessInfo processInfo] arguments][0] stringByDeletingLastPathComponent]; + NSString *libPath = [execDir stringByAppendingPathComponent:@"matmul.metallib"]; + if ([[NSFileManager defaultManager] fileExistsAtPath:libPath]) { + lib = [dev newLibraryWithURL:[NSURL fileURLWithPath:libPath] error:&error]; + } + + // Fall back to compiling from source + if (!lib) { + NSString *srcPath = [execDir stringByAppendingPathComponent:@"matmul.metal"]; + NSString *src = [NSString stringWithContentsOfFile:srcPath + encoding:NSUTF8StringEncoding error:&error]; + if (!src) { + fprintf(stderr, "Metal: cannot read shader source: %s\n", + [[error description] UTF8String]); + return -1; + } + MTLCompileOptions *opts = [[MTLCompileOptions alloc] init]; + if (@available(macOS 15.0, *)) { + opts.mathMode = MTLMathModeFast; + } else { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + opts.fastMathEnabled = YES; +#pragma clang diagnostic pop + } + lib = [dev newLibraryWithSource:src options:opts error:&error]; + if (!lib) { + fprintf(stderr, "Metal: shader compile failed: %s\n", + [[error description] UTF8String]); + return -1; + } + } + + // Build all pipeline states + NSArray *names = @[ + @"sgemv_f16", @"sgemv_f32", @"sgemv_q4", + @"rms_norm", @"rope_apply", @"silu_mul", + @"vec_add", @"bias_add", @"embed_lookup", + @"attn_score", @"softmax_inplace", @"attn_weighted_sum", + @"argmax_kernel", @"vec_copy", @"vec_zero", + @"sgemv_q4_fast", @"sgemv_q4_fused_ffn", + @"attn_score_batched", @"softmax_batched", @"attn_wsum_batched", + @"sgemm_q4", @"sgemm_q4_fused_ffn", + @"rms_norm_batched", @"embed_lookup_batched", + @"rope_apply_batched", @"vec_add_batched" + ]; + void **pipelines[] = { + &g_metal.pipeline_f16, &g_metal.pipeline_f32, &g_metal.pipeline_q4, + &g_metal.pipeline_rms, &g_metal.pipeline_rope, &g_metal.pipeline_silu, + &g_metal.pipeline_add, &g_metal.pipeline_bias, &g_metal.pipeline_embed, + &g_metal.pipeline_attn_score, &g_metal.pipeline_softmax, &g_metal.pipeline_attn_wsum, + &g_metal.pipeline_argmax, &g_metal.pipeline_copy, &g_metal.pipeline_zero, + &g_metal.pipeline_q4_fast, &g_metal.pipeline_q4_fused_ffn, + &g_metal.pipeline_attn_score_b, &g_metal.pipeline_softmax_b, &g_metal.pipeline_attn_wsum_b, + &g_metal.pipeline_sgemm_q4, &g_metal.pipeline_sgemm_q4_fused_ffn, + &g_metal.pipeline_rms_batched, &g_metal.pipeline_embed_batched, + &g_metal.pipeline_rope_batched, &g_metal.pipeline_add_batched + }; + + for (int i = 0; i < (int)[names count]; i++) { + id fn = [lib newFunctionWithName:names[i]]; + if (!fn) { + fprintf(stderr, "Metal: missing shader function '%s'\n", [names[i] UTF8String]); + return -1; + } + id pso = [dev newComputePipelineStateWithFunction:fn error:&error]; + if (!pso) { + fprintf(stderr, "Metal: pipeline for '%s' failed: %s\n", + [names[i] UTF8String], [[error description] UTF8String]); + return -1; + } + *pipelines[i] = (__bridge_retained void*)pso; + } + + g_metal.device = (__bridge_retained void*)dev; + g_metal.queue = (__bridge_retained void*)[dev newCommandQueue]; + + g_metal.initialized = 1; + printf("Metal GPU initialized (%s)\n", [[dev name] UTF8String]); + return 0; +} + +// GPU projection for F16 weights: dispatches Metal compute shader +// Uses per-call output buffers to allow batching multiple projections +static void gpu_project_f16(id w_buf, const float *x, float *y, + int in_dim, int out_dim) { + id dev = (__bridge id)g_metal.device; + id queue = (__bridge id)g_metal.queue; + id pipeline = (__bridge id)g_metal.pipeline_f16; + + // Shared input/output buffers + id x_buf = [dev newBufferWithBytes:x + length:in_dim * sizeof(float) + options:MTLResourceStorageModeShared]; + id y_buf = [dev newBufferWithLength:out_dim * sizeof(float) + options:MTLResourceStorageModeShared]; + + id cmd = [queue commandBuffer]; + id enc = [cmd computeCommandEncoder]; + [enc setComputePipelineState:pipeline]; + [enc setBuffer:w_buf offset:0 atIndex:0]; + [enc setBuffer:x_buf offset:0 atIndex:1]; + [enc setBuffer:y_buf offset:0 atIndex:2]; + uint32_t dims[2] = {(uint32_t)in_dim, (uint32_t)out_dim}; + [enc setBytes:&dims[0] length:sizeof(uint32_t) atIndex:3]; + [enc setBytes:&dims[1] length:sizeof(uint32_t) atIndex:4]; + + NSUInteger tpg = pipeline.maxTotalThreadsPerThreadgroup; + if (tpg > (NSUInteger)out_dim) tpg = (NSUInteger)out_dim; + [enc dispatchThreads:MTLSizeMake(out_dim, 1, 1) + threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + memcpy(y, [y_buf contents], out_dim * sizeof(float)); +} + +#endif // __OBJC__ + // ── Forward one token ──────────────────────────────────────────────── static int qwen_forward(QwenModel *m, int token) { @@ -237,15 +717,25 @@ static int qwen_forward(QwenModel *m, int token) { } #endif - // QKV projections (ANE) + bias + // QKV projections + bias (CPU path -- GPU overhead too high for small matmuls) #if USE_ANE_PROJECTIONS ane_project(m->k_q[l], m->xb, m->q, D, QWEN_Q_DIM); ane_project(m->k_k[l], m->xb, m->k, D, QWEN_KV_DIM); ane_project(m->k_v[l], m->xb, m->v, D, QWEN_KV_DIM); #else - cpu_project(m->wq[l], m->xb, m->q, D, QWEN_Q_DIM); - cpu_project(m->wk[l], m->xb, m->k, D, QWEN_KV_DIM); - cpu_project(m->wv[l], m->xb, m->v, D, QWEN_KV_DIM); + if (m->weight_fmt == 3) { + cpu_project_q4_amx(m->wq_q8[l], m->xb, m->q, D, QWEN_Q_DIM); + cpu_project_q4_amx(m->wk_q8[l], m->xb, m->k, D, QWEN_KV_DIM); + cpu_project_q4_amx(m->wv_q8[l], m->xb, m->v, D, QWEN_KV_DIM); + } else if (m->weight_fmt == 2) { + cpu_project_q8(m->wq_q8[l], m->xb, m->q, D, QWEN_Q_DIM); + cpu_project_q8(m->wk_q8[l], m->xb, m->k, D, QWEN_KV_DIM); + cpu_project_q8(m->wv_q8[l], m->xb, m->v, D, QWEN_KV_DIM); + } else { + cpu_project(m->wq[l], m->xb, m->q, D, QWEN_Q_DIM); + cpu_project(m->wk[l], m->xb, m->k, D, QWEN_KV_DIM); + cpu_project(m->wv[l], m->xb, m->v, D, QWEN_KV_DIM); + } #endif // Apply Q/K/V biases (vectorized) if (m->q_bias[l]) @@ -323,7 +813,12 @@ static int qwen_forward(QwenModel *m, int token) { #if USE_ANE_PROJECTIONS ane_project(m->k_o[l], attn_out, o_out, QWEN_Q_DIM, D); #else - cpu_project(m->wo[l], attn_out, o_out, QWEN_Q_DIM, D); + if (m->weight_fmt == 3) + cpu_project_q4_amx(m->wo_q8[l], attn_out, o_out, QWEN_Q_DIM, D); + else if (m->weight_fmt == 2) + cpu_project_q8(m->wo_q8[l], attn_out, o_out, QWEN_Q_DIM, D); + else + cpu_project(m->wo[l], attn_out, o_out, QWEN_Q_DIM, D); #endif // Residual (vectorized) @@ -350,8 +845,16 @@ static int qwen_forward(QwenModel *m, int token) { ane_project(m->k_gate[l], m->xb, m->hb, D, HD); ane_project(m->k_up[l], m->xb, m->hb2, D, HD); #else - cpu_project(m->w_gate[l], m->xb, m->hb, D, HD); - cpu_project(m->w_up[l], m->xb, m->hb2, D, HD); + if (m->weight_fmt == 3) { + cpu_project_q4_amx(m->wgate_q8[l], m->xb, m->hb, D, HD); + cpu_project_q4_amx(m->wup_q8[l], m->xb, m->hb2, D, HD); + } else if (m->weight_fmt == 2) { + cpu_project_q8(m->wgate_q8[l], m->xb, m->hb, D, HD); + cpu_project_q8(m->wup_q8[l], m->xb, m->hb2, D, HD); + } else { + cpu_project(m->w_gate[l], m->xb, m->hb, D, HD); + cpu_project(m->w_up[l], m->xb, m->hb2, D, HD); + } #endif #if QWEN_DEBUG @@ -372,7 +875,12 @@ static int qwen_forward(QwenModel *m, int token) { #if USE_ANE_PROJECTIONS ane_project(m->k_down[l], m->hb, ffn_out, HD, D); #else - cpu_project(m->w_down[l], m->hb, ffn_out, HD, D); + if (m->weight_fmt == 3) + cpu_project_q4_amx(m->wdown_q8[l], m->hb, ffn_out, HD, D); + else if (m->weight_fmt == 2) + cpu_project_q8(m->wdown_q8[l], m->hb, ffn_out, HD, D); + else + cpu_project(m->w_down[l], m->hb, ffn_out, HD, D); #endif // Residual (vectorized) @@ -400,7 +908,7 @@ static int qwen_forward(QwenModel *m, int token) { } #endif - // LM head via Accelerate BLAS: logits = embed @ xb + // LM head via Accelerate BLAS (AMX, fastest for dim<=896) cblas_sgemv(CblasRowMajor, CblasNoTrans, QWEN_VOCAB, D, 1.0f, m->embed, D, @@ -429,6 +937,987 @@ static int qwen_forward(QwenModel *m, int token) { return (int)max_idx_vdsp; } +// ── ANE fused forward pass: ANE for matmuls, CPU for element-wise ops ── +// Uses fused QKV and Gate+Up kernels (112 total, under 119 ANE limit). +// O-proj and Down-proj remain as single conv kernels. + +static int qwen_forward_ane(QwenModel *m, int token) { + int D = QWEN_DIM, HD = QWEN_HIDDEN; + int pos = m->pos; + + memcpy(m->x, m->embed + token * D, D * sizeof(float)); + + for (int l = 0; l < QWEN_LAYERS; l++) { + qwen_rmsnorm(m->xb, m->x, m->rms_att[l], D); + + // Fused QKV projection (1 ANE eval → Q, K, V) + ane_project_qkv(m->k_qkv[l], m->xb, m->q, m->k, m->v, + D, QWEN_Q_DIM, QWEN_KV_DIM); + + // Biases (CPU, vectorized) + if (m->q_bias[l]) + vDSP_vadd(m->q, 1, m->q_bias[l], 1, m->q, 1, (vDSP_Length)QWEN_Q_DIM); + if (m->k_bias[l]) + vDSP_vadd(m->k, 1, m->k_bias[l], 1, m->k, 1, (vDSP_Length)QWEN_KV_DIM); + if (m->v_bias[l]) + vDSP_vadd(m->v, 1, m->v_bias[l], 1, m->v, 1, (vDSP_Length)QWEN_KV_DIM); + + qwen_rope(m->q, m->k, pos, QWEN_HEADS, QWEN_KV_HEADS, QWEN_HEAD_DIM); + + memcpy(m->kv_cache_k[l] + pos * QWEN_KV_DIM, m->k, QWEN_KV_DIM * sizeof(float)); + memcpy(m->kv_cache_v[l] + pos * QWEN_KV_DIM, m->v, QWEN_KV_DIM * sizeof(float)); + + // GQA attention (CPU) + float scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM); + float *attn_out = m->xb; + memset(attn_out, 0, QWEN_Q_DIM * sizeof(float)); + + for (int h = 0; h < QWEN_HEADS; h++) { + int kv_h = h / QWEN_GQA_FACTOR; + float *qh = m->q + h * QWEN_HEAD_DIM; + float *att_h = m->att + h * QWEN_MAX_SEQ; + int seq_len = pos + 1; + + float max_score = -1e9f; + for (int t = 0; t <= pos; t++) { + float *kt = m->kv_cache_k[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM; + float score = cblas_sdot(QWEN_HEAD_DIM, qh, 1, kt, 1); + att_h[t] = score * scale; + if (att_h[t] > max_score) max_score = att_h[t]; + } + float neg_max = -max_score; + vDSP_vsadd(att_h, 1, &neg_max, att_h, 1, (vDSP_Length)seq_len); + int n_exp = seq_len; + vvexpf(att_h, att_h, &n_exp); + float sum; + vDSP_sve(att_h, 1, &sum, (vDSP_Length)seq_len); + float inv_sum = 1.0f / sum; + vDSP_vsmul(att_h, 1, &inv_sum, att_h, 1, (vDSP_Length)seq_len); + + for (int t = 0; t <= pos; t++) { + float a = att_h[t]; + float *vt = m->kv_cache_v[l] + t * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM; + cblas_saxpy(QWEN_HEAD_DIM, a, vt, 1, attn_out + h * QWEN_HEAD_DIM, 1); + } + } + + // O projection (single ANE kernel) + float o_out[QWEN_DIM]; + ane_project(m->k_o[l], attn_out, o_out, QWEN_Q_DIM, D); + + vDSP_vadd(m->x, 1, o_out, 1, m->x, 1, (vDSP_Length)D); + + qwen_rmsnorm(m->xb, m->x, m->rms_ffn[l], D); + + // Fused Gate+Up projection (1 ANE eval → gate, up) + ane_project_ffn_up(m->k_ffn_up[l], m->xb, m->hb, m->hb2, D, HD); + + qwen_silu(m->hb, HD); + vDSP_vmul(m->hb, 1, m->hb2, 1, m->hb, 1, (vDSP_Length)HD); + + // Down projection (single ANE kernel) + float ffn_out[QWEN_DIM]; + ane_project(m->k_down[l], m->hb, ffn_out, HD, D); + + vDSP_vadd(m->x, 1, ffn_out, 1, m->x, 1, (vDSP_Length)D); + } + + qwen_rmsnorm(m->xb, m->x, m->rms_final, D); + + // LM head: CPU AMX (too large for ANE, 151936 outputs) + cblas_sgemv(CblasRowMajor, CblasNoTrans, + QWEN_VOCAB, D, + 1.0f, m->embed, D, + m->xb, 1, + 0.0f, m->logits, 1); + + m->pos++; + + float max_val; + vDSP_Length max_idx_vdsp; + vDSP_maxvi(m->logits, 1, &max_val, &max_idx_vdsp, (vDSP_Length)QWEN_VOCAB); + return (int)max_idx_vdsp; +} + +// ── Batched prefill: process all prompt tokens at once ──────────────── +// Uses cblas_sgemm (matrix-matrix) instead of sequential sgemv calls. +// Returns the argmax token from the last position's logits. + +static void cpu_project_batch(const float *W, const float *X, float *Y, + int in_dim, int out_dim, int n_tokens) { + // X[n_tokens, in_dim], W[out_dim, in_dim], Y[n_tokens, out_dim] + // Y = X @ W^T => Y(n,out) = sum_k X(n,k) * W(out,k) + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + n_tokens, out_dim, in_dim, + 1.0f, X, in_dim, + W, in_dim, + 0.0f, Y, out_dim); +} + +static int qwen_prefill(QwenModel *m, const int *tokens, int n_tokens) { + int D = QWEN_DIM, HD = QWEN_HIDDEN, N = n_tokens; + + float *xs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_xs"); + float *xbs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_xbs"); + float *qs = (float*)qwen_calloc(N * QWEN_Q_DIM, sizeof(float), "prefill_qs"); + float *ks = (float*)qwen_calloc(N * QWEN_KV_DIM, sizeof(float), "prefill_ks"); + float *vs = (float*)qwen_calloc(N * QWEN_KV_DIM, sizeof(float), "prefill_vs"); + float *hbs = (float*)qwen_calloc(N * HD, sizeof(float), "prefill_hbs"); + float *hb2s = (float*)qwen_calloc(N * HD, sizeof(float), "prefill_hb2s"); + float *o_outs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_o_outs"); + float *ffn_outs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_ffn_outs"); + + // Load all embeddings + for (int t = 0; t < N; t++) + memcpy(xs + t * D, m->embed + tokens[t] * D, D * sizeof(float)); + + for (int l = 0; l < QWEN_LAYERS; l++) { + // Batch RMSNorm + for (int t = 0; t < N; t++) + qwen_rmsnorm(xbs + t * D, xs + t * D, m->rms_att[l], D); + + // Batch QKV projections: sgemm + cpu_project_batch(m->wq[l], xbs, qs, D, QWEN_Q_DIM, N); + cpu_project_batch(m->wk[l], xbs, ks, D, QWEN_KV_DIM, N); + cpu_project_batch(m->wv[l], xbs, vs, D, QWEN_KV_DIM, N); + + // Per-token: bias + RoPE + cache + attention + for (int t = 0; t < N; t++) { + float *qt = qs + t * QWEN_Q_DIM; + float *kt = ks + t * QWEN_KV_DIM; + float *vt = vs + t * QWEN_KV_DIM; + int pos = m->pos + t; + + // Biases + if (m->q_bias[l]) + vDSP_vadd(qt, 1, m->q_bias[l], 1, qt, 1, (vDSP_Length)QWEN_Q_DIM); + if (m->k_bias[l]) + vDSP_vadd(kt, 1, m->k_bias[l], 1, kt, 1, (vDSP_Length)QWEN_KV_DIM); + if (m->v_bias[l]) + vDSP_vadd(vt, 1, m->v_bias[l], 1, vt, 1, (vDSP_Length)QWEN_KV_DIM); + + // RoPE + qwen_rope(qt, kt, pos, QWEN_HEADS, QWEN_KV_HEADS, QWEN_HEAD_DIM); + + // Store K, V in cache + memcpy(m->kv_cache_k[l] + pos * QWEN_KV_DIM, kt, QWEN_KV_DIM * sizeof(float)); + memcpy(m->kv_cache_v[l] + pos * QWEN_KV_DIM, vt, QWEN_KV_DIM * sizeof(float)); + + // GQA attention + float scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM); + float *attn_out = xbs + t * D; + memset(attn_out, 0, QWEN_Q_DIM * sizeof(float)); + + for (int h = 0; h < QWEN_HEADS; h++) { + int kv_h = h / QWEN_GQA_FACTOR; + float *qh = qt + h * QWEN_HEAD_DIM; + float *att_h = m->att + h * QWEN_MAX_SEQ; + int seq_len = pos + 1; + + float max_score = -1e9f; + for (int p = 0; p <= pos; p++) { + float *kp = m->kv_cache_k[l] + p * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM; + float score = cblas_sdot(QWEN_HEAD_DIM, qh, 1, kp, 1); + att_h[p] = score * scale; + if (att_h[p] > max_score) max_score = att_h[p]; + } + float neg_max = -max_score; + vDSP_vsadd(att_h, 1, &neg_max, att_h, 1, (vDSP_Length)seq_len); + int n_exp = seq_len; + vvexpf(att_h, att_h, &n_exp); + float sum; + vDSP_sve(att_h, 1, &sum, (vDSP_Length)seq_len); + float inv_sum = 1.0f / sum; + vDSP_vsmul(att_h, 1, &inv_sum, att_h, 1, (vDSP_Length)seq_len); + + for (int p = 0; p <= pos; p++) { + float a = att_h[p]; + float *vp = m->kv_cache_v[l] + p * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM; + cblas_saxpy(QWEN_HEAD_DIM, a, vp, 1, + attn_out + h * QWEN_HEAD_DIM, 1); + } + } + } + // xbs now has [N, Q_DIM] attention outputs + + // Batch O projection (reuses pre-allocated o_outs) + cpu_project_batch(m->wo[l], xbs, o_outs, QWEN_Q_DIM, D, N); + + for (int t = 0; t < N; t++) + vDSP_vadd(xs + t * D, 1, o_outs + t * D, 1, xs + t * D, 1, (vDSP_Length)D); + + // Batch FFN RMSNorm + for (int t = 0; t < N; t++) + qwen_rmsnorm(xbs + t * D, xs + t * D, m->rms_ffn[l], D); + + // Batch FFN projections + cpu_project_batch(m->w_gate[l], xbs, hbs, D, HD, N); + cpu_project_batch(m->w_up[l], xbs, hb2s, D, HD, N); + + for (int t = 0; t < N; t++) { + qwen_silu(hbs + t * HD, HD); + vDSP_vmul(hbs + t * HD, 1, hb2s + t * HD, 1, hbs + t * HD, 1, (vDSP_Length)HD); + } + + // Batch down projection (reuses pre-allocated ffn_outs) + cpu_project_batch(m->w_down[l], hbs, ffn_outs, HD, D, N); + + for (int t = 0; t < N; t++) + vDSP_vadd(xs + t * D, 1, ffn_outs + t * D, 1, xs + t * D, 1, (vDSP_Length)D); + } + + // Only need logits for the last token + float *last_x = xs + (N - 1) * D; + qwen_rmsnorm(m->xb, last_x, m->rms_final, D); + + cblas_sgemv(CblasRowMajor, CblasNoTrans, + QWEN_VOCAB, D, + 1.0f, m->embed, D, + m->xb, 1, + 0.0f, m->logits, 1); + + m->pos += N; + + float max_val; + vDSP_Length max_idx_vdsp; + vDSP_maxvi(m->logits, 1, &max_val, &max_idx_vdsp, (vDSP_Length)QWEN_VOCAB); + + free(xs); free(xbs); free(qs); free(ks); free(vs); free(hbs); free(hb2s); + free(o_outs); free(ffn_outs); + return (int)max_idx_vdsp; +} + +// Q4 AMX batched prefill: dequantize weight matrices then use sgemm +static int qwen_prefill_q4(QwenModel *m, const int *tokens, int n_tokens) { + int D = QWEN_DIM, HD = QWEN_HIDDEN, N = n_tokens; + + float *xs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_q4_xs"); + float *xbs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_q4_xbs"); + float *qs = (float*)qwen_calloc(N * QWEN_Q_DIM, sizeof(float), "prefill_q4_qs"); + float *ks = (float*)qwen_calloc(N * QWEN_KV_DIM, sizeof(float), "prefill_q4_ks"); + float *vs = (float*)qwen_calloc(N * QWEN_KV_DIM, sizeof(float), "prefill_q4_vs"); + float *hbs = (float*)qwen_calloc(N * HD, sizeof(float), "prefill_q4_hbs"); + float *hb2s = (float*)qwen_calloc(N * HD, sizeof(float), "prefill_q4_hb2s"); + float *o_outs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_q4_o_outs"); + float *ffn_outs = (float*)qwen_calloc(N * D, sizeof(float), "prefill_q4_ffn_outs"); + + for (int t = 0; t < N; t++) + memcpy(xs + t * D, m->embed + tokens[t] * D, D * sizeof(float)); + + for (int l = 0; l < QWEN_LAYERS; l++) { + for (int t = 0; t < N; t++) + qwen_rmsnorm(xbs + t * D, xs + t * D, m->rms_att[l], D); + + cpu_project_batch_q4_amx(m->wq_q8[l], xbs, qs, D, QWEN_Q_DIM, N); + cpu_project_batch_q4_amx(m->wk_q8[l], xbs, ks, D, QWEN_KV_DIM, N); + cpu_project_batch_q4_amx(m->wv_q8[l], xbs, vs, D, QWEN_KV_DIM, N); + + for (int t = 0; t < N; t++) { + float *qt = qs + t * QWEN_Q_DIM; + float *kt = ks + t * QWEN_KV_DIM; + float *vt = vs + t * QWEN_KV_DIM; + int pos = m->pos + t; + + if (m->q_bias[l]) + vDSP_vadd(qt, 1, m->q_bias[l], 1, qt, 1, (vDSP_Length)QWEN_Q_DIM); + if (m->k_bias[l]) + vDSP_vadd(kt, 1, m->k_bias[l], 1, kt, 1, (vDSP_Length)QWEN_KV_DIM); + if (m->v_bias[l]) + vDSP_vadd(vt, 1, m->v_bias[l], 1, vt, 1, (vDSP_Length)QWEN_KV_DIM); + + qwen_rope(qt, kt, pos, QWEN_HEADS, QWEN_KV_HEADS, QWEN_HEAD_DIM); + + memcpy(m->kv_cache_k[l] + pos * QWEN_KV_DIM, kt, QWEN_KV_DIM * sizeof(float)); + memcpy(m->kv_cache_v[l] + pos * QWEN_KV_DIM, vt, QWEN_KV_DIM * sizeof(float)); + + float scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM); + float *attn_out = xbs + t * D; + memset(attn_out, 0, QWEN_Q_DIM * sizeof(float)); + + for (int h = 0; h < QWEN_HEADS; h++) { + int kv_h = h / QWEN_GQA_FACTOR; + float *qh = qt + h * QWEN_HEAD_DIM; + float *att_h = m->att + h * QWEN_MAX_SEQ; + int seq_len = pos + 1; + + float max_score = -1e9f; + for (int p = 0; p <= pos; p++) { + float *kp = m->kv_cache_k[l] + p * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM; + float score = cblas_sdot(QWEN_HEAD_DIM, qh, 1, kp, 1); + att_h[p] = score * scale; + if (att_h[p] > max_score) max_score = att_h[p]; + } + float neg_max = -max_score; + vDSP_vsadd(att_h, 1, &neg_max, att_h, 1, (vDSP_Length)seq_len); + int n_exp = seq_len; + vvexpf(att_h, att_h, &n_exp); + float sum; + vDSP_sve(att_h, 1, &sum, (vDSP_Length)seq_len); + float inv_sum = 1.0f / sum; + vDSP_vsmul(att_h, 1, &inv_sum, att_h, 1, (vDSP_Length)seq_len); + + for (int p = 0; p <= pos; p++) { + float a = att_h[p]; + float *vp = m->kv_cache_v[l] + p * QWEN_KV_DIM + kv_h * QWEN_HEAD_DIM; + cblas_saxpy(QWEN_HEAD_DIM, a, vp, 1, + attn_out + h * QWEN_HEAD_DIM, 1); + } + } + } + + cpu_project_batch_q4_amx(m->wo_q8[l], xbs, o_outs, QWEN_Q_DIM, D, N); + + for (int t = 0; t < N; t++) + vDSP_vadd(xs + t * D, 1, o_outs + t * D, 1, xs + t * D, 1, (vDSP_Length)D); + + for (int t = 0; t < N; t++) + qwen_rmsnorm(xbs + t * D, xs + t * D, m->rms_ffn[l], D); + + cpu_project_batch_q4_amx(m->wgate_q8[l], xbs, hbs, D, HD, N); + cpu_project_batch_q4_amx(m->wup_q8[l], xbs, hb2s, D, HD, N); + + for (int t = 0; t < N; t++) { + qwen_silu(hbs + t * HD, HD); + vDSP_vmul(hbs + t * HD, 1, hb2s + t * HD, 1, hbs + t * HD, 1, (vDSP_Length)HD); + } + + cpu_project_batch_q4_amx(m->wdown_q8[l], hbs, ffn_outs, HD, D, N); + + for (int t = 0; t < N; t++) + vDSP_vadd(xs + t * D, 1, ffn_outs + t * D, 1, xs + t * D, 1, (vDSP_Length)D); + } + + float *last_x = xs + (N - 1) * D; + qwen_rmsnorm(m->xb, last_x, m->rms_final, D); + + cblas_sgemv(CblasRowMajor, CblasNoTrans, + QWEN_VOCAB, D, + 1.0f, m->embed, D, + m->xb, 1, + 0.0f, m->logits, 1); + + m->pos += N; + + float max_val; + vDSP_Length max_idx_vdsp; + vDSP_maxvi(m->logits, 1, &max_val, &max_idx_vdsp, (vDSP_Length)QWEN_VOCAB); + + free(xs); free(xbs); free(qs); free(ks); free(vs); free(hbs); free(hb2s); + free(o_outs); free(ffn_outs); + return (int)max_idx_vdsp; +} + +// ── Full-GPU forward pass (Metal, single command buffer per layer) ──── +// Runs entire transformer on GPU using Q4 quantized weights. +// KV cache stays on GPU between calls. Attention runs per-head on GPU. +#ifdef __OBJC__ + +// SIMD-optimized Q4 matvec with optional bias fusion. +// 2 SIMD groups x 4 rows each = 8 rows/threadgroup, simd_sum reduction. +static void gpu_encode_sgemv_q4_bias(id enc, QwenModel *m, + id w_buf, id x_buf, id y_buf, + uint32_t in_dim, uint32_t out_dim, + id bias_buf) { + id pso = (__bridge id)g_metal.pipeline_q4_fast; + [enc setComputePipelineState:pso]; + [enc setBuffer:w_buf offset:0 atIndex:0]; + [enc setBuffer:x_buf offset:0 atIndex:1]; + [enc setBuffer:y_buf offset:0 atIndex:2]; + [enc setBytes:&in_dim length:4 atIndex:3]; + [enc setBytes:&out_dim length:4 atIndex:4]; + + uint32_t use_bias = (bias_buf != nil) ? 1 : 0; + if (bias_buf) { + [enc setBuffer:bias_buf offset:0 atIndex:5]; + } else { + [enc setBuffer:y_buf offset:0 atIndex:5]; + } + [enc setBytes:&use_bias length:4 atIndex:6]; + + uint32_t rows_per_tg = 8; + uint32_t n_tg = (out_dim + rows_per_tg - 1) / rows_per_tg; + [enc dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; +} + +static void gpu_encode_sgemv_q4(id enc, QwenModel *m, + id w_buf, id x_buf, id y_buf, + uint32_t in_dim, uint32_t out_dim) { + gpu_encode_sgemv_q4_bias(enc, m, w_buf, x_buf, y_buf, in_dim, out_dim, nil); +} + +static void gpu_encode_sgemv_f32(id enc, + id w_buf, id x_buf, id y_buf, + uint32_t in_dim, uint32_t out_dim) { + id pso = (__bridge id)g_metal.pipeline_f32; + [enc setComputePipelineState:pso]; + [enc setBuffer:w_buf offset:0 atIndex:0]; + [enc setBuffer:x_buf offset:0 atIndex:1]; + [enc setBuffer:y_buf offset:0 atIndex:2]; + [enc setBytes:&in_dim length:4 atIndex:3]; + [enc setBytes:&out_dim length:4 atIndex:4]; + NSUInteger tpg = pso.maxTotalThreadsPerThreadgroup; + if (tpg > out_dim) tpg = out_dim; + [enc dispatchThreads:MTLSizeMake(out_dim, 1, 1) threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; +} + +// Fused gate+up+silu: reads x once, computes silu(Wg*x)*Wu*x +static void gpu_encode_fused_ffn(id enc, + id wgate_buf, id wup_buf, + id x_buf, id out_buf, + uint32_t in_dim, uint32_t out_dim) { + id pso = (__bridge id)g_metal.pipeline_q4_fused_ffn; + [enc setComputePipelineState:pso]; + [enc setBuffer:wgate_buf offset:0 atIndex:0]; + [enc setBuffer:wup_buf offset:0 atIndex:1]; + [enc setBuffer:x_buf offset:0 atIndex:2]; + [enc setBuffer:out_buf offset:0 atIndex:3]; + [enc setBytes:&in_dim length:4 atIndex:4]; + [enc setBytes:&out_dim length:4 atIndex:5]; + + uint32_t rows_per_tg = 4; // FUSED_ROWS_PER_SIMD(2) * FUSED_SIMD_GROUPS(2) + uint32_t n_tg = (out_dim + rows_per_tg - 1) / rows_per_tg; + [enc dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; +} + +static int qwen_forward_gpu(QwenModel *m, int token) { + int D = QWEN_DIM, HD = QWEN_HIDDEN; + int pos = m->pos; + uint32_t uD = (uint32_t)D, uHD = (uint32_t)HD; + uint32_t uQD = (uint32_t)QWEN_Q_DIM, uKVD = (uint32_t)QWEN_KV_DIM; + + id dev = (__bridge id)g_metal.device; + id queue = (__bridge id)g_metal.queue; + + static id gpu_x = nil, gpu_xb = nil; + static id gpu_q = nil, gpu_k = nil, gpu_v = nil; + static id gpu_hb = nil, gpu_hb2 = nil; + static id gpu_attn_out = nil; + static id gpu_o_out = nil, gpu_ffn_out = nil; + static id gpu_logits = nil; + static id gpu_att = nil; + static id gpu_result = nil; + static id gpu_rope_cos = nil, gpu_rope_sin = nil; + + if (!gpu_x) { + gpu_x = [dev newBufferWithLength:D * 4 options:MTLResourceStorageModeShared]; + gpu_xb = [dev newBufferWithLength:D * 4 options:MTLResourceStorageModeShared]; + gpu_q = [dev newBufferWithLength:QWEN_Q_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_k = [dev newBufferWithLength:QWEN_KV_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_v = [dev newBufferWithLength:QWEN_KV_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_hb = [dev newBufferWithLength:HD * 4 options:MTLResourceStorageModeShared]; + gpu_hb2 = [dev newBufferWithLength:HD * 4 options:MTLResourceStorageModeShared]; + gpu_attn_out = [dev newBufferWithLength:QWEN_Q_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_o_out = [dev newBufferWithLength:D * 4 options:MTLResourceStorageModeShared]; + gpu_ffn_out = [dev newBufferWithLength:D * 4 options:MTLResourceStorageModeShared]; + gpu_logits = [dev newBufferWithLength:QWEN_VOCAB * 4 options:MTLResourceStorageModeShared]; + gpu_att = [dev newBufferWithLength:QWEN_HEADS * QWEN_MAX_SEQ * 4 options:MTLResourceStorageModeShared]; + gpu_result = [dev newBufferWithLength:4 options:MTLResourceStorageModeShared]; + + qwen_rope_init(); + gpu_rope_cos = [dev newBufferWithLength:sizeof(g_rope_cos) options:MTLResourceStorageModeShared]; + gpu_rope_sin = [dev newBufferWithLength:sizeof(g_rope_sin) options:MTLResourceStorageModeShared]; + memcpy([gpu_rope_cos contents], g_rope_cos, sizeof(g_rope_cos)); + memcpy([gpu_rope_sin contents], g_rope_sin, sizeof(g_rope_sin)); + } + + id pso_rms = (__bridge id)g_metal.pipeline_rms; + id pso_rope = (__bridge id)g_metal.pipeline_rope; + id pso_silu = (__bridge id)g_metal.pipeline_silu; + id pso_add = (__bridge id)g_metal.pipeline_add; + id pso_bias = (__bridge id)g_metal.pipeline_bias; + id pso_embed = (__bridge id)g_metal.pipeline_embed; + id pso_attn_score = (__bridge id)g_metal.pipeline_attn_score; + id pso_softmax = (__bridge id)g_metal.pipeline_softmax; + id pso_attn_wsum = (__bridge id)g_metal.pipeline_attn_wsum; + id pso_argmax = (__bridge id)g_metal.pipeline_argmax; + id pso_zero = (__bridge id)g_metal.pipeline_zero; + id pso_copy = (__bridge id)g_metal.pipeline_copy; + + float rms_eps = QWEN_RMS_EPS; + uint32_t utoken = (uint32_t)token; + uint32_t seq_len = (uint32_t)(pos + 1); + float attn_scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM); + uint32_t un_q = QWEN_HEADS, un_kv = QWEN_KV_HEADS, uhd = QWEN_HEAD_DIM; + + // Encode ALL 24 layers + final into ONE command buffer. + // Metal guarantees sequential execution of dispatches within a command encoder, + // so data dependencies (KV cache reads after writes) are satisfied by dispatch order. + id cmd = [queue commandBuffer]; + id enc = [cmd computeCommandEncoder]; + + // Embedding + [enc setComputePipelineState:pso_embed]; + [enc setBuffer:(__bridge id)m->gpu_embed offset:0 atIndex:0]; + [enc setBuffer:gpu_x offset:0 atIndex:1]; + [enc setBytes:&utoken length:4 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc dispatchThreads:MTLSizeMake(D, 1, 1) threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)D, pso_embed.maxTotalThreadsPerThreadgroup), 1, 1)]; + + for (int l = 0; l < QWEN_LAYERS; l++) { + // RMSNorm attention + [enc setComputePipelineState:pso_rms]; + [enc setBuffer:gpu_x offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_rms_att[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_xb offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc setBytes:&rms_eps length:4 atIndex:4]; + { NSUInteger p = 1; while (p < (NSUInteger)D) p <<= 1; if (p > 1024) p = 1024; + [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + // QKV with fused bias (saves 3 bias_add dispatches per layer) + gpu_encode_sgemv_q4_bias(enc, m, (__bridge id)m->gpu_wq[l], gpu_xb, gpu_q, uD, uQD, + (__bridge id)m->gpu_q_bias[l]); + gpu_encode_sgemv_q4_bias(enc, m, (__bridge id)m->gpu_wk[l], gpu_xb, gpu_k, uD, uKVD, + (__bridge id)m->gpu_k_bias[l]); + gpu_encode_sgemv_q4_bias(enc, m, (__bridge id)m->gpu_wv[l], gpu_xb, gpu_v, uD, uKVD, + (__bridge id)m->gpu_v_bias[l]); + + // RoPE + uint32_t rope_offset = (uint32_t)pos * (QWEN_HEAD_DIM / 2); + [enc setComputePipelineState:pso_rope]; + [enc setBuffer:gpu_q offset:0 atIndex:0]; + [enc setBuffer:gpu_k offset:0 atIndex:1]; + [enc setBuffer:gpu_rope_cos offset:rope_offset * 4 atIndex:2]; + [enc setBuffer:gpu_rope_sin offset:rope_offset * 4 atIndex:3]; + [enc setBytes:&un_q length:4 atIndex:4]; + [enc setBytes:&un_kv length:4 atIndex:5]; + [enc setBytes:&uhd length:4 atIndex:6]; + { uint32_t total = (QWEN_HEADS + QWEN_KV_HEADS) * (QWEN_HEAD_DIM / 2); + [enc dispatchThreads:MTLSizeMake(total, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)total, pso_rope.maxTotalThreadsPerThreadgroup), 1, 1)]; } + + // Store K, V into KV cache + [enc setComputePipelineState:pso_copy]; + [enc setBuffer:gpu_k offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_k[l] offset:(NSUInteger)pos * QWEN_KV_DIM * 4 atIndex:1]; + [enc setBytes:&uKVD length:4 atIndex:2]; + [enc dispatchThreads:MTLSizeMake(QWEN_KV_DIM, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)QWEN_KV_DIM, pso_copy.maxTotalThreadsPerThreadgroup), 1, 1)]; + + [enc setBuffer:gpu_v offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_v[l] offset:(NSUInteger)pos * QWEN_KV_DIM * 4 atIndex:1]; + [enc setBytes:&uKVD length:4 atIndex:2]; + [enc dispatchThreads:MTLSizeMake(QWEN_KV_DIM, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)QWEN_KV_DIM, pso_copy.maxTotalThreadsPerThreadgroup), 1, 1)]; + + // Batched attention: all 14 Q heads in 3 dispatches (was 42) + { + uint32_t un_q_heads = QWEN_HEADS; + uint32_t u_gqa = QWEN_GQA_FACTOR; + uint32_t u_max_seq = QWEN_MAX_SEQ; + + id pso_score_b = (__bridge id)g_metal.pipeline_attn_score_b; + id pso_soft_b = (__bridge id)g_metal.pipeline_softmax_b; + id pso_wsum_b = (__bridge id)g_metal.pipeline_attn_wsum_b; + + // 1. Batched attn score: grid (seq_len, n_q_heads) + [enc setComputePipelineState:pso_score_b]; + [enc setBuffer:gpu_q offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_k[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_att offset:0 atIndex:2]; + [enc setBytes:&uhd length:4 atIndex:3]; + [enc setBytes:&uKVD length:4 atIndex:4]; + [enc setBytes:&un_q_heads length:4 atIndex:5]; + [enc setBytes:&u_gqa length:4 atIndex:6]; + [enc setBytes:&attn_scale length:4 atIndex:7]; + [enc setBytes:&seq_len length:4 atIndex:8]; + [enc setBytes:&u_max_seq length:4 atIndex:9]; + { NSUInteger tpg_x = MIN((NSUInteger)seq_len, (NSUInteger)256); + NSUInteger tpg_y = MIN((NSUInteger)QWEN_HEADS, (NSUInteger)(pso_score_b.maxTotalThreadsPerThreadgroup / tpg_x)); + if (tpg_y < 1) tpg_y = 1; + [enc dispatchThreads:MTLSizeMake(seq_len, QWEN_HEADS, 1) + threadsPerThreadgroup:MTLSizeMake(tpg_x, tpg_y, 1)]; } + + // 2. Batched softmax: one threadgroup per head + [enc setComputePipelineState:pso_soft_b]; + [enc setBuffer:gpu_att offset:0 atIndex:0]; + [enc setBytes:&seq_len length:4 atIndex:1]; + [enc setBytes:&u_max_seq length:4 atIndex:2]; + [enc setBytes:&un_q_heads length:4 atIndex:3]; + { NSUInteger p = 1; while (p < (NSUInteger)seq_len && p < 1024) p <<= 1; + [enc dispatchThreadgroups:MTLSizeMake(QWEN_HEADS, 1, 1) + threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + // 3. Batched weighted sum: grid (head_dim, n_q_heads) + [enc setComputePipelineState:pso_wsum_b]; + [enc setBuffer:gpu_att offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_v[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_attn_out offset:0 atIndex:2]; + [enc setBytes:&uhd length:4 atIndex:3]; + [enc setBytes:&uKVD length:4 atIndex:4]; + [enc setBytes:&un_q_heads length:4 atIndex:5]; + [enc setBytes:&u_gqa length:4 atIndex:6]; + [enc setBytes:&seq_len length:4 atIndex:7]; + [enc setBytes:&u_max_seq length:4 atIndex:8]; + { NSUInteger tpg_x = MIN((NSUInteger)QWEN_HEAD_DIM, (NSUInteger)64); + NSUInteger tpg_y = MIN((NSUInteger)QWEN_HEADS, (NSUInteger)(pso_wsum_b.maxTotalThreadsPerThreadgroup / tpg_x)); + if (tpg_y < 1) tpg_y = 1; + [enc dispatchThreads:MTLSizeMake(QWEN_HEAD_DIM, QWEN_HEADS, 1) + threadsPerThreadgroup:MTLSizeMake(tpg_x, tpg_y, 1)]; } + } + + // O projection + residual + gpu_encode_sgemv_q4(enc, m, (__bridge id)m->gpu_wo[l], gpu_attn_out, gpu_o_out, uQD, uD); + + [enc setComputePipelineState:pso_add]; + [enc setBuffer:gpu_x offset:0 atIndex:0]; + [enc setBuffer:gpu_o_out offset:0 atIndex:1]; + [enc setBuffer:gpu_x offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc dispatchThreads:MTLSizeMake(D, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)D, pso_add.maxTotalThreadsPerThreadgroup), 1, 1)]; + + // FFN + [enc setComputePipelineState:pso_rms]; + [enc setBuffer:gpu_x offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_rms_ffn[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_xb offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc setBytes:&rms_eps length:4 atIndex:4]; + { NSUInteger p = 1; while (p < (NSUInteger)D) p <<= 1; if (p > 1024) p = 1024; + [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + // Fused gate+up+silu: one kernel reads xb, computes silu(Wg*xb)*Wu*xb + gpu_encode_fused_ffn(enc, + (__bridge id)m->gpu_wgate[l], + (__bridge id)m->gpu_wup[l], + gpu_xb, gpu_hb, uD, uHD); + + gpu_encode_sgemv_q4(enc, m, (__bridge id)m->gpu_wdown[l], gpu_hb, gpu_ffn_out, uHD, uD); + + [enc setComputePipelineState:pso_add]; + [enc setBuffer:gpu_x offset:0 atIndex:0]; + [enc setBuffer:gpu_ffn_out offset:0 atIndex:1]; + [enc setBuffer:gpu_x offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc dispatchThreads:MTLSizeMake(D, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)D, pso_add.maxTotalThreadsPerThreadgroup), 1, 1)]; + } + + // Final RMSNorm + LM Head + argmax (still in the SAME command buffer) + [enc setComputePipelineState:pso_rms]; + [enc setBuffer:gpu_x offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_rms_final offset:0 atIndex:1]; + [enc setBuffer:gpu_xb offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc setBytes:&rms_eps length:4 atIndex:4]; + { NSUInteger p = 1; while (p < (NSUInteger)D) p <<= 1; if (p > 1024) p = 1024; + [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + uint32_t uVocab = QWEN_VOCAB; + gpu_encode_sgemv_f32(enc, (__bridge id)m->gpu_embed, gpu_xb, gpu_logits, uD, uVocab); + + [enc setComputePipelineState:pso_argmax]; + [enc setBuffer:gpu_logits offset:0 atIndex:0]; + [enc setBuffer:gpu_result offset:0 atIndex:1]; + [enc setBytes:&uVocab length:4 atIndex:2]; + { NSUInteger tpg = MIN((NSUInteger)1024, pso_argmax.maxTotalThreadsPerThreadgroup); + [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; } + + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + m->pos++; + + int *result_ptr = (int*)[gpu_result contents]; + return result_ptr[0]; +} + +// ── GPU batched prefill: all N prompt tokens in one command buffer ──── +// Uses sgemm_q4 (matrix-matrix) instead of sequential sgemv calls. +// Reads each weight matrix once for all N tokens instead of N times. +static int qwen_prefill_gpu(QwenModel *m, const int *tokens, int n_tokens) { + int D = QWEN_DIM, HD = QWEN_HIDDEN, N = n_tokens; + uint32_t uD = (uint32_t)D, uHD = (uint32_t)HD; + uint32_t uQD = (uint32_t)QWEN_Q_DIM, uKVD = (uint32_t)QWEN_KV_DIM; + uint32_t uN = (uint32_t)N; + float rms_eps = QWEN_RMS_EPS; + float attn_scale = 1.0f / sqrtf((float)QWEN_HEAD_DIM); + uint32_t uhd = QWEN_HEAD_DIM; + uint32_t un_q = QWEN_HEADS, un_kv = QWEN_KV_HEADS; + uint32_t u_gqa = QWEN_GQA_FACTOR; + uint32_t u_max_seq = QWEN_MAX_SEQ; + + id dev = (__bridge id)g_metal.device; + id queue = (__bridge id)g_metal.queue; + + // Static batch buffers: allocated once at QWEN_MAX_SEQ size, reused across calls + static id gpu_xs = nil, gpu_xbs = nil; + static id gpu_qs = nil, gpu_ks = nil, gpu_vs = nil; + static id gpu_hbs = nil; + static id gpu_attn_outs = nil, gpu_o_outs = nil, gpu_ffn_outs = nil; + static id gpu_att = nil, gpu_logits = nil, gpu_result = nil; + static id gpu_token_ids = nil; + static id gpu_rope_cos = nil, gpu_rope_sin = nil; + static id gpu_xb_last = nil; + + if (!gpu_xs) { + NSUInteger maxN = QWEN_MAX_SEQ; + gpu_xs = [dev newBufferWithLength:maxN * D * 4 options:MTLResourceStorageModeShared]; + gpu_xbs = [dev newBufferWithLength:maxN * D * 4 options:MTLResourceStorageModeShared]; + gpu_qs = [dev newBufferWithLength:maxN * QWEN_Q_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_ks = [dev newBufferWithLength:maxN * QWEN_KV_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_vs = [dev newBufferWithLength:maxN * QWEN_KV_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_hbs = [dev newBufferWithLength:maxN * HD * 4 options:MTLResourceStorageModeShared]; + gpu_attn_outs = [dev newBufferWithLength:maxN * QWEN_Q_DIM * 4 options:MTLResourceStorageModeShared]; + gpu_o_outs = [dev newBufferWithLength:maxN * D * 4 options:MTLResourceStorageModeShared]; + gpu_ffn_outs = [dev newBufferWithLength:maxN * D * 4 options:MTLResourceStorageModeShared]; + gpu_att = [dev newBufferWithLength:QWEN_HEADS * QWEN_MAX_SEQ * 4 options:MTLResourceStorageModeShared]; + gpu_logits = [dev newBufferWithLength:QWEN_VOCAB * 4 options:MTLResourceStorageModeShared]; + gpu_result = [dev newBufferWithLength:4 options:MTLResourceStorageModeShared]; + gpu_token_ids = [dev newBufferWithLength:maxN * sizeof(int) options:MTLResourceStorageModeShared]; + gpu_xb_last = [dev newBufferWithLength:D * 4 options:MTLResourceStorageModeShared]; + + qwen_rope_init(); + gpu_rope_cos = [dev newBufferWithLength:sizeof(g_rope_cos) options:MTLResourceStorageModeShared]; + gpu_rope_sin = [dev newBufferWithLength:sizeof(g_rope_sin) options:MTLResourceStorageModeShared]; + memcpy([gpu_rope_cos contents], g_rope_cos, sizeof(g_rope_cos)); + memcpy([gpu_rope_sin contents], g_rope_sin, sizeof(g_rope_sin)); + } + + memcpy([gpu_token_ids contents], tokens, (NSUInteger)N * sizeof(int)); + + // Pipeline states + id pso_sgemm_q4 = (__bridge id)g_metal.pipeline_sgemm_q4; + id pso_sgemm_ffn = (__bridge id)g_metal.pipeline_sgemm_q4_fused_ffn; + id pso_rms_b = (__bridge id)g_metal.pipeline_rms_batched; + id pso_embed_b = (__bridge id)g_metal.pipeline_embed_batched; + id pso_rope_b = (__bridge id)g_metal.pipeline_rope_batched; + id pso_add_b = (__bridge id)g_metal.pipeline_add_batched; + id pso_copy = (__bridge id)g_metal.pipeline_copy; + id pso_rms = (__bridge id)g_metal.pipeline_rms; + id pso_argmax = (__bridge id)g_metal.pipeline_argmax; + id pso_score_b = (__bridge id)g_metal.pipeline_attn_score_b; + id pso_soft_b = (__bridge id)g_metal.pipeline_softmax_b; + id pso_wsum_b = (__bridge id)g_metal.pipeline_attn_wsum_b; + + // Helper: encode sgemm_q4 dispatch + #define ENCODE_SGEMM_Q4(enc, w_buf, x_buf, y_buf, in_d, out_d, bias_buf, n_tok) do { \ + [enc setComputePipelineState:pso_sgemm_q4]; \ + [enc setBuffer:w_buf offset:0 atIndex:0]; \ + [enc setBuffer:x_buf offset:0 atIndex:1]; \ + [enc setBuffer:y_buf offset:0 atIndex:2]; \ + uint32_t _id = (in_d), _od = (out_d), _ub = ((bias_buf) != nil) ? 1 : 0, _nt = (n_tok); \ + [enc setBytes:&_id length:4 atIndex:3]; \ + [enc setBytes:&_od length:4 atIndex:4]; \ + if (bias_buf) [enc setBuffer:bias_buf offset:0 atIndex:5]; \ + else [enc setBuffer:y_buf offset:0 atIndex:5]; \ + [enc setBytes:&_ub length:4 atIndex:6]; \ + [enc setBytes:&_nt length:4 atIndex:7]; \ + uint32_t _tg_x = (_od + 7) / 8; \ + [enc dispatchThreadgroups:MTLSizeMake(_tg_x, _nt, 1) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; \ + } while(0) + + // Single command buffer for entire prefill + id cmd = [queue commandBuffer]; + id enc = [cmd computeCommandEncoder]; + + // 1. Batched embedding: load all N token embeddings + [enc setComputePipelineState:pso_embed_b]; + [enc setBuffer:(__bridge id)m->gpu_embed offset:0 atIndex:0]; + [enc setBuffer:gpu_xs offset:0 atIndex:1]; + [enc setBuffer:gpu_token_ids offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + { NSUInteger tpg_x = MIN((NSUInteger)D, pso_embed_b.maxTotalThreadsPerThreadgroup); + [enc dispatchThreads:MTLSizeMake(D, N, 1) threadsPerThreadgroup:MTLSizeMake(tpg_x, 1, 1)]; } + + for (int l = 0; l < QWEN_LAYERS; l++) { + // 2. Batched RMSNorm (attention): N threadgroups, one per token + [enc setComputePipelineState:pso_rms_b]; + [enc setBuffer:gpu_xs offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_rms_att[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_xbs offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc setBytes:&rms_eps length:4 atIndex:4]; + [enc setBytes:&uN length:4 atIndex:5]; + { NSUInteger p = 1; while (p < (NSUInteger)D) p <<= 1; if (p > 1024) p = 1024; + [enc dispatchThreadgroups:MTLSizeMake(N, 1, 1) threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + // 3. Batched QKV projections with fused bias (3 sgemm_q4 dispatches) + ENCODE_SGEMM_Q4(enc, (__bridge id)m->gpu_wq[l], gpu_xbs, gpu_qs, + uD, uQD, (__bridge id)m->gpu_q_bias[l], uN); + ENCODE_SGEMM_Q4(enc, (__bridge id)m->gpu_wk[l], gpu_xbs, gpu_ks, + uD, uKVD, (__bridge id)m->gpu_k_bias[l], uN); + ENCODE_SGEMM_Q4(enc, (__bridge id)m->gpu_wv[l], gpu_xbs, gpu_vs, + uD, uKVD, (__bridge id)m->gpu_v_bias[l], uN); + + // 4. Batched RoPE: apply to all N tokens' Q and K + uint32_t base_pos = (uint32_t)m->pos; + uint32_t q_stride_val = QWEN_Q_DIM; + uint32_t k_stride_val = QWEN_KV_DIM; + [enc setComputePipelineState:pso_rope_b]; + [enc setBuffer:gpu_qs offset:0 atIndex:0]; + [enc setBuffer:gpu_ks offset:0 atIndex:1]; + [enc setBuffer:gpu_rope_cos offset:0 atIndex:2]; + [enc setBuffer:gpu_rope_sin offset:0 atIndex:3]; + [enc setBytes:&un_q length:4 atIndex:4]; + [enc setBytes:&un_kv length:4 atIndex:5]; + [enc setBytes:&uhd length:4 atIndex:6]; + [enc setBytes:&base_pos length:4 atIndex:7]; + [enc setBytes:&q_stride_val length:4 atIndex:8]; + [enc setBytes:&k_stride_val length:4 atIndex:9]; + { uint32_t total_pairs = (QWEN_HEADS + QWEN_KV_HEADS) * (QWEN_HEAD_DIM / 2); + NSUInteger tpg = MIN((NSUInteger)total_pairs, pso_rope_b.maxTotalThreadsPerThreadgroup); + [enc dispatchThreads:MTLSizeMake(total_pairs, N, 1) + threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; } + + // 5. Store K, V into cache for all N tokens (copy from batched buffers) + for (int t = 0; t < N; t++) { + int pos = m->pos + t; + [enc setComputePipelineState:pso_copy]; + [enc setBuffer:gpu_ks offset:(NSUInteger)t * QWEN_KV_DIM * 4 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_k[l] offset:(NSUInteger)pos * QWEN_KV_DIM * 4 atIndex:1]; + [enc setBytes:&uKVD length:4 atIndex:2]; + [enc dispatchThreads:MTLSizeMake(QWEN_KV_DIM, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)QWEN_KV_DIM, pso_copy.maxTotalThreadsPerThreadgroup), 1, 1)]; + + [enc setBuffer:gpu_vs offset:(NSUInteger)t * QWEN_KV_DIM * 4 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_v[l] offset:(NSUInteger)pos * QWEN_KV_DIM * 4 atIndex:1]; + [enc setBytes:&uKVD length:4 atIndex:2]; + [enc dispatchThreads:MTLSizeMake(QWEN_KV_DIM, 1, 1) + threadsPerThreadgroup:MTLSizeMake(MIN((NSUInteger)QWEN_KV_DIM, pso_copy.maxTotalThreadsPerThreadgroup), 1, 1)]; + } + + // 6. Per-token causal attention on GPU (each token sees only preceding tokens) + for (int t = 0; t < N; t++) { + uint32_t seq_len = (uint32_t)(m->pos + t + 1); + + // Attn score: all heads + [enc setComputePipelineState:pso_score_b]; + [enc setBuffer:gpu_qs offset:(NSUInteger)t * QWEN_Q_DIM * 4 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_k[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_att offset:0 atIndex:2]; + [enc setBytes:&uhd length:4 atIndex:3]; + [enc setBytes:&uKVD length:4 atIndex:4]; + [enc setBytes:&un_q length:4 atIndex:5]; + [enc setBytes:&u_gqa length:4 atIndex:6]; + [enc setBytes:&attn_scale length:4 atIndex:7]; + [enc setBytes:&seq_len length:4 atIndex:8]; + [enc setBytes:&u_max_seq length:4 atIndex:9]; + { NSUInteger tpg_x = MIN((NSUInteger)seq_len, (NSUInteger)256); + NSUInteger tpg_y = MIN((NSUInteger)QWEN_HEADS, pso_score_b.maxTotalThreadsPerThreadgroup / tpg_x); + if (tpg_y < 1) tpg_y = 1; + [enc dispatchThreads:MTLSizeMake(seq_len, QWEN_HEADS, 1) + threadsPerThreadgroup:MTLSizeMake(tpg_x, tpg_y, 1)]; } + + // Softmax: one threadgroup per head + [enc setComputePipelineState:pso_soft_b]; + [enc setBuffer:gpu_att offset:0 atIndex:0]; + [enc setBytes:&seq_len length:4 atIndex:1]; + [enc setBytes:&u_max_seq length:4 atIndex:2]; + [enc setBytes:&un_q length:4 atIndex:3]; + { NSUInteger p = 1; while (p < (NSUInteger)seq_len && p < 1024) p <<= 1; + [enc dispatchThreadgroups:MTLSizeMake(QWEN_HEADS, 1, 1) + threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + // Weighted sum: all heads + [enc setComputePipelineState:pso_wsum_b]; + [enc setBuffer:gpu_att offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_kv_cache_v[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_attn_outs offset:(NSUInteger)t * QWEN_Q_DIM * 4 atIndex:2]; + [enc setBytes:&uhd length:4 atIndex:3]; + [enc setBytes:&uKVD length:4 atIndex:4]; + [enc setBytes:&un_q length:4 atIndex:5]; + [enc setBytes:&u_gqa length:4 atIndex:6]; + [enc setBytes:&seq_len length:4 atIndex:7]; + [enc setBytes:&u_max_seq length:4 atIndex:8]; + { NSUInteger tpg_x = MIN((NSUInteger)QWEN_HEAD_DIM, (NSUInteger)64); + NSUInteger tpg_y = MIN((NSUInteger)QWEN_HEADS, pso_wsum_b.maxTotalThreadsPerThreadgroup / tpg_x); + if (tpg_y < 1) tpg_y = 1; + [enc dispatchThreads:MTLSizeMake(QWEN_HEAD_DIM, QWEN_HEADS, 1) + threadsPerThreadgroup:MTLSizeMake(tpg_x, tpg_y, 1)]; } + } + + // 7. Batched O projection + ENCODE_SGEMM_Q4(enc, (__bridge id)m->gpu_wo[l], gpu_attn_outs, gpu_o_outs, + uQD, uD, nil, uN); + + // 8. Batched residual: xs += o_outs + uint32_t total_add = uN * uD; + [enc setComputePipelineState:pso_add_b]; + [enc setBuffer:gpu_xs offset:0 atIndex:0]; + [enc setBuffer:gpu_o_outs offset:0 atIndex:1]; + [enc setBuffer:gpu_xs offset:0 atIndex:2]; + [enc setBytes:&total_add length:4 atIndex:3]; + { NSUInteger tpg = MIN((NSUInteger)total_add, pso_add_b.maxTotalThreadsPerThreadgroup); + [enc dispatchThreads:MTLSizeMake(total_add, 1, 1) threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; } + + // 9. Batched RMSNorm (FFN) + [enc setComputePipelineState:pso_rms_b]; + [enc setBuffer:gpu_xs offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_rms_ffn[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_xbs offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc setBytes:&rms_eps length:4 atIndex:4]; + [enc setBytes:&uN length:4 atIndex:5]; + { NSUInteger p = 1; while (p < (NSUInteger)D) p <<= 1; if (p > 1024) p = 1024; + [enc dispatchThreadgroups:MTLSizeMake(N, 1, 1) threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + // 10. Batched fused Gate+Up+SiLU + [enc setComputePipelineState:pso_sgemm_ffn]; + [enc setBuffer:(__bridge id)m->gpu_wgate[l] offset:0 atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_wup[l] offset:0 atIndex:1]; + [enc setBuffer:gpu_xbs offset:0 atIndex:2]; + [enc setBuffer:gpu_hbs offset:0 atIndex:3]; + [enc setBytes:&uD length:4 atIndex:4]; + [enc setBytes:&uHD length:4 atIndex:5]; + [enc setBytes:&uN length:4 atIndex:6]; + { uint32_t ffn_tg_x = (uHD + 3) / 4; + [enc dispatchThreadgroups:MTLSizeMake(ffn_tg_x, N, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; } + + // 11. Batched down projection + ENCODE_SGEMM_Q4(enc, (__bridge id)m->gpu_wdown[l], gpu_hbs, gpu_ffn_outs, + uHD, uD, nil, uN); + + // 12. Batched FFN residual: xs += ffn_outs + [enc setComputePipelineState:pso_add_b]; + [enc setBuffer:gpu_xs offset:0 atIndex:0]; + [enc setBuffer:gpu_ffn_outs offset:0 atIndex:1]; + [enc setBuffer:gpu_xs offset:0 atIndex:2]; + [enc setBytes:&total_add length:4 atIndex:3]; + { NSUInteger tpg = MIN((NSUInteger)total_add, pso_add_b.maxTotalThreadsPerThreadgroup); + [enc dispatchThreads:MTLSizeMake(total_add, 1, 1) threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; } + } + + // Final: RMSNorm + LM head + argmax on LAST token only + NSUInteger last_off = (NSUInteger)(N - 1) * D * 4; + + [enc setComputePipelineState:pso_rms]; + [enc setBuffer:gpu_xs offset:last_off atIndex:0]; + [enc setBuffer:(__bridge id)m->gpu_rms_final offset:0 atIndex:1]; + [enc setBuffer:gpu_xb_last offset:0 atIndex:2]; + [enc setBytes:&uD length:4 atIndex:3]; + [enc setBytes:&rms_eps length:4 atIndex:4]; + { NSUInteger p = 1; while (p < (NSUInteger)D) p <<= 1; if (p > 1024) p = 1024; + [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(p, 1, 1)]; } + + uint32_t uVocab = QWEN_VOCAB; + gpu_encode_sgemv_f32(enc, (__bridge id)m->gpu_embed, gpu_xb_last, gpu_logits, uD, uVocab); + + [enc setComputePipelineState:pso_argmax]; + [enc setBuffer:gpu_logits offset:0 atIndex:0]; + [enc setBuffer:gpu_result offset:0 atIndex:1]; + [enc setBytes:&uVocab length:4 atIndex:2]; + { NSUInteger tpg = MIN((NSUInteger)1024, pso_argmax.maxTotalThreadsPerThreadgroup); + [enc dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(tpg, 1, 1)]; } + + [enc endEncoding]; + [cmd commit]; + [cmd waitUntilCompleted]; + + #undef ENCODE_SGEMM_Q4 + + m->pos += N; + int *result_ptr = (int*)[gpu_result contents]; + return result_ptr[0]; +} + +#endif // __OBJC__ + // ── Compile all ANE kernels ────────────────────────────────────────── static void qwen_compile_kernels(QwenModel *m) { @@ -460,21 +1949,59 @@ static void qwen_compile_kernels(QwenModel *m) { #endif } +// Fused ANE compilation: QKV fused + Gate/Up fused + separate O, Down +// Total: 24*(1 QKV + 1 O + 1 FFN_up + 1 Down) = 96 kernels + 16 LM head = 112 (< 119) +static void qwen_compile_kernels_fused(QwenModel *m) { + int D = QWEN_DIM, HD = QWEN_HIDDEN; + int total = QWEN_LAYERS * 4 + QWEN_LM_CHUNKS; + int compiled = 0, failed = 0; + printf("Compiling %d fused ANE kernels (QKV+FFN_up fused)...\n", total); + + for (int l = 0; l < QWEN_LAYERS; l++) { + m->k_qkv[l] = compile_qkv_gqa_kernel( + m->wq[l], m->wk[l], m->wv[l], + D, QWEN_Q_DIM, QWEN_KV_DIM); + if (m->k_qkv[l]) compiled++; else { failed++; printf(" Layer %d QKV FAILED\n", l); } + + m->k_o[l] = compile_conv_kernel_fp16io(m->wo[l], QWEN_Q_DIM, D, 1); + if (m->k_o[l]) compiled++; else { failed++; printf(" Layer %d O FAILED\n", l); } + + m->k_ffn_up[l] = compile_ffn_up_kernel(m->w_gate[l], m->w_up[l], D, HD); + if (m->k_ffn_up[l]) compiled++; else { failed++; printf(" Layer %d FFN_up FAILED\n", l); } + + m->k_down[l] = compile_conv_kernel_fp16io(m->w_down[l], HD, D, 1); + if (m->k_down[l]) compiled++; else { failed++; printf(" Layer %d Down FAILED\n", l); } + + printf(" Layer %d/%d compiled (%d/%d ok)\r", l+1, QWEN_LAYERS, compiled, compiled+failed); + fflush(stdout); + } + + for (int c = 0; c < QWEN_LM_CHUNKS; c++) { + float *chunk_w = m->embed + c * QWEN_LM_CHUNK_SIZE * D; + m->k_lmhead[c] = compile_conv_kernel_fp16io(chunk_w, D, QWEN_LM_CHUNK_SIZE, 1); + if (m->k_lmhead[c]) compiled++; else { failed++; printf(" LM head chunk %d FAILED\n", c); } + } + + printf("\nFused ANE: %d/%d compiled, %d failed\n", compiled, total, failed); + if (failed > 0) + printf("WARNING: some kernels failed — ANE inference will fall back to CPU for those projections\n"); +} + // ── Allocate buffers ───────────────────────────────────────────────── static void qwen_alloc(QwenModel *m) { - m->x = (float*)calloc(QWEN_DIM, sizeof(float)); - m->xb = (float*)calloc(QWEN_DIM, sizeof(float)); - m->q = (float*)calloc(QWEN_Q_DIM, sizeof(float)); - m->k = (float*)calloc(QWEN_KV_DIM, sizeof(float)); - m->v = (float*)calloc(QWEN_KV_DIM, sizeof(float)); - m->att = (float*)calloc(QWEN_HEADS * QWEN_MAX_SEQ, sizeof(float)); - m->hb = (float*)calloc(QWEN_HIDDEN, sizeof(float)); - m->hb2 = (float*)calloc(QWEN_HIDDEN, sizeof(float)); - m->logits = (float*)calloc(QWEN_VOCAB, sizeof(float)); + m->x = (float*)qwen_calloc(QWEN_DIM, sizeof(float), "x"); + m->xb = (float*)qwen_calloc(QWEN_DIM, sizeof(float), "xb"); + m->q = (float*)qwen_calloc(QWEN_Q_DIM, sizeof(float), "q"); + m->k = (float*)qwen_calloc(QWEN_KV_DIM, sizeof(float), "k"); + m->v = (float*)qwen_calloc(QWEN_KV_DIM, sizeof(float), "v"); + m->att = (float*)qwen_calloc(QWEN_HEADS * QWEN_MAX_SEQ, sizeof(float), "att"); + m->hb = (float*)qwen_calloc(QWEN_HIDDEN, sizeof(float), "hb"); + m->hb2 = (float*)qwen_calloc(QWEN_HIDDEN, sizeof(float), "hb2"); + m->logits = (float*)qwen_calloc(QWEN_VOCAB, sizeof(float), "logits"); for (int l = 0; l < QWEN_LAYERS; l++) { - m->kv_cache_k[l] = (float*)calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float)); - m->kv_cache_v[l] = (float*)calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float)); + m->kv_cache_k[l] = (float*)qwen_calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float), "kv_cache_k"); + m->kv_cache_v[l] = (float*)qwen_calloc(QWEN_MAX_SEQ * QWEN_KV_DIM, sizeof(float), "kv_cache_v"); } m->pos = 0; } diff --git a/inference/setup.sh b/inference/setup.sh index 6f61a0e..cd4e127 100755 --- a/inference/setup.sh +++ b/inference/setup.sh @@ -7,7 +7,8 @@ MODEL_DIR="$HOME/models/Qwen2.5-0.5B-Instruct" WEIGHTS_BIN="$SCRIPT_DIR/qwen05b.bin" BINARY="$SCRIPT_DIR/qwen_ane" VENV_DIR="$SCRIPT_DIR/.venv" -EXPECTED_WEIGHT_SIZE=1976131100 +EXPECTED_WEIGHT_SIZE_F32=1976131100 +EXPECTED_WEIGHT_SIZE_F16=988082236 info() { printf "\033[1;34m==> %s\033[0m\n" "$1"; } ok() { printf "\033[1;32m ✓ %s\033[0m\n" "$1"; } @@ -86,16 +87,16 @@ info "Converting weights to binary format..." if [ -f "$WEIGHTS_BIN" ]; then ACTUAL_SIZE=$(stat -f%z "$WEIGHTS_BIN" 2>/dev/null || stat -c%s "$WEIGHTS_BIN" 2>/dev/null) - if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE" ]; then + if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE_F16" ] || [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE_F32" ]; then ok "Weights already converted ($((ACTUAL_SIZE / 1024 / 1024)) MB)" else - warn "Weight file exists but wrong size ($ACTUAL_SIZE vs $EXPECTED_WEIGHT_SIZE), reconverting" - python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" - ok "Weights converted" + warn "Weight file exists but unexpected size ($ACTUAL_SIZE), reconverting as F16" + python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" --f16 + ok "Weights converted (F16)" fi else - python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" - ok "Weights converted" + python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" --f16 + ok "Weights converted (F16)" fi # --- Step 5: Build binary --- @@ -113,8 +114,10 @@ elif [ "$SCRIPT_DIR/main.m" -nt "$BINARY" ] || \ fi if [ "$NEEDS_BUILD" -eq 1 ]; then - xcrun clang -O2 -framework Foundation -framework IOSurface \ - -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \ + xcrun clang -O3 -ffast-math -mcpu=apple-m4 -flto \ + -framework Foundation -framework IOSurface \ + -framework CoreML -framework Accelerate -framework Metal \ + -ldl -lobjc -fobjc-arc \ -o "$BINARY" "$SCRIPT_DIR/main.m" ok "Binary built: $BINARY" else diff --git a/training/ane_mil_gen.h b/training/ane_mil_gen.h index 5e205c3..05a8c95 100644 --- a/training/ane_mil_gen.h +++ b/training/ane_mil_gen.h @@ -232,6 +232,124 @@ static NSData *mil_build_ffn_up_weight_blob(const float *w1, const float *w3, in return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES]; } +// Generate MIL for fused GQA QKV: Q, K, V have different output dimensions +// Qwen2.5-0.5B: Q=[q_dim, dim], K=[kv_dim, dim], V=[kv_dim, dim] +// Weight blob: Wq[q_dim,dim] @ chunk0, Wk[kv_dim,dim] @ chunk1, Wv[kv_dim,dim] @ chunk2 +static NSString *mil_gen_qkv_gqa(int dim, int q_dim, int kv_dim, int spatial) { + NSUInteger cs_q = 64 + (NSUInteger)q_dim * dim * 2; + NSUInteger cs_kv = 64 + (NSUInteger)kv_dim * dim * 2; + NSUInteger off_k = 64 + cs_q; + NSUInteger off_v = off_k + cs_kv; + if (g_fp16_io) { + return [NSString stringWithFormat: + @"program(1.0)\n" + "[buildInfo = dict, tensor>({{\"coremlc-version\", \"3505.4.1\"}})]\n" + "{\n" + " func main(tensor x) {\n" + " tensor c_pad_type = const()[name = tensor(\"c_pad_type\"), val = tensor(\"valid\")];\n" + " tensor c_strides = const()[name = tensor(\"c_strides\"), val = tensor([1, 1])];\n" + " tensor c_pad = const()[name = tensor(\"c_pad\"), val = tensor([0, 0, 0, 0])];\n" + " tensor c_dilations = const()[name = tensor(\"c_dilations\"), val = tensor([1, 1])];\n" + " tensor c_groups = const()[name = tensor(\"c_groups\"), val = tensor(1)];\n" + " tensor Wq = const()[name = tensor(\"Wq\"), " + "val = tensor(BLOBFILE(path = tensor(\"@model_path/weights/weight.bin\"), offset = tensor(64)))];\n" + " tensor Wk = const()[name = tensor(\"Wk\"), " + "val = tensor(BLOBFILE(path = tensor(\"@model_path/weights/weight.bin\"), offset = tensor(%lu)))];\n" + " tensor Wv = const()[name = tensor(\"Wv\"), " + "val = tensor(BLOBFILE(path = tensor(\"@model_path/weights/weight.bin\"), offset = tensor(%lu)))];\n" + " tensor q = conv(dilations = c_dilations, groups = c_groups, " + "pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wq, x = x)[name = tensor(\"conv_q\")];\n" + " tensor k = conv(dilations = c_dilations, groups = c_groups, " + "pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wk, x = x)[name = tensor(\"conv_k\")];\n" + " tensor v = conv(dilations = c_dilations, groups = c_groups, " + "pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wv, x = x)[name = tensor(\"conv_v\")];\n" + " } -> (q, k, v);\n" + "}\n", + dim, spatial, + q_dim, dim, q_dim, dim, + kv_dim, dim, kv_dim, dim, (unsigned long)off_k, + kv_dim, dim, kv_dim, dim, (unsigned long)off_v, + q_dim, spatial, kv_dim, spatial, kv_dim, spatial]; + } + return [NSString stringWithFormat: + @"program(1.0)\n" + "[buildInfo = dict, tensor>({{\"coremlc-version\", \"3505.4.1\"}})]\n" + "{\n" + " func main(tensor x) {\n" + " tensor c_pad_type = const()[name = tensor(\"c_pad_type\"), val = tensor(\"valid\")];\n" + " tensor c_strides = const()[name = tensor(\"c_strides\"), val = tensor([1, 1])];\n" + " tensor c_pad = const()[name = tensor(\"c_pad\"), val = tensor([0, 0, 0, 0])];\n" + " tensor c_dilations = const()[name = tensor(\"c_dilations\"), val = tensor([1, 1])];\n" + " tensor c_groups = const()[name = tensor(\"c_groups\"), val = tensor(1)];\n" + " tensor to_fp16 = const()[name = tensor(\"to_fp16\"), val = tensor(\"fp16\")];\n" + " tensor x16 = cast(dtype = to_fp16, x = x)[name = tensor(\"cast_in\")];\n" + " tensor Wq = const()[name = tensor(\"Wq\"), " + "val = tensor(BLOBFILE(path = tensor(\"@model_path/weights/weight.bin\"), offset = tensor(64)))];\n" + " tensor Wk = const()[name = tensor(\"Wk\"), " + "val = tensor(BLOBFILE(path = tensor(\"@model_path/weights/weight.bin\"), offset = tensor(%lu)))];\n" + " tensor Wv = const()[name = tensor(\"Wv\"), " + "val = tensor(BLOBFILE(path = tensor(\"@model_path/weights/weight.bin\"), offset = tensor(%lu)))];\n" + " tensor q16 = conv(dilations = c_dilations, groups = c_groups, " + "pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wq, x = x16)[name = tensor(\"conv_q\")];\n" + " tensor k16 = conv(dilations = c_dilations, groups = c_groups, " + "pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wk, x = x16)[name = tensor(\"conv_k\")];\n" + " tensor v16 = conv(dilations = c_dilations, groups = c_groups, " + "pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wv, x = x16)[name = tensor(\"conv_v\")];\n" + " tensor to_fp32 = const()[name = tensor(\"to_fp32\"), val = tensor(\"fp32\")];\n" + " tensor q = cast(dtype = to_fp32, x = q16)[name = tensor(\"cast_q\")];\n" + " tensor k = cast(dtype = to_fp32, x = k16)[name = tensor(\"cast_k\")];\n" + " tensor v = cast(dtype = to_fp32, x = v16)[name = tensor(\"cast_v\")];\n" + " } -> (q, k, v);\n" + "}\n", + dim, spatial, dim, spatial, + q_dim, dim, q_dim, dim, + kv_dim, dim, kv_dim, dim, (unsigned long)off_k, + kv_dim, dim, kv_dim, dim, (unsigned long)off_v, + q_dim, spatial, kv_dim, spatial, kv_dim, spatial, + q_dim, spatial, kv_dim, spatial, kv_dim, spatial]; +} + +// Build weight blob for GQA QKV (3 weight matrices with different shapes) +static NSData *mil_build_qkv_gqa_weight_blob(const float *wq, int q_dim, int dim, + const float *wk, const float *wv, int kv_dim) { + NSUInteger wsize_q = (NSUInteger)q_dim * dim * 2; + NSUInteger wsize_kv = (NSUInteger)kv_dim * dim * 2; + NSUInteger cs_q = 64 + wsize_q; + NSUInteger cs_kv = 64 + wsize_kv; + NSUInteger total = 64 + cs_q + 2 * cs_kv; + uint8_t *buf = (uint8_t*)calloc(total, 1); + buf[0] = 0x01; buf[4] = 0x02; + + // Chunk 0: Wq + { + uint8_t *chunk = buf + 64; + chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01; + *(uint32_t*)(chunk + 8) = (uint32_t)wsize_q; + *(uint32_t*)(chunk + 16) = (uint32_t)(64 + 64); + _Float16 *fp16 = (_Float16*)(chunk + 64); + for (NSUInteger i = 0; i < (NSUInteger)q_dim * dim; i++) fp16[i] = (_Float16)wq[i]; + } + // Chunk 1: Wk + { + uint8_t *chunk = buf + 64 + cs_q; + chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01; + *(uint32_t*)(chunk + 8) = (uint32_t)wsize_kv; + *(uint32_t*)(chunk + 16) = (uint32_t)(64 + cs_q + 64); + _Float16 *fp16 = (_Float16*)(chunk + 64); + for (NSUInteger i = 0; i < (NSUInteger)kv_dim * dim; i++) fp16[i] = (_Float16)wk[i]; + } + // Chunk 2: Wv + { + uint8_t *chunk = buf + 64 + cs_q + cs_kv; + chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01; + *(uint32_t*)(chunk + 8) = (uint32_t)wsize_kv; + *(uint32_t*)(chunk + 16) = (uint32_t)(64 + cs_q + cs_kv + 64); + _Float16 *fp16 = (_Float16*)(chunk + 64); + for (NSUInteger i = 0; i < (NSUInteger)kv_dim * dim; i++) fp16[i] = (_Float16)wv[i]; + } + return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES]; +} + // Generate MIL for fused FFN up: w1 + w3 parallel convs static NSString *mil_gen_ffn_up(int dim, int hidden_dim, int spatial) { NSUInteger cs = 64 + (NSUInteger)hidden_dim * dim * 2;