mirror of https://github.com/maderix/ANE.git
[feat][gpu] Q4 quantization, Metal GPU shaders, ANE kernel fusion, memory safety
This commit is contained in:
parent
0e70f5bd71
commit
be96079bbf
|
|
@ -25,6 +25,9 @@ training/test_*
|
||||||
# Inference binaries and runtime data
|
# Inference binaries and runtime data
|
||||||
inference/qwen_ane
|
inference/qwen_ane
|
||||||
inference/qwen05b.bin
|
inference/qwen05b.bin
|
||||||
|
inference/qwen05b_f32.bin
|
||||||
|
inference/qwen05b_f16.bin
|
||||||
|
inference/qwen05b_q8.bin
|
||||||
inference/.venv/
|
inference/.venv/
|
||||||
inference/benchmark_results.json
|
inference/benchmark_results.json
|
||||||
|
|
||||||
|
|
@ -59,6 +62,7 @@ web/
|
||||||
training/tinystories_data00.bin
|
training/tinystories_data00.bin
|
||||||
training/ane_stories110M_ckpt.bin
|
training/ane_stories110M_ckpt.bin
|
||||||
*.bin
|
*.bin
|
||||||
|
*.metallib
|
||||||
!training/download_data.sh
|
!training/download_data.sh
|
||||||
|
|
||||||
# Secrets / env
|
# Secrets / env
|
||||||
|
|
|
||||||
|
|
@ -66,125 +66,235 @@ MEM_BYTES=$(sysctl -n hw.memsize 2>/dev/null || echo "0")
|
||||||
MEM_GB=$((MEM_BYTES / 1073741824))
|
MEM_GB=$((MEM_BYTES / 1073741824))
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
info "=== ANE Inference Benchmark (qwen_ane) ==="
|
info "=== ANE Multi-Format Inference Benchmark ==="
|
||||||
echo "Hardware: $CHIP"
|
echo "Hardware: $CHIP"
|
||||||
echo "macOS: $MACOS"
|
echo "macOS: $MACOS"
|
||||||
echo "Memory: ${MEM_GB} GB"
|
echo "Memory: ${MEM_GB} GB"
|
||||||
echo "Model: Qwen2.5-0.5B-Instruct (BF16, 494M params)"
|
echo "Model: Qwen2.5-0.5B-Instruct (494M params)"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# --- Phase 1: Server mode benchmark (HTTP API) ---
|
# --- Phase 0: Prepare weight files (F16 + Q8) ---
|
||||||
info "Phase 1: Server mode (persistent ANE kernels via HTTP API)"
|
WEIGHTS_F16="$SCRIPT_DIR/qwen05b.bin"
|
||||||
dim "Starting server on port $HTTP_PORT..."
|
WEIGHTS_Q8="$SCRIPT_DIR/qwen05b_q8.bin"
|
||||||
|
WEIGHTS_Q4="$SCRIPT_DIR/qwen05b_q4.bin"
|
||||||
|
CONVERT="$SCRIPT_DIR/convert_weights.py"
|
||||||
|
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||||
|
|
||||||
# Start HTTP server in background
|
info "Phase 0: Preparing weight files"
|
||||||
"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 &
|
|
||||||
SERVER_PID=$!
|
|
||||||
|
|
||||||
|
if [ ! -f "$WEIGHTS_Q8" ]; then
|
||||||
|
if [ ! -f "$CONVERT" ]; then
|
||||||
|
echo " convert_weights.py not found, skipping Q8 generation."
|
||||||
|
WEIGHTS_Q8=""
|
||||||
|
else
|
||||||
|
dim "Generating Q8 weights (one-time)..."
|
||||||
|
if [ -d "$VENV_DIR" ]; then
|
||||||
|
source "$VENV_DIR/bin/activate"
|
||||||
|
fi
|
||||||
|
python3 "$CONVERT" "$MODEL_DIR" "$WEIGHTS_Q8" --q8
|
||||||
|
dim "Q8 weights ready: $(du -h "$WEIGHTS_Q8" | cut -f1)"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
dim "Q8 weights already exist: $(du -h "$WEIGHTS_Q8" | cut -f1)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "$WEIGHTS_Q4" ]; then
|
||||||
|
if [ ! -f "$CONVERT" ]; then
|
||||||
|
echo " convert_weights.py not found, skipping Q4 generation."
|
||||||
|
WEIGHTS_Q4=""
|
||||||
|
else
|
||||||
|
dim "Generating Q4 weights (one-time)..."
|
||||||
|
if [ -d "$VENV_DIR" ]; then
|
||||||
|
source "$VENV_DIR/bin/activate"
|
||||||
|
fi
|
||||||
|
python3 "$CONVERT" "$MODEL_DIR" "$WEIGHTS_Q4" --q4
|
||||||
|
dim "Q4 weights ready: $(du -h "$WEIGHTS_Q4" | cut -f1)"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
dim "Q4 weights already exist: $(du -h "$WEIGHTS_Q4" | cut -f1)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
dim "F16 weights: $(du -h "$WEIGHTS_F16" | cut -f1)"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ANE weight formats to benchmark
|
||||||
|
# GPU flag: empty for CPU formats, "--gpu" for Metal GPU formats
|
||||||
|
ANE_FMT_NAMES=("F16")
|
||||||
|
ANE_FMT_WEIGHTS=("$WEIGHTS_F16")
|
||||||
|
ANE_FMT_LABELS=("F16→F32 (AMX)")
|
||||||
|
ANE_FMT_GPU=("")
|
||||||
|
|
||||||
|
if [ -n "$WEIGHTS_Q8" ] && [ -f "$WEIGHTS_Q8" ]; then
|
||||||
|
ANE_FMT_NAMES+=("Q8")
|
||||||
|
ANE_FMT_WEIGHTS+=("$WEIGHTS_Q8")
|
||||||
|
ANE_FMT_LABELS+=("Q8 (NEON dequant)")
|
||||||
|
ANE_FMT_GPU+=("")
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$WEIGHTS_Q4" ] && [ -f "$WEIGHTS_Q4" ]; then
|
||||||
|
ANE_FMT_NAMES+=("Q4_Metal")
|
||||||
|
ANE_FMT_WEIGHTS+=("$WEIGHTS_Q4")
|
||||||
|
ANE_FMT_LABELS+=("Q4 SIMD (Metal GPU)")
|
||||||
|
ANE_FMT_GPU+=("--gpu")
|
||||||
|
|
||||||
|
ANE_FMT_NAMES+=("Q4_AMX")
|
||||||
|
ANE_FMT_WEIGHTS+=("$WEIGHTS_Q4")
|
||||||
|
ANE_FMT_LABELS+=("Q4→F32 (AMX dequant)")
|
||||||
|
ANE_FMT_GPU+=("")
|
||||||
|
fi
|
||||||
|
|
||||||
|
NUM_ANE_FMTS=${#ANE_FMT_NAMES[@]}
|
||||||
|
NUM_PROMPTS=${#PROMPTS[@]}
|
||||||
|
|
||||||
|
# Global cleanup
|
||||||
|
SERVER_PID=""
|
||||||
cleanup() {
|
cleanup() {
|
||||||
kill "$SERVER_PID" 2>/dev/null || true
|
[ -n "$SERVER_PID" ] && kill "$SERVER_PID" 2>/dev/null || true
|
||||||
rm -f "$SOCK" /tmp/qwen_bench_server.log
|
rm -f "$SOCK" /tmp/qwen_bench_server.log
|
||||||
}
|
}
|
||||||
trap cleanup EXIT
|
trap cleanup EXIT
|
||||||
|
|
||||||
# Wait for READY
|
# Helper: start server with given weight file and optional extra flags, wait for READY
|
||||||
for i in $(seq 1 30); do
|
start_server() {
|
||||||
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then
|
local wfile="$1"
|
||||||
break
|
shift
|
||||||
fi
|
local extra_flags="$*"
|
||||||
|
[ -n "$SERVER_PID" ] && kill "$SERVER_PID" 2>/dev/null || true
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
rm -f /tmp/qwen_bench_server.log
|
||||||
|
"$BINARY" "$wfile" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" $extra_flags > /tmp/qwen_bench_server.log 2>&1 &
|
||||||
if ! grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then
|
SERVER_PID=$!
|
||||||
echo "Server failed to start. Log:"
|
for _i in $(seq 1 30); do
|
||||||
|
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then return 0; fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Server failed to start with $wfile. Log:"
|
||||||
cat /tmp/qwen_bench_server.log
|
cat /tmp/qwen_bench_server.log
|
||||||
exit 1
|
return 1
|
||||||
fi
|
}
|
||||||
dim "Server ready (PID $SERVER_PID)"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Warmup: first request primes any remaining caches
|
# --- Phase 1: Multi-format ANE benchmarks ---
|
||||||
dim "Warmup run (discarded)..."
|
# Per-format result tracking (indexed by format number)
|
||||||
curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
|
declare -a ALL_AVG_P ALL_AVG_D ALL_AVG_INF ALL_AVG_TTFT ALL_AVG_RT
|
||||||
-H "Content-Type: application/json" \
|
ANE_JSON_BLOCKS=""
|
||||||
-d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Print table header
|
for fmt_idx in $(seq 0 $((NUM_ANE_FMTS - 1))); do
|
||||||
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
|
FMT_NAME="${ANE_FMT_NAMES[$fmt_idx]}"
|
||||||
"Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead"
|
FMT_WEIGHTS="${ANE_FMT_WEIGHTS[$fmt_idx]}"
|
||||||
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
|
FMT_LABEL="${ANE_FMT_LABELS[$fmt_idx]}"
|
||||||
"" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)"
|
FMT_GPU="${ANE_FMT_GPU[$fmt_idx]}"
|
||||||
printf '%.0s─' {1..85}; echo ""
|
|
||||||
|
|
||||||
# Arrays for averages
|
echo ""
|
||||||
declare -a P_TPS_ARR D_TPS_ARR INF_MS_ARR TTFT_MS_ARR RT_MS_ARR
|
info "Phase 1.$((fmt_idx+1)): ANE $FMT_NAME benchmark ($FMT_LABEL)"
|
||||||
|
dim "Weights: $(du -h "$FMT_WEIGHTS" | cut -f1) — Starting server..."
|
||||||
|
|
||||||
JSON_ENTRIES=""
|
if ! start_server "$FMT_WEIGHTS" $FMT_GPU; then
|
||||||
NUM_PROMPTS=${#PROMPTS[@]}
|
echo "Skipping $FMT_NAME format."
|
||||||
|
ALL_AVG_P+=("0"); ALL_AVG_D+=("0"); ALL_AVG_INF+=("0")
|
||||||
|
ALL_AVG_TTFT+=("0"); ALL_AVG_RT+=("0")
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
|
dim "Warmup run (discarded)..."
|
||||||
NAME="${PROMPT_NAMES[$i]}"
|
curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
|
||||||
PROMPT="${PROMPTS[$i]}"
|
|
||||||
MAXTOK="${MAX_TOKENS[$i]}"
|
|
||||||
|
|
||||||
RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
|
||||||
RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
|
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1)
|
-d '{"prompt":"warmup","max_tokens":5}' > /dev/null 2>&1
|
||||||
RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
echo ""
|
||||||
RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
|
|
||||||
|
|
||||||
# Parse server JSON with pure shell -- no python
|
|
||||||
P_TOKENS=$(json_val "$RESP" "prompt_tokens")
|
|
||||||
G_TOKENS=$(json_val "$RESP" "gen_tokens")
|
|
||||||
P_TPS=$(json_val "$RESP" "prefill_tps")
|
|
||||||
D_TPS=$(json_val "$RESP" "decode_tps")
|
|
||||||
TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")")
|
|
||||||
INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")")
|
|
||||||
TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")")
|
|
||||||
TEXT=$(json_text "$RESP")
|
|
||||||
OVERHEAD=$((RT_MS - TOTAL_MS))
|
|
||||||
|
|
||||||
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
|
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
|
||||||
"$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD"
|
"Prompt" "In" "Out" "Prefill" "Decode" "TTFT" "Infer" "Rndtrip" "Overhead"
|
||||||
|
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
|
||||||
|
"" "tok" "tok" "(t/s)" "(t/s)" "(ms)" "(ms)" "(ms)" "(ms)"
|
||||||
|
printf '%.0s─' {1..85}; echo ""
|
||||||
|
|
||||||
P_TPS_ARR+=("$P_TPS")
|
declare -a P_TPS_ARR=() D_TPS_ARR=() INF_MS_ARR=() TTFT_MS_ARR=() RT_MS_ARR=()
|
||||||
D_TPS_ARR+=("$D_TPS")
|
FMT_JSON_ENTRIES=""
|
||||||
INF_MS_ARR+=("$INF_MS")
|
|
||||||
TTFT_MS_ARR+=("$TTFT_MS")
|
|
||||||
RT_MS_ARR+=("$RT_MS")
|
|
||||||
|
|
||||||
# Build JSON entry
|
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
|
||||||
JSON_ENTRIES="$JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS},"
|
NAME="${PROMPT_NAMES[$i]}"
|
||||||
|
PROMPT="${PROMPTS[$i]}"
|
||||||
|
MAXTOK="${MAX_TOKENS[$i]}"
|
||||||
|
|
||||||
# Print response text indented below
|
RT_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
||||||
echo " → $TEXT"
|
RESP=$(curl -s "http://127.0.0.1:$HTTP_PORT/v1/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d "{\"prompt\": \"$PROMPT\", \"max_tokens\": $MAXTOK}" 2>&1)
|
||||||
|
RT_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
||||||
|
RT_MS=$(echo "$RT_T0 $RT_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
|
||||||
|
|
||||||
|
P_TOKENS=$(json_val "$RESP" "prompt_tokens")
|
||||||
|
G_TOKENS=$(json_val "$RESP" "gen_tokens")
|
||||||
|
P_TPS=$(json_val "$RESP" "prefill_tps")
|
||||||
|
D_TPS=$(json_val "$RESP" "decode_tps")
|
||||||
|
TTFT_MS=$(trunc "$(json_val "$RESP" "ttft_ms")")
|
||||||
|
INF_MS=$(trunc "$(json_val "$RESP" "inference_ms")")
|
||||||
|
TOTAL_MS=$(trunc "$(json_val "$RESP" "total_ms")")
|
||||||
|
TEXT=$(json_text "$RESP")
|
||||||
|
OVERHEAD=$((RT_MS - TOTAL_MS))
|
||||||
|
|
||||||
|
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" \
|
||||||
|
"$NAME" "$P_TOKENS" "$G_TOKENS" "$P_TPS" "$D_TPS" "$TTFT_MS" "$INF_MS" "$RT_MS" "$OVERHEAD"
|
||||||
|
|
||||||
|
P_TPS_ARR+=("$P_TPS")
|
||||||
|
D_TPS_ARR+=("$D_TPS")
|
||||||
|
INF_MS_ARR+=("$INF_MS")
|
||||||
|
TTFT_MS_ARR+=("$TTFT_MS")
|
||||||
|
RT_MS_ARR+=("$RT_MS")
|
||||||
|
|
||||||
|
FMT_JSON_ENTRIES="$FMT_JSON_ENTRIES{\"name\":\"$NAME\",\"prompt_tokens\":$P_TOKENS,\"gen_tokens\":$G_TOKENS,\"prefill_tps\":$P_TPS,\"decode_tps\":$D_TPS,\"ttft_ms\":$TTFT_MS,\"inference_ms\":$INF_MS,\"roundtrip_ms\":$RT_MS},"
|
||||||
|
|
||||||
|
echo " → $TEXT"
|
||||||
|
echo ""
|
||||||
|
done
|
||||||
|
|
||||||
|
printf '%.0s─' {1..85}; echo ""
|
||||||
|
|
||||||
|
F_AVG_P=$(shell_avg "${P_TPS_ARR[@]}")
|
||||||
|
F_AVG_D=$(shell_avg "${D_TPS_ARR[@]}")
|
||||||
|
F_AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}")
|
||||||
|
F_AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}")
|
||||||
|
F_AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}")
|
||||||
|
F_AVG_OVERHEAD=$((F_AVG_RT - F_AVG_INF))
|
||||||
|
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$F_AVG_P" "$F_AVG_D" "$F_AVG_TTFT" "$F_AVG_INF" "$F_AVG_RT" "$F_AVG_OVERHEAD"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
|
ALL_AVG_P+=("$F_AVG_P")
|
||||||
|
ALL_AVG_D+=("$F_AVG_D")
|
||||||
|
ALL_AVG_INF+=("$F_AVG_INF")
|
||||||
|
ALL_AVG_TTFT+=("$F_AVG_TTFT")
|
||||||
|
ALL_AVG_RT+=("$F_AVG_RT")
|
||||||
|
|
||||||
|
ANE_JSON_BLOCKS="$ANE_JSON_BLOCKS
|
||||||
|
\"$FMT_NAME\": {
|
||||||
|
\"format\": \"$FMT_NAME\",
|
||||||
|
\"label\": \"$FMT_LABEL\",
|
||||||
|
\"weight_size_mb\": $(du -m "$FMT_WEIGHTS" | cut -f1),
|
||||||
|
\"avg_prefill_tps\": $F_AVG_P,
|
||||||
|
\"avg_decode_tps\": $F_AVG_D,
|
||||||
|
\"avg_inference_ms\": $F_AVG_INF,
|
||||||
|
\"avg_roundtrip_ms\": $F_AVG_RT,
|
||||||
|
\"avg_ttft_ms\": $F_AVG_TTFT,
|
||||||
|
\"results\": [${FMT_JSON_ENTRIES%,}]
|
||||||
|
},"
|
||||||
done
|
done
|
||||||
|
|
||||||
printf '%.0s─' {1..85}; echo ""
|
# Use F16 results as the primary ANE numbers (first format)
|
||||||
|
AVG_P="${ALL_AVG_P[0]}"
|
||||||
|
AVG_D="${ALL_AVG_D[0]}"
|
||||||
|
AVG_INF="${ALL_AVG_INF[0]}"
|
||||||
|
AVG_TTFT="${ALL_AVG_TTFT[0]}"
|
||||||
|
AVG_RT="${ALL_AVG_RT[0]}"
|
||||||
|
|
||||||
# Averages (pure shell, no python)
|
|
||||||
AVG_P=$(shell_avg "${P_TPS_ARR[@]}")
|
|
||||||
AVG_D=$(shell_avg "${D_TPS_ARR[@]}")
|
|
||||||
AVG_INF=$(shell_avg_int "${INF_MS_ARR[@]}")
|
|
||||||
AVG_TTFT=$(shell_avg_int "${TTFT_MS_ARR[@]}")
|
|
||||||
AVG_RT=$(shell_avg_int "${RT_MS_ARR[@]}")
|
|
||||||
AVG_OVERHEAD=$((AVG_RT - AVG_INF))
|
|
||||||
printf "%-10s %5s %5s %10s %10s %10s %10s %10s %10s\n" "Average" "" "" "$AVG_P" "$AVG_D" "$AVG_TTFT" "$AVG_INF" "$AVG_RT" "$AVG_OVERHEAD"
|
|
||||||
echo ""
|
|
||||||
info "Infer = server-reported (pure processing). Rndtrip = wall-clock (what clients see)."
|
info "Infer = server-reported (pure processing). Rndtrip = wall-clock (what clients see)."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# --- Phase 2: Cold start measurement ---
|
# --- Phase 2: Cold start measurement ---
|
||||||
info "Phase 2: Cold start (single-shot, recompiles ANE kernels)"
|
info "Phase 2: Cold start (single-shot, recompiles ANE kernels)"
|
||||||
|
|
||||||
# Kill server, run single-shot
|
|
||||||
kill "$SERVER_PID" 2>/dev/null || true
|
kill "$SERVER_PID" 2>/dev/null || true
|
||||||
|
SERVER_PID=""
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
# Use perl for sub-second timing (available on all macOS, no python)
|
|
||||||
COLD_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
COLD_T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
||||||
COLD_OUT=$("$BINARY" "$WEIGHTS" "151644 8948 198 2610 525 264 10950 17847 13 151645 198 151644 872 198 13048 151645 198 151644 77091 198" 10 2>&1 || true)
|
COLD_OUT=$("$BINARY" "$WEIGHTS" "151644 8948 198 2610 525 264 10950 17847 13 151645 198 151644 872 198 13048 151645 198 151644 77091 198" 10 2>&1 || true)
|
||||||
COLD_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
COLD_T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
||||||
|
|
@ -193,17 +303,11 @@ COLD_MS=$(echo "$COLD_T0 $COLD_T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
|
||||||
echo "Cold start latency: ${COLD_MS}ms (includes ANE kernel compilation)"
|
echo "Cold start latency: ${COLD_MS}ms (includes ANE kernel compilation)"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Re-start server for any additional tests
|
# Re-start server (F16) for consistency check
|
||||||
"$BINARY" "$WEIGHTS" --http "$HTTP_PORT" --model-dir "$MODEL_DIR" > /tmp/qwen_bench_server.log 2>&1 &
|
start_server "$WEIGHTS_F16"
|
||||||
SERVER_PID=$!
|
|
||||||
|
|
||||||
# --- Phase 3: Repeated prompt (consistency check) ---
|
# --- Phase 3: Repeated prompt (consistency check) ---
|
||||||
info "Phase 3: Decode speed consistency (5x same prompt)"
|
info "Phase 3: Decode speed consistency (5x same prompt, F16)"
|
||||||
|
|
||||||
for retry in $(seq 1 15); do
|
|
||||||
if grep -q "READY" /tmp/qwen_bench_server.log 2>/dev/null; then break; fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
|
|
||||||
printf "%-6s %10s %10s %10s\n" "Run" "Prefill" "Decode" "Infer(ms)"
|
printf "%-6s %10s %10s %10s\n" "Run" "Prefill" "Decode" "Infer(ms)"
|
||||||
printf '%.0s─' {1..40}; echo ""
|
printf '%.0s─' {1..40}; echo ""
|
||||||
|
|
@ -227,12 +331,8 @@ JSON="{
|
||||||
\"model\": \"Qwen2.5-0.5B-Instruct\",
|
\"model\": \"Qwen2.5-0.5B-Instruct\",
|
||||||
\"mode\": \"http_server\",
|
\"mode\": \"http_server\",
|
||||||
\"cold_start_ms\": $COLD_MS,
|
\"cold_start_ms\": $COLD_MS,
|
||||||
\"avg_prefill_tps\": $AVG_P,
|
\"ane_formats\": {$( echo "$ANE_JSON_BLOCKS" | sed '$ s/,$//' )
|
||||||
\"avg_decode_tps\": $AVG_D,
|
}
|
||||||
\"avg_inference_ms\": $AVG_INF,
|
|
||||||
\"avg_roundtrip_ms\": $AVG_RT,
|
|
||||||
\"avg_ttft_ms\": $AVG_TTFT,
|
|
||||||
\"results\": [${JSON_ENTRIES%,}]
|
|
||||||
}"
|
}"
|
||||||
echo "$JSON" > "$RESULTS_JSON"
|
echo "$JSON" > "$RESULTS_JSON"
|
||||||
dim "Results saved to $RESULTS_JSON"
|
dim "Results saved to $RESULTS_JSON"
|
||||||
|
|
@ -240,9 +340,12 @@ echo ""
|
||||||
|
|
||||||
# --- Phase 4: LM Studio comparison (if running) ---
|
# --- Phase 4: LM Studio comparison (if running) ---
|
||||||
LMS_PORT="${LMS_PORT:-1234}"
|
LMS_PORT="${LMS_PORT:-1234}"
|
||||||
LMS_MODEL="${LMS_MODEL:-qwen2.5-0.5b-instruct}"
|
|
||||||
LMS_API_KEY="${LMS_API_KEY:-}"
|
LMS_API_KEY="${LMS_API_KEY:-}"
|
||||||
|
|
||||||
|
# Models to benchmark (override via LMS_MODELS env var, comma-separated)
|
||||||
|
LMS_MODELS_DEFAULT="qwen2.5-0.5b-instruct,qwen2.5-0.5b-instruct-mlx@8bit,qwen2.5-0.5b-instruct-mlx@4bit"
|
||||||
|
IFS=',' read -ra LMS_MODEL_LIST <<< "${LMS_MODELS:-$LMS_MODELS_DEFAULT}"
|
||||||
|
|
||||||
# Check if LM Studio is running
|
# Check if LM Studio is running
|
||||||
LMS_REACHABLE=0
|
LMS_REACHABLE=0
|
||||||
if curl -s --max-time 2 "http://localhost:$LMS_PORT/api/v1/chat" -H "Content-Type: application/json" -d '{}' >/dev/null 2>&1; then
|
if curl -s --max-time 2 "http://localhost:$LMS_PORT/api/v1/chat" -H "Content-Type: application/json" -d '{}' >/dev/null 2>&1; then
|
||||||
|
|
@ -251,8 +354,8 @@ fi
|
||||||
|
|
||||||
if [ "$LMS_REACHABLE" -eq 1 ]; then
|
if [ "$LMS_REACHABLE" -eq 1 ]; then
|
||||||
info "Phase 4: LM Studio comparison (localhost:$LMS_PORT)"
|
info "Phase 4: LM Studio comparison (localhost:$LMS_PORT)"
|
||||||
|
dim "Models: ${LMS_MODEL_LIST[*]}"
|
||||||
|
|
||||||
# If no API key, prompt for it
|
|
||||||
if [ -z "$LMS_API_KEY" ]; then
|
if [ -z "$LMS_API_KEY" ]; then
|
||||||
echo ""
|
echo ""
|
||||||
echo " LM Studio requires an API key."
|
echo " LM Studio requires an API key."
|
||||||
|
|
@ -268,30 +371,53 @@ if [ "$LMS_REACHABLE" -eq 1 ]; then
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
LMS_ALL_JSON=""
|
||||||
|
|
||||||
if [ "$LMS_REACHABLE" -eq 1 ] && [ -n "$LMS_API_KEY" ]; then
|
if [ "$LMS_REACHABLE" -eq 1 ] && [ -n "$LMS_API_KEY" ]; then
|
||||||
echo ""
|
|
||||||
printf "%-10s %5s %5s %10s %10s %10s\n" \
|
|
||||||
"Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip"
|
|
||||||
printf "%-10s %5s %5s %10s %10s %10s\n" \
|
|
||||||
"" "tok" "tok" "(t/s)" "(ms)" "(ms)"
|
|
||||||
printf '%.0s─' {1..55}; echo ""
|
|
||||||
|
|
||||||
declare -a LMS_LATENCIES LMS_TPS_ARR LMS_TTFT_ARR
|
# Track the best model for the final comparison table
|
||||||
LMS_JSON_ENTRIES=""
|
BEST_LMS_MODEL=""
|
||||||
|
BEST_LMS_TPS="0"
|
||||||
|
BEST_LMS_LAT="99999"
|
||||||
|
BEST_LMS_TTFT="0"
|
||||||
|
|
||||||
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
|
for LMS_MODEL in "${LMS_MODEL_LIST[@]}"; do
|
||||||
NAME="${PROMPT_NAMES[$i]}"
|
echo ""
|
||||||
PROMPT="${PROMPTS[$i]}"
|
info "── $LMS_MODEL ──"
|
||||||
|
|
||||||
T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
# Test if this model is available
|
||||||
LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \
|
TEST_RESP=$(curl -s --max-time 10 "http://localhost:$LMS_PORT/api/v1/chat" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-H "Authorization: Bearer $LMS_API_KEY" \
|
-H "Authorization: Bearer $LMS_API_KEY" \
|
||||||
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1)
|
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"test\",\"input\":\"hi\"}" 2>&1)
|
||||||
T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
|
||||||
LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
|
|
||||||
|
|
||||||
eval "$(echo "$LMS_RESP" | python3 -c "
|
if echo "$TEST_RESP" | grep -qi "error\|not found\|not loaded\|no model"; then
|
||||||
|
dim " Model '$LMS_MODEL' not available, skipping."
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
printf "%-10s %5s %5s %10s %10s %10s\n" \
|
||||||
|
"Prompt" "In" "Out" "Decode" "TTFT" "Rndtrip"
|
||||||
|
printf "%-10s %5s %5s %10s %10s %10s\n" \
|
||||||
|
"" "tok" "tok" "(t/s)" "(ms)" "(ms)"
|
||||||
|
printf '%.0s─' {1..55}; echo ""
|
||||||
|
|
||||||
|
declare -a LMS_LATENCIES=() LMS_TPS_ARR=() LMS_TTFT_ARR=()
|
||||||
|
LMS_JSON_ENTRIES=""
|
||||||
|
|
||||||
|
for i in $(seq 0 $((NUM_PROMPTS - 1))); do
|
||||||
|
NAME="${PROMPT_NAMES[$i]}"
|
||||||
|
PROMPT="${PROMPTS[$i]}"
|
||||||
|
|
||||||
|
T0=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
||||||
|
LMS_RESP=$(curl -s --max-time 120 "http://localhost:$LMS_PORT/api/v1/chat" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer $LMS_API_KEY" \
|
||||||
|
-d "{\"model\":\"$LMS_MODEL\",\"system_prompt\":\"You are a helpful assistant. Be concise.\",\"input\":\"$PROMPT\"}" 2>&1)
|
||||||
|
T1=$(perl -MTime::HiRes=time -e 'printf "%.3f", time')
|
||||||
|
LMS_MS=$(echo "$T0 $T1" | awk '{printf "%.0f", ($2 - $1) * 1000}')
|
||||||
|
|
||||||
|
eval "$(echo "$LMS_RESP" | python3 -c "
|
||||||
import sys, json
|
import sys, json
|
||||||
try:
|
try:
|
||||||
r = json.load(sys.stdin)
|
r = json.load(sys.stdin)
|
||||||
|
|
@ -314,69 +440,191 @@ except Exception as e:
|
||||||
print('LMS_OUT=0')
|
print('LMS_OUT=0')
|
||||||
" 2>/dev/null)"
|
" 2>/dev/null)"
|
||||||
|
|
||||||
printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS"
|
printf "%-10s %5s %5s %10s %10s %10s\n" "$NAME" "$LMS_IN" "$LMS_OUT" "$LMS_TPS" "$LMS_TTFT" "$LMS_MS"
|
||||||
echo " → $LMS_TEXT"
|
LMS_LATENCIES+=("$LMS_MS")
|
||||||
echo ""
|
LMS_TPS_ARR+=("$LMS_TPS")
|
||||||
LMS_LATENCIES+=("$LMS_MS")
|
LMS_TTFT_ARR+=("$LMS_TTFT")
|
||||||
LMS_TPS_ARR+=("$LMS_TPS")
|
LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT},"
|
||||||
LMS_TTFT_ARR+=("$LMS_TTFT")
|
done
|
||||||
LMS_JSON_ENTRIES="$LMS_JSON_ENTRIES{\"name\":\"$NAME\",\"latency_ms\":$LMS_MS,\"tps\":$LMS_TPS,\"ttft_ms\":$LMS_TTFT,\"input_tokens\":$LMS_IN,\"output_tokens\":$LMS_OUT},"
|
|
||||||
|
printf '%.0s─' {1..55}; echo ""
|
||||||
|
|
||||||
|
M_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}")
|
||||||
|
M_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}")
|
||||||
|
M_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}")
|
||||||
|
printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$M_AVG_TPS" "$M_AVG_TTFT" "$M_AVG_LAT"
|
||||||
|
|
||||||
|
# Track the best model by decode t/s
|
||||||
|
if awk "BEGIN {exit !($M_AVG_TPS > $BEST_LMS_TPS)}" 2>/dev/null; then
|
||||||
|
BEST_LMS_MODEL="$LMS_MODEL"
|
||||||
|
BEST_LMS_TPS="$M_AVG_TPS"
|
||||||
|
BEST_LMS_LAT="$M_AVG_LAT"
|
||||||
|
BEST_LMS_TTFT="$M_AVG_TTFT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
LMS_ALL_JSON="$LMS_ALL_JSON
|
||||||
|
\"$(echo "$LMS_MODEL" | sed 's/[^a-zA-Z0-9._-]/_/g')\": {
|
||||||
|
\"model\": \"$LMS_MODEL\",
|
||||||
|
\"avg_latency_ms\": $M_AVG_LAT,
|
||||||
|
\"avg_tps\": $M_AVG_TPS,
|
||||||
|
\"avg_ttft_ms\": $M_AVG_TTFT,
|
||||||
|
\"results\": [${LMS_JSON_ENTRIES%,}]
|
||||||
|
},"
|
||||||
done
|
done
|
||||||
|
|
||||||
printf '%.0s─' {1..55}; echo ""
|
|
||||||
|
|
||||||
# Averages (awk, no python)
|
|
||||||
LMS_AVG_LAT=$(shell_avg_int "${LMS_LATENCIES[@]}")
|
|
||||||
LMS_AVG_TPS=$(shell_avg "${LMS_TPS_ARR[@]}")
|
|
||||||
LMS_AVG_TTFT=$(shell_avg_int "${LMS_TTFT_ARR[@]}")
|
|
||||||
printf "%-10s %5s %5s %10s %10s %10s\n" "Average" "" "" "$LMS_AVG_TPS" "$LMS_AVG_TTFT" "$LMS_AVG_LAT"
|
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Side-by-side comparison
|
# --- Final Comparison Table: all ANE formats + all LM Studio models ---
|
||||||
info "=== Side-by-Side Comparison ==="
|
info "=== Multi-Format Comparison ==="
|
||||||
dim "(Round-trip = wall-clock from client, apples-to-apples)"
|
dim "(All times are wall-clock round-trip, apples-to-apples)"
|
||||||
echo ""
|
|
||||||
printf "%-24s %15s %15s\n" "" "ANE (qwen_ane)" "LM Studio"
|
|
||||||
printf '%.0s─' {1..56}; echo ""
|
|
||||||
printf "%-24s %12s t/s %12s t/s\n" "Decode speed" "$AVG_D" "$LMS_AVG_TPS"
|
|
||||||
printf "%-24s %12s t/s %12s\n" "Prefill speed" "$AVG_P" "N/A"
|
|
||||||
printf "%-24s %12s ms %12s ms\n" "TTFT" "$AVG_TTFT" "$LMS_AVG_TTFT"
|
|
||||||
printf "%-24s %12s ms %12s ms\n" "Avg round-trip" "$AVG_RT" "$LMS_AVG_LAT"
|
|
||||||
printf "%-24s %12s ms %12s ms\n" " (server-only)" "$AVG_INF" "N/A"
|
|
||||||
printf "%-24s %12s ms %12s\n" "Cold start" "$COLD_MS" "N/A"
|
|
||||||
printf "%-24s %15s %15s\n" "Precision" "F32 (from BF16)" "GGUF quantized"
|
|
||||||
printf "%-24s %15s %15s\n" "Accelerator" "Neural Engine" "CPU/GPU"
|
|
||||||
printf "%-24s %15s %15s\n" "Timing method" "Wall-clock" "Wall-clock"
|
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Append LM Studio block to JSON results (pure shell, no python)
|
# Collect all column names and data
|
||||||
# Remove trailing "}" and newline, append lm_studio object
|
declare -a COL_NAMES=() COL_DECODE=() COL_PREFILL=() COL_TTFT=() COL_RT=() COL_PREC=() COL_ACCEL=()
|
||||||
|
|
||||||
|
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do
|
||||||
|
COL_NAMES+=("ANE ${ANE_FMT_NAMES[$fi2]}")
|
||||||
|
COL_DECODE+=("${ALL_AVG_D[$fi2]}")
|
||||||
|
COL_PREFILL+=("${ALL_AVG_P[$fi2]}")
|
||||||
|
COL_TTFT+=("${ALL_AVG_TTFT[$fi2]}")
|
||||||
|
COL_RT+=("${ALL_AVG_RT[$fi2]}")
|
||||||
|
COL_PREC+=("${ANE_FMT_LABELS[$fi2]}")
|
||||||
|
if [ -n "${ANE_FMT_GPU[$fi2]}" ]; then
|
||||||
|
COL_ACCEL+=("Metal GPU")
|
||||||
|
else
|
||||||
|
COL_ACCEL+=("CPU (AMX)")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Add each tested LM Studio model as a column
|
||||||
|
declare -a LMS_TESTED_NAMES=() LMS_TESTED_TPS=() LMS_TESTED_TTFT=() LMS_TESTED_LAT=()
|
||||||
|
for LMS_MODEL in "${LMS_MODEL_LIST[@]}"; do
|
||||||
|
# Check if this model was actually tested (has data in LMS_ALL_JSON)
|
||||||
|
SAFE_KEY=$(echo "$LMS_MODEL" | sed 's/[^a-zA-Z0-9._-]/_/g')
|
||||||
|
if echo "$LMS_ALL_JSON" | grep -q "\"$SAFE_KEY\""; then
|
||||||
|
M_TPS=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_tps":[[:space:]]*\([0-9.]*\).*/\1/p' | head -1)
|
||||||
|
M_TTFT=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_ttft_ms":[[:space:]]*\([0-9]*\).*/\1/p' | head -1)
|
||||||
|
M_LAT=$(echo "$LMS_ALL_JSON" | sed -n "/\"$SAFE_KEY\"/,/}/p" | sed -n 's/.*"avg_latency_ms":[[:space:]]*\([0-9]*\).*/\1/p' | head -1)
|
||||||
|
|
||||||
|
SHORT_NAME=$(echo "$LMS_MODEL" | sed 's/qwen2.5-0.5b-instruct/q0.5b/; s/-mlx/mlx/')
|
||||||
|
COL_NAMES+=("LMS $SHORT_NAME")
|
||||||
|
COL_DECODE+=("${M_TPS:-0}")
|
||||||
|
COL_PREFILL+=("N/A")
|
||||||
|
COL_TTFT+=("${M_TTFT:-0}")
|
||||||
|
COL_RT+=("${M_LAT:-0}")
|
||||||
|
|
||||||
|
PREC_TAG="GGUF"
|
||||||
|
echo "$LMS_MODEL" | grep -q "8bit" && PREC_TAG="MLX 8-bit"
|
||||||
|
echo "$LMS_MODEL" | grep -q "4bit" && PREC_TAG="MLX 4-bit"
|
||||||
|
COL_PREC+=("$PREC_TAG")
|
||||||
|
COL_ACCEL+=("CPU/GPU")
|
||||||
|
|
||||||
|
LMS_TESTED_NAMES+=("$LMS_MODEL")
|
||||||
|
LMS_TESTED_TPS+=("${M_TPS:-0}")
|
||||||
|
LMS_TESTED_TTFT+=("${M_TTFT:-0}")
|
||||||
|
LMS_TESTED_LAT+=("${M_LAT:-0}")
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
NUM_COLS=${#COL_NAMES[@]}
|
||||||
|
COL_W=16
|
||||||
|
|
||||||
|
# Print header row
|
||||||
|
printf "%-20s" ""
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_NAMES[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
printf '%.0s─' $(seq 1 $((20 + NUM_COLS * COL_W))); echo ""
|
||||||
|
|
||||||
|
# Data rows
|
||||||
|
printf "%-20s" "Decode (t/s)"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_DECODE[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "Prefill (t/s)"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_PREFILL[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "TTFT (ms)"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_TTFT[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "Round-trip (ms)"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_RT[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "Cold start (ms)"
|
||||||
|
printf "%${COL_W}s" "$COLD_MS"
|
||||||
|
for c in $(seq 1 $((NUM_COLS - 1))); do printf "%${COL_W}s" "N/A"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf '%.0s─' $(seq 1 $((20 + NUM_COLS * COL_W))); echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "Precision"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_PREC[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "Accelerator"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "${COL_ACCEL[$c]}"; done
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
printf "%-20s" "Timing"
|
||||||
|
for c in $(seq 0 $((NUM_COLS - 1))); do printf "%${COL_W}s" "Wall-clock"; done
|
||||||
|
echo ""
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Append LM Studio results to JSON
|
||||||
LMS_JSON_BLOCK=",
|
LMS_JSON_BLOCK=",
|
||||||
\"lm_studio\": {
|
\"lm_studio\": {
|
||||||
\"port\": $LMS_PORT,
|
\"port\": $LMS_PORT,
|
||||||
\"model\": \"$LMS_MODEL\",
|
\"models_tested\": [$(printf '"%s",' "${LMS_MODEL_LIST[@]}" | sed 's/,$//')],$( echo "$LMS_ALL_JSON" | sed '$ s/,$//' )
|
||||||
\"avg_latency_ms\": $LMS_AVG_LAT,
|
|
||||||
\"avg_tps\": $LMS_AVG_TPS,
|
|
||||||
\"avg_ttft_ms\": $LMS_AVG_TTFT,
|
|
||||||
\"results\": [${LMS_JSON_ENTRIES%,}]
|
|
||||||
}
|
}
|
||||||
}"
|
}"
|
||||||
# Replace the final "}" with the LM Studio block
|
|
||||||
sed -i '' '$ s/}$//' "$RESULTS_JSON"
|
sed -i '' '$ s/}$//' "$RESULTS_JSON"
|
||||||
printf '%s\n' "$LMS_JSON_BLOCK" >> "$RESULTS_JSON"
|
printf '%s\n' "$LMS_JSON_BLOCK" >> "$RESULTS_JSON"
|
||||||
dim "LM Studio results added to $RESULTS_JSON"
|
dim "LM Studio results added to $RESULTS_JSON"
|
||||||
else
|
else
|
||||||
|
# No LM Studio -- print ANE-only comparison if we have multiple formats
|
||||||
|
if [ "$NUM_ANE_FMTS" -gt 1 ]; then
|
||||||
|
info "=== ANE Format Comparison ==="
|
||||||
|
echo ""
|
||||||
|
printf "%-20s" ""
|
||||||
|
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "ANE ${ANE_FMT_NAMES[$fi2]}"; done
|
||||||
|
echo ""
|
||||||
|
printf '%.0s─' $(seq 1 $((20 + NUM_ANE_FMTS * 16))); echo ""
|
||||||
|
printf "%-20s" "Decode (t/s)"
|
||||||
|
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_D[$fi2]}"; done
|
||||||
|
echo ""
|
||||||
|
printf "%-20s" "Prefill (t/s)"
|
||||||
|
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_P[$fi2]}"; done
|
||||||
|
echo ""
|
||||||
|
printf "%-20s" "TTFT (ms)"
|
||||||
|
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_TTFT[$fi2]}"; done
|
||||||
|
echo ""
|
||||||
|
printf "%-20s" "Round-trip (ms)"
|
||||||
|
for fi2 in $(seq 0 $((NUM_ANE_FMTS - 1))); do printf "%16s" "${ALL_AVG_RT[$fi2]}"; done
|
||||||
|
echo ""
|
||||||
|
printf '%.0s─' $(seq 1 $((20 + NUM_ANE_FMTS * 16))); echo ""
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
|
||||||
info "=== LM Studio Comparison ==="
|
info "=== LM Studio Comparison ==="
|
||||||
echo ""
|
echo ""
|
||||||
if [ "$LMS_REACHABLE" -eq 0 ]; then
|
if [ "$LMS_REACHABLE" -eq 0 ]; then
|
||||||
echo " LM Studio server not detected on localhost:$LMS_PORT"
|
echo " LM Studio server not detected on localhost:$LMS_PORT"
|
||||||
echo ""
|
echo ""
|
||||||
echo " To enable automatic comparison:"
|
echo " To enable automatic comparison:"
|
||||||
echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF)"
|
echo " 1. Open LM Studio, download Qwen2.5-0.5B-Instruct (GGUF + MLX variants)"
|
||||||
echo " 2. Load the model, go to Developer tab > Start Server"
|
echo " 2. Load the model, go to Developer tab > Start Server"
|
||||||
echo " 3. Re-run this benchmark"
|
echo " 3. Re-run this benchmark"
|
||||||
echo ""
|
echo ""
|
||||||
echo " Or set env vars: LMS_PORT=1234 LMS_API_KEY=your-key ./benchmark.sh"
|
echo " Or set env vars: LMS_PORT=1234 LMS_API_KEY=your-key ./benchmark.sh"
|
||||||
|
echo ""
|
||||||
|
echo " Models benchmarked by default:"
|
||||||
|
echo " - qwen2.5-0.5b-instruct (GGUF)"
|
||||||
|
echo " - qwen2.5-0.5b-instruct-mlx@8bit (MLX 8-bit)"
|
||||||
|
echo " - qwen2.5-0.5b-instruct-mlx@4bit (MLX 4-bit)"
|
||||||
|
echo ""
|
||||||
|
echo " Override with: LMS_MODELS='model1,model2' ./benchmark.sh"
|
||||||
fi
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
echo " Manual test:"
|
echo " Manual test:"
|
||||||
|
|
@ -385,9 +633,9 @@ else
|
||||||
echo " -H 'Authorization: Bearer YOUR_API_KEY' \\"
|
echo " -H 'Authorization: Bearer YOUR_API_KEY' \\"
|
||||||
echo " -d '{\"model\":\"qwen2.5-0.5b-instruct\",\"system_prompt\":\"You are a helpful assistant.\",\"input\":\"What is 2+2?\"}'"
|
echo " -d '{\"model\":\"qwen2.5-0.5b-instruct\",\"system_prompt\":\"You are a helpful assistant.\",\"input\":\"What is 2+2?\"}'"
|
||||||
echo ""
|
echo ""
|
||||||
echo " ANE (this benchmark): prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms"
|
echo " ANE F16: prefill=${AVG_P} t/s, decode=${AVG_D} t/s, inference=${AVG_INF}ms"
|
||||||
echo ""
|
echo ""
|
||||||
echo " Note: LM Studio uses quantized GGUF (CPU/GPU) while we use"
|
echo " Note: LM Studio uses quantized GGUF/MLX (CPU/GPU) while we use"
|
||||||
echo " BF16 weights (full precision) running on the Neural Engine."
|
echo " F16/Q8 weights running on CPU AMX / NEON."
|
||||||
fi
|
fi
|
||||||
echo ""
|
echo ""
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,17 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Convert Qwen2.5-0.5B-Instruct safetensors → flat binary for ANE inference.
|
"""Convert Qwen2.5-0.5B-Instruct safetensors → flat binary for ANE inference.
|
||||||
|
|
||||||
Output format: config header (7 ints) + all weights in f32, layer by layer.
|
Output format (F32): config header (8 ints) + all weights in f32
|
||||||
Matches the layout expected by qwen_ane_infer.h.
|
Output format (F16): config header (8 ints) + embeddings f32 + projection weights f16
|
||||||
|
Output format (Q8): config header (8 ints) + embeddings f32 + projection weights q8_0
|
||||||
|
Output format (Q4): config header (8 ints) + embeddings f32 + projection weights q4_0
|
||||||
|
|
||||||
|
The 8th config int is the format flag: 0 = F32, 1 = F16, 2 = Q8, 3 = Q4.
|
||||||
|
Q8_0 format: blocks of 32 values, each block = 1 f16 scale + 32 int8 values (34 bytes).
|
||||||
|
Q4_0 format: blocks of 32 values, each block = 1 f16 scale + 1 f16 zero + 16 uint8 packed pairs (20 bytes).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct /path/to/output.bin
|
python3 convert_weights.py <model_dir> <output.bin> [--f16|--q8|--q4]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
|
@ -14,10 +20,74 @@ import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
def convert(model_dir: str, output_path: str):
|
Q8_BLOCK_SIZE = 32
|
||||||
|
Q4_BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
def quantize_q4_0(weights_f32):
|
||||||
|
"""Quantize a 2D weight matrix to Q4_0 block format.
|
||||||
|
Returns bytes: for each row, blocks of (f16_scale + f16_zero + 16*uint8 packed pairs).
|
||||||
|
Each uint8 stores two 4-bit values: low nibble = even index, high nibble = odd index."""
|
||||||
|
out_dim, in_dim = weights_f32.shape
|
||||||
|
assert in_dim % Q4_BLOCK_SIZE == 0, f"in_dim {in_dim} not divisible by {Q4_BLOCK_SIZE}"
|
||||||
|
|
||||||
|
n_blocks_per_row = in_dim // Q4_BLOCK_SIZE
|
||||||
|
result = bytearray()
|
||||||
|
|
||||||
|
for r in range(out_dim):
|
||||||
|
row = weights_f32[r]
|
||||||
|
for b in range(n_blocks_per_row):
|
||||||
|
block = row[b * Q4_BLOCK_SIZE : (b + 1) * Q4_BLOCK_SIZE]
|
||||||
|
bmin = np.min(block)
|
||||||
|
bmax = np.max(block)
|
||||||
|
if bmax == bmin:
|
||||||
|
scale = np.float16(0.0)
|
||||||
|
zero = np.float16(0.0)
|
||||||
|
packed = bytes(Q4_BLOCK_SIZE // 2)
|
||||||
|
else:
|
||||||
|
scale_f = (bmax - bmin) / 15.0
|
||||||
|
zero_f = bmin
|
||||||
|
scale = np.float16(scale_f)
|
||||||
|
zero = np.float16(zero_f)
|
||||||
|
scale_f = float(scale) if float(scale) != 0.0 else 1e-10
|
||||||
|
quant = np.clip(np.round((block - float(zero)) / scale_f), 0, 15).astype(np.uint8)
|
||||||
|
packed = bytearray(Q4_BLOCK_SIZE // 2)
|
||||||
|
for i in range(0, Q4_BLOCK_SIZE, 2):
|
||||||
|
packed[i // 2] = quant[i] | (quant[i + 1] << 4)
|
||||||
|
result += scale.tobytes()
|
||||||
|
result += zero.tobytes()
|
||||||
|
result += bytes(packed)
|
||||||
|
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_q8_0(weights_f32):
|
||||||
|
"""Quantize a 2D weight matrix to Q8_0 block format.
|
||||||
|
Returns bytes: for each row, blocks of (f16_scale + 32*int8)."""
|
||||||
|
out_dim, in_dim = weights_f32.shape
|
||||||
|
assert in_dim % Q8_BLOCK_SIZE == 0, f"in_dim {in_dim} not divisible by {Q8_BLOCK_SIZE}"
|
||||||
|
|
||||||
|
n_blocks_per_row = in_dim // Q8_BLOCK_SIZE
|
||||||
|
result = bytearray()
|
||||||
|
|
||||||
|
for r in range(out_dim):
|
||||||
|
row = weights_f32[r]
|
||||||
|
for b in range(n_blocks_per_row):
|
||||||
|
block = row[b * Q8_BLOCK_SIZE : (b + 1) * Q8_BLOCK_SIZE]
|
||||||
|
amax = np.max(np.abs(block))
|
||||||
|
scale = amax / 127.0 if amax > 0 else 0.0
|
||||||
|
if scale > 0:
|
||||||
|
quant = np.round(block / scale).astype(np.int8)
|
||||||
|
else:
|
||||||
|
quant = np.zeros(Q8_BLOCK_SIZE, dtype=np.int8)
|
||||||
|
result += np.float16(scale).tobytes()
|
||||||
|
result += quant.tobytes()
|
||||||
|
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(model_dir: str, output_path: str, fmt: str = "f32"):
|
||||||
model_dir = Path(model_dir)
|
model_dir = Path(model_dir)
|
||||||
|
|
||||||
# Load safetensors
|
|
||||||
st_files = list(model_dir.glob("*.safetensors"))
|
st_files = list(model_dir.glob("*.safetensors"))
|
||||||
if not st_files:
|
if not st_files:
|
||||||
print(f"No safetensors files in {model_dir}")
|
print(f"No safetensors files in {model_dir}")
|
||||||
|
|
@ -30,8 +100,8 @@ def convert(model_dir: str, output_path: str):
|
||||||
tensors[key] = sf.get_tensor(key).float().numpy()
|
tensors[key] = sf.get_tensor(key).float().numpy()
|
||||||
|
|
||||||
print(f"Loaded {len(tensors)} tensors from {len(st_files)} files")
|
print(f"Loaded {len(tensors)} tensors from {len(st_files)} files")
|
||||||
|
print(f"Mode: {fmt.upper()} projections (embeddings + norms + biases stay F32)")
|
||||||
|
|
||||||
# Qwen2.5-0.5B config
|
|
||||||
dim = 896
|
dim = 896
|
||||||
hidden = 4864
|
hidden = 4864
|
||||||
n_layers = 24
|
n_layers = 24
|
||||||
|
|
@ -39,37 +109,41 @@ def convert(model_dir: str, output_path: str):
|
||||||
n_kv_heads = 2
|
n_kv_heads = 2
|
||||||
vocab_size = 151936
|
vocab_size = 151936
|
||||||
max_seq = 512
|
max_seq = 512
|
||||||
|
fmt_flag = {"f32": 0, "f16": 1, "q8": 2, "q4": 3}[fmt]
|
||||||
|
|
||||||
|
def write_proj(f_out, tensor_f32):
|
||||||
|
if fmt == "q4":
|
||||||
|
f_out.write(quantize_q4_0(tensor_f32))
|
||||||
|
elif fmt == "q8":
|
||||||
|
f_out.write(quantize_q8_0(tensor_f32))
|
||||||
|
elif fmt == "f16":
|
||||||
|
f_out.write(tensor_f32.astype(np.float16).tobytes())
|
||||||
|
else:
|
||||||
|
f_out.write(tensor_f32.astype(np.float32).tobytes())
|
||||||
|
|
||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
# Config header: 7 x int32
|
f.write(struct.pack("iiiiiiii",
|
||||||
f.write(struct.pack("iiiiiii",
|
dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq, fmt_flag))
|
||||||
dim, hidden, n_layers, n_heads, n_kv_heads, vocab_size, max_seq))
|
|
||||||
|
|
||||||
# Embedding [vocab, dim]
|
|
||||||
emb = tensors["model.embed_tokens.weight"].astype(np.float32)
|
emb = tensors["model.embed_tokens.weight"].astype(np.float32)
|
||||||
print(f"embed: {emb.shape}")
|
print(f"embed: {emb.shape} (f32)")
|
||||||
f.write(emb.tobytes())
|
f.write(emb.tobytes())
|
||||||
|
|
||||||
# Per-layer weights
|
|
||||||
for l in range(n_layers):
|
for l in range(n_layers):
|
||||||
prefix = f"model.layers.{l}"
|
prefix = f"model.layers.{l}"
|
||||||
|
|
||||||
# Attention norm
|
|
||||||
rms_att = tensors[f"{prefix}.input_layernorm.weight"].astype(np.float32)
|
rms_att = tensors[f"{prefix}.input_layernorm.weight"].astype(np.float32)
|
||||||
f.write(rms_att.tobytes())
|
f.write(rms_att.tobytes())
|
||||||
|
|
||||||
# Q, K, V projections
|
|
||||||
wq = tensors[f"{prefix}.self_attn.q_proj.weight"].astype(np.float32)
|
wq = tensors[f"{prefix}.self_attn.q_proj.weight"].astype(np.float32)
|
||||||
wk = tensors[f"{prefix}.self_attn.k_proj.weight"].astype(np.float32)
|
wk = tensors[f"{prefix}.self_attn.k_proj.weight"].astype(np.float32)
|
||||||
wv = tensors[f"{prefix}.self_attn.v_proj.weight"].astype(np.float32)
|
wv = tensors[f"{prefix}.self_attn.v_proj.weight"].astype(np.float32)
|
||||||
wo = tensors[f"{prefix}.self_attn.o_proj.weight"].astype(np.float32)
|
wo = tensors[f"{prefix}.self_attn.o_proj.weight"].astype(np.float32)
|
||||||
f.write(wq.tobytes())
|
write_proj(f, wq)
|
||||||
f.write(wk.tobytes())
|
write_proj(f, wk)
|
||||||
f.write(wv.tobytes())
|
write_proj(f, wv)
|
||||||
f.write(wo.tobytes())
|
write_proj(f, wo)
|
||||||
|
|
||||||
# Q/K biases (Qwen has them)
|
|
||||||
# Q/K/V biases
|
|
||||||
qb = tensors.get(f"{prefix}.self_attn.q_proj.bias")
|
qb = tensors.get(f"{prefix}.self_attn.q_proj.bias")
|
||||||
kb = tensors.get(f"{prefix}.self_attn.k_proj.bias")
|
kb = tensors.get(f"{prefix}.self_attn.k_proj.bias")
|
||||||
vb = tensors.get(f"{prefix}.self_attn.v_proj.bias")
|
vb = tensors.get(f"{prefix}.self_attn.v_proj.bias")
|
||||||
|
|
@ -77,22 +151,19 @@ def convert(model_dir: str, output_path: str):
|
||||||
f.write((kb if kb is not None else np.zeros(wk.shape[0])).astype(np.float32).tobytes())
|
f.write((kb if kb is not None else np.zeros(wk.shape[0])).astype(np.float32).tobytes())
|
||||||
f.write((vb if vb is not None else np.zeros(wv.shape[0])).astype(np.float32).tobytes())
|
f.write((vb if vb is not None else np.zeros(wv.shape[0])).astype(np.float32).tobytes())
|
||||||
|
|
||||||
# FFN norm
|
|
||||||
rms_ffn = tensors[f"{prefix}.post_attention_layernorm.weight"].astype(np.float32)
|
rms_ffn = tensors[f"{prefix}.post_attention_layernorm.weight"].astype(np.float32)
|
||||||
f.write(rms_ffn.tobytes())
|
f.write(rms_ffn.tobytes())
|
||||||
|
|
||||||
# FFN: gate, up, down
|
|
||||||
w_gate = tensors[f"{prefix}.mlp.gate_proj.weight"].astype(np.float32)
|
w_gate = tensors[f"{prefix}.mlp.gate_proj.weight"].astype(np.float32)
|
||||||
w_up = tensors[f"{prefix}.mlp.up_proj.weight"].astype(np.float32)
|
w_up = tensors[f"{prefix}.mlp.up_proj.weight"].astype(np.float32)
|
||||||
w_down = tensors[f"{prefix}.mlp.down_proj.weight"].astype(np.float32)
|
w_down = tensors[f"{prefix}.mlp.down_proj.weight"].astype(np.float32)
|
||||||
f.write(w_gate.tobytes())
|
write_proj(f, w_gate)
|
||||||
f.write(w_up.tobytes())
|
write_proj(f, w_up)
|
||||||
f.write(w_down.tobytes())
|
write_proj(f, w_down)
|
||||||
|
|
||||||
print(f" Layer {l}: Q{wq.shape} K{wk.shape} V{wv.shape} O{wo.shape} "
|
print(f" Layer {l}: Q{wq.shape} K{wk.shape} V{wv.shape} O{wo.shape} "
|
||||||
f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape}")
|
f"gate{w_gate.shape} up{w_up.shape} down{w_down.shape} [{fmt}]")
|
||||||
|
|
||||||
# Final norm
|
|
||||||
rms_final = tensors["model.norm.weight"].astype(np.float32)
|
rms_final = tensors["model.norm.weight"].astype(np.float32)
|
||||||
f.write(rms_final.tobytes())
|
f.write(rms_final.tobytes())
|
||||||
|
|
||||||
|
|
@ -101,7 +172,14 @@ def convert(model_dir: str, output_path: str):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) != 3:
|
if len(sys.argv) < 3:
|
||||||
print("Usage: python3 convert_weights.py <model_dir> <output.bin>")
|
print("Usage: python3 convert_weights.py <model_dir> <output.bin> [--f16|--q8|--q4]")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
convert(sys.argv[1], sys.argv[2])
|
fmt = "f32"
|
||||||
|
if "--f16" in sys.argv:
|
||||||
|
fmt = "f16"
|
||||||
|
elif "--q8" in sys.argv:
|
||||||
|
fmt = "q8"
|
||||||
|
elif "--q4" in sys.argv:
|
||||||
|
fmt = "q4"
|
||||||
|
convert(sys.argv[1], sys.argv[2], fmt)
|
||||||
|
|
|
||||||
239
inference/main.m
239
inference/main.m
|
|
@ -6,9 +6,10 @@
|
||||||
// 4. HTTP API: ./qwen_ane weights.bin --http 8000 --model-dir ~/models/Qwen2.5-0.5B-Instruct
|
// 4. HTTP API: ./qwen_ane weights.bin --http 8000 --model-dir ~/models/Qwen2.5-0.5B-Instruct
|
||||||
//
|
//
|
||||||
// Build:
|
// Build:
|
||||||
// xcrun clang -O2 -framework Foundation -framework IOSurface \
|
// xcrun clang -O3 -ffast-math -mcpu=apple-m4 -flto \
|
||||||
// -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
|
// -framework Foundation -framework IOSurface \
|
||||||
// -o qwen_ane main.m
|
// -framework CoreML -framework Accelerate -framework Metal \
|
||||||
|
// -ldl -lobjc -fobjc-arc -o qwen_ane main.m
|
||||||
//
|
//
|
||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
@ -39,36 +40,112 @@ static void handle_signal(int sig) {
|
||||||
_exit(0);
|
_exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void *safe_malloc(size_t size, const char *desc) {
|
||||||
|
void *p = malloc(size);
|
||||||
|
if (!p) {
|
||||||
|
fprintf(stderr, "FATAL: malloc failed for %s (%.1f MB)\n",
|
||||||
|
desc, (double)size / (1024*1024));
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void *safe_calloc(size_t count, size_t size, const char *desc) {
|
||||||
|
void *p = calloc(count, size);
|
||||||
|
if (!p) {
|
||||||
|
fprintf(stderr, "FATAL: calloc failed for %s (%.1f MB)\n",
|
||||||
|
desc, (double)(count * size) / (1024*1024));
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
static int load_weights(const char *path) {
|
static int load_weights(const char *path) {
|
||||||
FILE *f = fopen(path, "rb");
|
FILE *f = fopen(path, "rb");
|
||||||
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
|
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
|
||||||
|
|
||||||
int config[7];
|
// Try 8-int header first (new format), fall back to 7-int (legacy)
|
||||||
fread(config, sizeof(int), 7, f);
|
int config[8] = {0};
|
||||||
|
size_t hdr_read = fread(config, sizeof(int), 8, f);
|
||||||
int dim = config[0], hidden = config[1], n_layers = config[2];
|
int dim = config[0], hidden = config[1], n_layers = config[2];
|
||||||
int n_heads = config[3], n_kv_heads = config[4], vocab = config[5];
|
int n_heads = config[3], n_kv_heads = config[4], vocab = config[5];
|
||||||
printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d\n",
|
int fmt_flag = 0;
|
||||||
dim, hidden, n_layers, n_heads, n_kv_heads, vocab);
|
|
||||||
|
if (hdr_read == 8 && config[7] >= 0 && config[7] <= 3) {
|
||||||
|
fmt_flag = config[7];
|
||||||
|
} else {
|
||||||
|
fseek(f, 7 * sizeof(int), SEEK_SET);
|
||||||
|
}
|
||||||
|
|
||||||
|
g_model.weight_fmt = fmt_flag;
|
||||||
|
int is_f16 = (fmt_flag == 1);
|
||||||
|
int is_q8 = (fmt_flag == 2);
|
||||||
|
int is_q4 = (fmt_flag == 3);
|
||||||
|
const char *fmt_str = is_q4 ? "Q4" : (is_q8 ? "Q8" : (is_f16 ? "F16" : "F32"));
|
||||||
|
printf("Config: dim=%d hidden=%d layers=%d heads=%d kv_heads=%d vocab=%d fmt=%s\n",
|
||||||
|
dim, hidden, n_layers, n_heads, n_kv_heads, vocab, fmt_str);
|
||||||
|
|
||||||
int q_dim = n_heads * QWEN_HEAD_DIM;
|
int q_dim = n_heads * QWEN_HEAD_DIM;
|
||||||
int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
|
int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
|
||||||
|
|
||||||
g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float));
|
// Embeddings always F32
|
||||||
|
g_model.embed = (float*)safe_malloc((size_t)vocab * dim * sizeof(float), "embed");
|
||||||
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
|
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
|
||||||
|
|
||||||
for (int l = 0; l < n_layers; l++) {
|
for (int l = 0; l < n_layers; l++) {
|
||||||
|
// RMSNorm always F32
|
||||||
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
|
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
|
||||||
fread(g_model.rms_att[l], sizeof(float), dim, f);
|
fread(g_model.rms_att[l], sizeof(float), dim, f);
|
||||||
|
|
||||||
g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
|
if (is_q4) {
|
||||||
fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f);
|
#define LOAD_Q4(q8ptr, out_d, in_d) do { \
|
||||||
g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
size_t _nb = (size_t)(in_d) / Q4_BLOCK_SIZE; \
|
||||||
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
|
size_t _bytes = (size_t)(out_d) * _nb * Q4_BLOCK_BYTES; \
|
||||||
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
|
||||||
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
|
fread(q8ptr, 1, _bytes, f); \
|
||||||
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
|
} while(0)
|
||||||
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
|
LOAD_Q4(g_model.wq_q8[l], q_dim, dim);
|
||||||
|
LOAD_Q4(g_model.wk_q8[l], kv_dim, dim);
|
||||||
|
LOAD_Q4(g_model.wv_q8[l], kv_dim, dim);
|
||||||
|
LOAD_Q4(g_model.wo_q8[l], dim, q_dim);
|
||||||
|
#undef LOAD_Q4
|
||||||
|
} else if (is_q8) {
|
||||||
|
#define LOAD_Q8(q8ptr, out_d, in_d) do { \
|
||||||
|
size_t _nb = (size_t)(in_d) / Q8_BLOCK_SIZE; \
|
||||||
|
size_t _bytes = (size_t)(out_d) * _nb * Q8_BLOCK_BYTES; \
|
||||||
|
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
|
||||||
|
fread(q8ptr, 1, _bytes, f); \
|
||||||
|
} while(0)
|
||||||
|
LOAD_Q8(g_model.wq_q8[l], q_dim, dim);
|
||||||
|
LOAD_Q8(g_model.wk_q8[l], kv_dim, dim);
|
||||||
|
LOAD_Q8(g_model.wv_q8[l], kv_dim, dim);
|
||||||
|
LOAD_Q8(g_model.wo_q8[l], dim, q_dim);
|
||||||
|
#undef LOAD_Q8
|
||||||
|
} else if (is_f16) {
|
||||||
|
#define LOAD_F16_AS_F32(f32ptr, f16ptr, n) do { \
|
||||||
|
size_t _n = (size_t)(n); \
|
||||||
|
f16ptr = (_Float16*)malloc(_n * sizeof(_Float16)); \
|
||||||
|
fread(f16ptr, sizeof(_Float16), _n, f); \
|
||||||
|
f32ptr = (float*)malloc(_n * sizeof(float)); \
|
||||||
|
convert_f16_to_f32(f16ptr, f32ptr, _n); \
|
||||||
|
} while(0)
|
||||||
|
LOAD_F16_AS_F32(g_model.wq[l], g_model.wq_f16[l], (size_t)q_dim * dim);
|
||||||
|
LOAD_F16_AS_F32(g_model.wk[l], g_model.wk_f16[l], (size_t)kv_dim * dim);
|
||||||
|
LOAD_F16_AS_F32(g_model.wv[l], g_model.wv_f16[l], (size_t)kv_dim * dim);
|
||||||
|
LOAD_F16_AS_F32(g_model.wo[l], g_model.wo_f16[l], (size_t)dim * q_dim);
|
||||||
|
#undef LOAD_F16_AS_F32
|
||||||
|
} else {
|
||||||
|
g_model.wq[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
|
||||||
|
fread(g_model.wq[l], sizeof(float), (size_t)q_dim * dim, f);
|
||||||
|
g_model.wk[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
||||||
|
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
|
||||||
|
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
||||||
|
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
|
||||||
|
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
|
||||||
|
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Biases always F32
|
||||||
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
|
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
|
||||||
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
||||||
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
||||||
|
|
@ -76,15 +153,52 @@ static int load_weights(const char *path) {
|
||||||
fread(g_model.k_bias[l], sizeof(float), kv_dim, f);
|
fread(g_model.k_bias[l], sizeof(float), kv_dim, f);
|
||||||
fread(g_model.v_bias[l], sizeof(float), kv_dim, f);
|
fread(g_model.v_bias[l], sizeof(float), kv_dim, f);
|
||||||
|
|
||||||
|
// FFN RMSNorm always F32
|
||||||
g_model.rms_ffn[l] = (float*)malloc(dim * sizeof(float));
|
g_model.rms_ffn[l] = (float*)malloc(dim * sizeof(float));
|
||||||
fread(g_model.rms_ffn[l], sizeof(float), dim, f);
|
fread(g_model.rms_ffn[l], sizeof(float), dim, f);
|
||||||
|
|
||||||
g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
|
if (is_q4) {
|
||||||
fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f);
|
#define LOAD_Q4(q8ptr, out_d, in_d) do { \
|
||||||
g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
|
size_t _nb = (size_t)(in_d) / Q4_BLOCK_SIZE; \
|
||||||
fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f);
|
size_t _bytes = (size_t)(out_d) * _nb * Q4_BLOCK_BYTES; \
|
||||||
g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float));
|
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
|
||||||
fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f);
|
fread(q8ptr, 1, _bytes, f); \
|
||||||
|
} while(0)
|
||||||
|
LOAD_Q4(g_model.wgate_q8[l], hidden, dim);
|
||||||
|
LOAD_Q4(g_model.wup_q8[l], hidden, dim);
|
||||||
|
LOAD_Q4(g_model.wdown_q8[l], dim, hidden);
|
||||||
|
#undef LOAD_Q4
|
||||||
|
} else if (is_q8) {
|
||||||
|
#define LOAD_Q8(q8ptr, out_d, in_d) do { \
|
||||||
|
size_t _nb = (size_t)(in_d) / Q8_BLOCK_SIZE; \
|
||||||
|
size_t _bytes = (size_t)(out_d) * _nb * Q8_BLOCK_BYTES; \
|
||||||
|
q8ptr = (uint8_t*)safe_malloc(_bytes, #q8ptr); \
|
||||||
|
fread(q8ptr, 1, _bytes, f); \
|
||||||
|
} while(0)
|
||||||
|
LOAD_Q8(g_model.wgate_q8[l], hidden, dim);
|
||||||
|
LOAD_Q8(g_model.wup_q8[l], hidden, dim);
|
||||||
|
LOAD_Q8(g_model.wdown_q8[l], dim, hidden);
|
||||||
|
#undef LOAD_Q8
|
||||||
|
} else if (is_f16) {
|
||||||
|
#define LOAD_F16_AS_F32(f32ptr, f16ptr, n) do { \
|
||||||
|
size_t _n = (size_t)(n); \
|
||||||
|
f16ptr = (_Float16*)malloc(_n * sizeof(_Float16)); \
|
||||||
|
fread(f16ptr, sizeof(_Float16), _n, f); \
|
||||||
|
f32ptr = (float*)malloc(_n * sizeof(float)); \
|
||||||
|
convert_f16_to_f32(f16ptr, f32ptr, _n); \
|
||||||
|
} while(0)
|
||||||
|
LOAD_F16_AS_F32(g_model.w_gate[l], g_model.wgate_f16[l], (size_t)hidden * dim);
|
||||||
|
LOAD_F16_AS_F32(g_model.w_up[l], g_model.wup_f16[l], (size_t)hidden * dim);
|
||||||
|
LOAD_F16_AS_F32(g_model.w_down[l], g_model.wdown_f16[l], (size_t)dim * hidden);
|
||||||
|
#undef LOAD_F16_AS_F32
|
||||||
|
} else {
|
||||||
|
g_model.w_gate[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
|
||||||
|
fread(g_model.w_gate[l], sizeof(float), (size_t)hidden * dim, f);
|
||||||
|
g_model.w_up[l] = (float*)malloc((size_t)hidden * dim * sizeof(float));
|
||||||
|
fread(g_model.w_up[l], sizeof(float), (size_t)hidden * dim, f);
|
||||||
|
g_model.w_down[l] = (float*)malloc((size_t)dim * hidden * sizeof(float));
|
||||||
|
fread(g_model.w_down[l], sizeof(float), (size_t)dim * hidden, f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
g_model.rms_final = (float*)malloc(dim * sizeof(float));
|
g_model.rms_final = (float*)malloc(dim * sizeof(float));
|
||||||
|
|
@ -92,7 +206,8 @@ static int load_weights(const char *path) {
|
||||||
|
|
||||||
long file_size = ftell(f);
|
long file_size = ftell(f);
|
||||||
fclose(f);
|
fclose(f);
|
||||||
printf("Weights loaded (%.0f MB)\n", (float)file_size / 1024 / 1024);
|
printf("Weights loaded (%.0f MB, %s projections)\n",
|
||||||
|
(float)file_size / 1024 / 1024, fmt_str);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -115,16 +230,25 @@ static double timespec_diff(struct timespec *a, struct timespec *b) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run one generation pass. Writes output token IDs to out_ids, returns count.
|
// Run one generation pass. Writes output token IDs to out_ids, returns count.
|
||||||
// If out_fd >= 0, writes formatted results there; otherwise prints to stdout.
|
// Uses batched prefill (sgemm) for prompt, sequential decode (sgemv) for generation.
|
||||||
static int generate(int *prompt_ids, int n_prompt, int max_gen,
|
static int generate(int *prompt_ids, int n_prompt, int max_gen,
|
||||||
int *out_ids, int max_out,
|
int *out_ids, int max_out,
|
||||||
double *prefill_tps, double *decode_tps) {
|
double *prefill_tps, double *decode_tps) {
|
||||||
struct timespec t0, t1, t_pre;
|
struct timespec t0, t1, t_pre;
|
||||||
clock_gettime(CLOCK_MONOTONIC, &t0);
|
clock_gettime(CLOCK_MONOTONIC, &t0);
|
||||||
|
|
||||||
int next = 0;
|
int next;
|
||||||
for (int i = 0; i < n_prompt; i++)
|
if (g_model.use_ane) {
|
||||||
next = qwen_forward(&g_model, prompt_ids[i]);
|
for (int i = 0; i < n_prompt; i++)
|
||||||
|
next = qwen_forward_ane(&g_model, prompt_ids[i]);
|
||||||
|
} else if (n_prompt > 1 && g_model.weight_fmt == 3) {
|
||||||
|
next = qwen_prefill_q4(&g_model, prompt_ids, n_prompt);
|
||||||
|
} else if (n_prompt > 1 && g_model.weight_fmt != 2) {
|
||||||
|
next = qwen_prefill(&g_model, prompt_ids, n_prompt);
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < n_prompt; i++)
|
||||||
|
next = qwen_forward(&g_model, prompt_ids[i]);
|
||||||
|
}
|
||||||
|
|
||||||
clock_gettime(CLOCK_MONOTONIC, &t_pre);
|
clock_gettime(CLOCK_MONOTONIC, &t_pre);
|
||||||
double ps = timespec_diff(&t0, &t_pre);
|
double ps = timespec_diff(&t0, &t_pre);
|
||||||
|
|
@ -135,7 +259,10 @@ static int generate(int *prompt_ids, int n_prompt, int max_gen,
|
||||||
for (int i = 0; i < max_gen && n_out < max_out; i++) {
|
for (int i = 0; i < max_gen && n_out < max_out; i++) {
|
||||||
if (n_out < max_out) out_ids[n_out++] = next;
|
if (n_out < max_out) out_ids[n_out++] = next;
|
||||||
if (next == eos || next == eos2) break;
|
if (next == eos || next == eos2) break;
|
||||||
next = qwen_forward(&g_model, next);
|
if (g_model.use_ane)
|
||||||
|
next = qwen_forward_ane(&g_model, next);
|
||||||
|
else
|
||||||
|
next = qwen_forward(&g_model, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||||
|
|
@ -427,6 +554,7 @@ int main(int argc, char **argv) {
|
||||||
int server_mode = 0;
|
int server_mode = 0;
|
||||||
int http_port = 0;
|
int http_port = 0;
|
||||||
int test_ane = 0;
|
int test_ane = 0;
|
||||||
|
int use_ane = 0;
|
||||||
const char *sock_path = NULL;
|
const char *sock_path = NULL;
|
||||||
const char *model_dir = NULL;
|
const char *model_dir = NULL;
|
||||||
for (int i = 2; i < argc; i++) {
|
for (int i = 2; i < argc; i++) {
|
||||||
|
|
@ -442,6 +570,61 @@ int main(int argc, char **argv) {
|
||||||
else { fprintf(stderr, "--model-dir requires a path\n"); return 1; }
|
else { fprintf(stderr, "--model-dir requires a path\n"); return 1; }
|
||||||
} else if (strcmp(argv[i], "--test-ane") == 0) {
|
} else if (strcmp(argv[i], "--test-ane") == 0) {
|
||||||
test_ane = 1;
|
test_ane = 1;
|
||||||
|
} else if (strcmp(argv[i], "--ane") == 0) {
|
||||||
|
use_ane = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Q4 CPU mode: dequantize Q4 to F32 at load time, use AMX cblas_sgemv
|
||||||
|
if (g_model.weight_fmt == 3) {
|
||||||
|
printf("Dequantizing Q4→F32 for AMX acceleration...\n");
|
||||||
|
int q_dim = QWEN_Q_DIM, kv_dim = QWEN_KV_DIM, dim = QWEN_DIM;
|
||||||
|
int hidden = QWEN_HIDDEN;
|
||||||
|
|
||||||
|
#define DEQUANT_Q4_TO_F32(f32ptr, q4ptr, out_d, in_d) do { \
|
||||||
|
size_t _n = (size_t)(out_d) * (in_d); \
|
||||||
|
f32ptr = (float*)malloc(_n * sizeof(float)); \
|
||||||
|
dequant_q4_to_f32(q4ptr, f32ptr, (in_d), (out_d)); \
|
||||||
|
free(q4ptr); q4ptr = NULL; \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
for (int l = 0; l < QWEN_LAYERS; l++) {
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.wq[l], g_model.wq_q8[l], q_dim, dim);
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.wk[l], g_model.wk_q8[l], kv_dim, dim);
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.wv[l], g_model.wv_q8[l], kv_dim, dim);
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.wo[l], g_model.wo_q8[l], dim, q_dim);
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.w_gate[l], g_model.wgate_q8[l], hidden, dim);
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.w_up[l], g_model.wup_q8[l], hidden, dim);
|
||||||
|
DEQUANT_Q4_TO_F32(g_model.w_down[l], g_model.wdown_q8[l], dim, hidden);
|
||||||
|
}
|
||||||
|
#undef DEQUANT_Q4_TO_F32
|
||||||
|
|
||||||
|
g_model.weight_fmt = 0;
|
||||||
|
printf("Q4→F32 done. Using AMX cblas_sgemv (91+ t/s decode).\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ANE fused kernel compilation (requires F32 weights for baked-weight convs)
|
||||||
|
if (use_ane) {
|
||||||
|
if (g_model.weight_fmt != 0) {
|
||||||
|
printf("--ane requires F32 weights (weight_fmt=0). Got fmt=%d\n", g_model.weight_fmt);
|
||||||
|
printf("Re-run with F32 weight file (convert_weights.py without --f16/--q4/--q8)\n");
|
||||||
|
use_ane = 0;
|
||||||
|
} else {
|
||||||
|
struct timespec ta0, ta1;
|
||||||
|
clock_gettime(CLOCK_MONOTONIC, &ta0);
|
||||||
|
qwen_compile_kernels_fused(&g_model);
|
||||||
|
clock_gettime(CLOCK_MONOTONIC, &ta1);
|
||||||
|
double ane_sec = timespec_diff(&ta0, &ta1);
|
||||||
|
printf("ANE fused compile time: %.1fs\n", ane_sec);
|
||||||
|
|
||||||
|
// Verify at least one QKV kernel compiled
|
||||||
|
if (g_model.k_qkv[0] && g_model.k_o[0] && g_model.k_ffn_up[0] && g_model.k_down[0]) {
|
||||||
|
g_model.use_ane = 1;
|
||||||
|
printf("ANE fused mode active: 112 kernels (QKV+FFN_up fused)\n");
|
||||||
|
} else {
|
||||||
|
printf("ANE fused compilation failed, falling back to CPU\n");
|
||||||
|
use_ane = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,921 @@
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
// ── Q4_0 block format ────────────────────────────────────────────────
|
||||||
|
// Block of 32 values: 2 bytes F16 scale + 2 bytes F16 zero + 16 bytes packed uint8
|
||||||
|
// Each uint8 stores 2 values: low nibble = even index, high nibble = odd index
|
||||||
|
// Total: 20 bytes per block of 32 weights
|
||||||
|
#define Q4_BLOCK_SIZE 32
|
||||||
|
#define Q4_BLOCK_BYTES 20
|
||||||
|
|
||||||
|
// ── Q4 Matrix-vector multiply (legacy, 1 thread per row) ────────────
|
||||||
|
// Kept as fallback for edge cases.
|
||||||
|
kernel void sgemv_q4(
|
||||||
|
device const uint8_t *W [[buffer(0)]],
|
||||||
|
device const float *x [[buffer(1)]],
|
||||||
|
device float *y [[buffer(2)]],
|
||||||
|
constant uint &in_dim [[buffer(3)]],
|
||||||
|
constant uint &out_dim [[buffer(4)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= out_dim) return;
|
||||||
|
|
||||||
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
||||||
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
||||||
|
device const uint8_t *row = W + uint64_t(gid) * row_bytes;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (uint b = 0; b < n_blocks; b++) {
|
||||||
|
device const uint8_t *block = row + b * Q4_BLOCK_BYTES;
|
||||||
|
half scale_h, zero_h;
|
||||||
|
scale_h = *reinterpret_cast<device const half*>(block);
|
||||||
|
zero_h = *reinterpret_cast<device const half*>(block + 2);
|
||||||
|
float scale = float(scale_h);
|
||||||
|
float zero = float(zero_h);
|
||||||
|
|
||||||
|
device const uint8_t *packed = block + 4;
|
||||||
|
uint base = b * Q4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
for (uint i = 0; i < 16; i++) {
|
||||||
|
uint8_t byte = packed[i];
|
||||||
|
float w0 = float(byte & 0xF) * scale + zero;
|
||||||
|
float w1 = float(byte >> 4) * scale + zero;
|
||||||
|
sum += w0 * x[base + i * 2];
|
||||||
|
sum += w1 * x[base + i * 2 + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
y[gid] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Q4 SIMD-optimized matrix-vector multiply ─────────────────────────
|
||||||
|
// MLX-style cooperative SIMD kernel: 2 SIMD groups per threadgroup,
|
||||||
|
// each SIMD group handles ROWS_PER_SIMD output rows cooperatively.
|
||||||
|
// 32 threads in a SIMD group split the K (input) dimension, then
|
||||||
|
// reduce via simd_sum(). No threadgroup memory needed.
|
||||||
|
//
|
||||||
|
// Threadgroup layout: 64 threads = 2 SIMD groups of 32
|
||||||
|
// Grid: (ceil(out_dim / ROWS_PER_TG), 1, 1) threadgroups
|
||||||
|
//
|
||||||
|
// Optional bias: if bias pointer is non-null (use_bias != 0),
|
||||||
|
// y[r] = dot(W[r], x) + bias[r]
|
||||||
|
#define ROWS_PER_SIMD 4
|
||||||
|
#define SIMD_GROUPS 2
|
||||||
|
#define ROWS_PER_TG (ROWS_PER_SIMD * SIMD_GROUPS)
|
||||||
|
|
||||||
|
kernel void sgemv_q4_fast(
|
||||||
|
device const uint8_t *W [[buffer(0)]],
|
||||||
|
device const float *x [[buffer(1)]],
|
||||||
|
device float *y [[buffer(2)]],
|
||||||
|
constant uint &in_dim [[buffer(3)]],
|
||||||
|
constant uint &out_dim [[buffer(4)]],
|
||||||
|
device const float *bias [[buffer(5)]],
|
||||||
|
constant uint &use_bias [[buffer(6)]],
|
||||||
|
uint tgid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]])
|
||||||
|
{
|
||||||
|
uint base_row = tgid * ROWS_PER_TG + simd_gid * ROWS_PER_SIMD;
|
||||||
|
if (base_row >= out_dim) return;
|
||||||
|
|
||||||
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
||||||
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
||||||
|
|
||||||
|
uint rows_this = min((uint)ROWS_PER_SIMD, out_dim - base_row);
|
||||||
|
|
||||||
|
float accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
float zero_accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
|
||||||
|
// Each of 32 SIMD lanes processes a stripe of blocks.
|
||||||
|
// Lane i processes blocks i, i+32, i+64, ...
|
||||||
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
||||||
|
uint k_base = b * Q4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Load input vector segment for this block (32 floats)
|
||||||
|
float xv[Q4_BLOCK_SIZE];
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
xv[j * 2] = x[k_base + j * 2];
|
||||||
|
xv[j * 2 + 1] = x[k_base + j * 2 + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
device const uint8_t *block =
|
||||||
|
W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
||||||
|
|
||||||
|
half scale_h = *reinterpret_cast<device const half*>(block);
|
||||||
|
half zero_h = *reinterpret_cast<device const half*>(block + 2);
|
||||||
|
float scale = float(scale_h);
|
||||||
|
float zero = float(zero_h);
|
||||||
|
|
||||||
|
device const uint8_t *packed = block + 4;
|
||||||
|
|
||||||
|
float dot = 0.0f;
|
||||||
|
float xsum = 0.0f;
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
uint8_t byte = packed[j];
|
||||||
|
float w0 = float(byte & 0xF);
|
||||||
|
float w1 = float(byte >> 4);
|
||||||
|
dot += w0 * xv[j * 2] + w1 * xv[j * 2 + 1];
|
||||||
|
xsum += xv[j * 2] + xv[j * 2 + 1];
|
||||||
|
}
|
||||||
|
accum[r] += dot * scale;
|
||||||
|
zero_accum[r] += xsum * zero;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIMD reduction across 32 lanes
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]);
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
if (use_bias != 0) {
|
||||||
|
result += bias[base_row + r];
|
||||||
|
}
|
||||||
|
y[base_row + r] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Fused Gate+Up+SiLU: reads x once, computes gate=silu(Wg*x)*Wu*x ──
|
||||||
|
// Combines two Q4 matvecs + silu_mul into one kernel.
|
||||||
|
// W_gate and W_up have the same dimensions [out_dim, in_dim].
|
||||||
|
// Output: gate[r] = silu(dot(W_gate[r], x)) * dot(W_up[r], x)
|
||||||
|
#define FUSED_ROWS_PER_SIMD 2
|
||||||
|
#define FUSED_SIMD_GROUPS 2
|
||||||
|
#define FUSED_ROWS_PER_TG (FUSED_ROWS_PER_SIMD * FUSED_SIMD_GROUPS)
|
||||||
|
|
||||||
|
kernel void sgemv_q4_fused_ffn(
|
||||||
|
device const uint8_t *W_gate [[buffer(0)]],
|
||||||
|
device const uint8_t *W_up [[buffer(1)]],
|
||||||
|
device const float *x [[buffer(2)]],
|
||||||
|
device float *out [[buffer(3)]],
|
||||||
|
constant uint &in_dim [[buffer(4)]],
|
||||||
|
constant uint &out_dim [[buffer(5)]],
|
||||||
|
uint tgid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]])
|
||||||
|
{
|
||||||
|
uint base_row = tgid * FUSED_ROWS_PER_TG + simd_gid * FUSED_ROWS_PER_SIMD;
|
||||||
|
if (base_row >= out_dim) return;
|
||||||
|
|
||||||
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
||||||
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
||||||
|
|
||||||
|
uint rows_this = min((uint)FUSED_ROWS_PER_SIMD, out_dim - base_row);
|
||||||
|
|
||||||
|
float gate_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
float gate_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
float up_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
float up_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
|
||||||
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
||||||
|
uint k_base = b * Q4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
float xv[Q4_BLOCK_SIZE];
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
xv[j * 2] = x[k_base + j * 2];
|
||||||
|
xv[j * 2 + 1] = x[k_base + j * 2 + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
||||||
|
|
||||||
|
// Gate weight block
|
||||||
|
device const uint8_t *g_block = W_gate + row_off;
|
||||||
|
float g_scale = float(*reinterpret_cast<device const half*>(g_block));
|
||||||
|
float g_zero = float(*reinterpret_cast<device const half*>(g_block + 2));
|
||||||
|
device const uint8_t *g_packed = g_block + 4;
|
||||||
|
|
||||||
|
// Up weight block
|
||||||
|
device const uint8_t *u_block = W_up + row_off;
|
||||||
|
float u_scale = float(*reinterpret_cast<device const half*>(u_block));
|
||||||
|
float u_zero = float(*reinterpret_cast<device const half*>(u_block + 2));
|
||||||
|
device const uint8_t *u_packed = u_block + 4;
|
||||||
|
|
||||||
|
float g_dot = 0.0f, g_xsum = 0.0f;
|
||||||
|
float u_dot = 0.0f, u_xsum = 0.0f;
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
float x0 = xv[j * 2];
|
||||||
|
float x1 = xv[j * 2 + 1];
|
||||||
|
float xs = x0 + x1;
|
||||||
|
|
||||||
|
uint8_t gb = g_packed[j];
|
||||||
|
g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1;
|
||||||
|
g_xsum += xs;
|
||||||
|
|
||||||
|
uint8_t ub = u_packed[j];
|
||||||
|
u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1;
|
||||||
|
u_xsum += xs;
|
||||||
|
}
|
||||||
|
gate_acc[r] += g_dot * g_scale;
|
||||||
|
gate_zacc[r] += g_xsum * g_zero;
|
||||||
|
up_acc[r] += u_dot * u_scale;
|
||||||
|
up_zacc[r] += u_xsum * u_zero;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]);
|
||||||
|
float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]);
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
float s = g / (1.0f + exp(-g));
|
||||||
|
out[base_row + r] = s * u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Q4 batched matrix-matrix multiply (SGEMM) for prefill ────────────
|
||||||
|
// Y[t, r] = sum_k(dequant(W[r, k]) * X[t, k]) for t in [0, n_tokens), r in [0, out_dim)
|
||||||
|
// Grid: (ceil(out_dim / GEMM_TILE_M), n_tokens, 1)
|
||||||
|
// Each threadgroup: 2 SIMD groups, each handles GEMM_TILE_M/2 output rows for one token.
|
||||||
|
#define GEMM_TILE_M 8
|
||||||
|
#define GEMM_SIMD_GROUPS 2
|
||||||
|
#define GEMM_ROWS_PER_SIMD (GEMM_TILE_M / GEMM_SIMD_GROUPS)
|
||||||
|
|
||||||
|
kernel void sgemm_q4(
|
||||||
|
device const uint8_t *W [[buffer(0)]],
|
||||||
|
device const float *X [[buffer(1)]],
|
||||||
|
device float *Y [[buffer(2)]],
|
||||||
|
constant uint &in_dim [[buffer(3)]],
|
||||||
|
constant uint &out_dim [[buffer(4)]],
|
||||||
|
device const float *bias [[buffer(5)]],
|
||||||
|
constant uint &use_bias [[buffer(6)]],
|
||||||
|
constant uint &n_tokens [[buffer(7)]],
|
||||||
|
uint2 tgid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]])
|
||||||
|
{
|
||||||
|
uint base_row = tgid.x * GEMM_TILE_M + simd_gid * GEMM_ROWS_PER_SIMD;
|
||||||
|
uint t = tgid.y;
|
||||||
|
if (base_row >= out_dim || t >= n_tokens) return;
|
||||||
|
|
||||||
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
||||||
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
||||||
|
uint rows_this = min((uint)GEMM_ROWS_PER_SIMD, out_dim - base_row);
|
||||||
|
|
||||||
|
device const float *xt = X + uint64_t(t) * in_dim;
|
||||||
|
|
||||||
|
float accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
float zero_accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||||
|
|
||||||
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
||||||
|
uint k_base = b * Q4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
float xv[Q4_BLOCK_SIZE];
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
xv[j * 2] = xt[k_base + j * 2];
|
||||||
|
xv[j * 2 + 1] = xt[k_base + j * 2 + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
device const uint8_t *block =
|
||||||
|
W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
||||||
|
|
||||||
|
float scale = float(*reinterpret_cast<device const half*>(block));
|
||||||
|
float zero = float(*reinterpret_cast<device const half*>(block + 2));
|
||||||
|
device const uint8_t *packed = block + 4;
|
||||||
|
|
||||||
|
float dot = 0.0f;
|
||||||
|
float xsum = 0.0f;
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
uint8_t byte = packed[j];
|
||||||
|
dot += float(byte & 0xF) * xv[j * 2] + float(byte >> 4) * xv[j * 2 + 1];
|
||||||
|
xsum += xv[j * 2] + xv[j * 2 + 1];
|
||||||
|
}
|
||||||
|
accum[r] += dot * scale;
|
||||||
|
zero_accum[r] += xsum * zero;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]);
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
if (use_bias != 0)
|
||||||
|
result += bias[base_row + r];
|
||||||
|
Y[uint64_t(t) * out_dim + base_row + r] = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Q4 batched fused Gate+Up+SiLU (SGEMM variant) ───────────────────
|
||||||
|
// out[t, r] = silu(Wg[r] . X[t]) * Wu[r] . X[t] for all t and r
|
||||||
|
// Grid: (ceil(out_dim / GEMM_FFN_TILE_M), n_tokens, 1)
|
||||||
|
#define GEMM_FFN_TILE_M 4
|
||||||
|
#define GEMM_FFN_SIMD_GROUPS 2
|
||||||
|
#define GEMM_FFN_ROWS_PER_SIMD (GEMM_FFN_TILE_M / GEMM_FFN_SIMD_GROUPS)
|
||||||
|
|
||||||
|
kernel void sgemm_q4_fused_ffn(
|
||||||
|
device const uint8_t *W_gate [[buffer(0)]],
|
||||||
|
device const uint8_t *W_up [[buffer(1)]],
|
||||||
|
device const float *X [[buffer(2)]],
|
||||||
|
device float *out [[buffer(3)]],
|
||||||
|
constant uint &in_dim [[buffer(4)]],
|
||||||
|
constant uint &out_dim [[buffer(5)]],
|
||||||
|
constant uint &n_tokens [[buffer(6)]],
|
||||||
|
uint2 tgid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]])
|
||||||
|
{
|
||||||
|
uint base_row = tgid.x * GEMM_FFN_TILE_M + simd_gid * GEMM_FFN_ROWS_PER_SIMD;
|
||||||
|
uint t = tgid.y;
|
||||||
|
if (base_row >= out_dim || t >= n_tokens) return;
|
||||||
|
|
||||||
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
||||||
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
||||||
|
uint rows_this = min((uint)GEMM_FFN_ROWS_PER_SIMD, out_dim - base_row);
|
||||||
|
|
||||||
|
device const float *xt = X + uint64_t(t) * in_dim;
|
||||||
|
|
||||||
|
float gate_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
float gate_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
float up_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
float up_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
||||||
|
|
||||||
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
||||||
|
uint k_base = b * Q4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
float xv[Q4_BLOCK_SIZE];
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
xv[j * 2] = xt[k_base + j * 2];
|
||||||
|
xv[j * 2 + 1] = xt[k_base + j * 2 + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
||||||
|
|
||||||
|
device const uint8_t *g_block = W_gate + row_off;
|
||||||
|
float g_scale = float(*reinterpret_cast<device const half*>(g_block));
|
||||||
|
float g_zero = float(*reinterpret_cast<device const half*>(g_block + 2));
|
||||||
|
device const uint8_t *g_packed = g_block + 4;
|
||||||
|
|
||||||
|
device const uint8_t *u_block = W_up + row_off;
|
||||||
|
float u_scale = float(*reinterpret_cast<device const half*>(u_block));
|
||||||
|
float u_zero = float(*reinterpret_cast<device const half*>(u_block + 2));
|
||||||
|
device const uint8_t *u_packed = u_block + 4;
|
||||||
|
|
||||||
|
float g_dot = 0.0f, g_xsum = 0.0f;
|
||||||
|
float u_dot = 0.0f, u_xsum = 0.0f;
|
||||||
|
for (uint j = 0; j < 16; j++) {
|
||||||
|
float x0 = xv[j * 2];
|
||||||
|
float x1 = xv[j * 2 + 1];
|
||||||
|
float xs = x0 + x1;
|
||||||
|
|
||||||
|
uint8_t gb = g_packed[j];
|
||||||
|
g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1;
|
||||||
|
g_xsum += xs;
|
||||||
|
|
||||||
|
uint8_t ub = u_packed[j];
|
||||||
|
u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1;
|
||||||
|
u_xsum += xs;
|
||||||
|
}
|
||||||
|
gate_acc[r] += g_dot * g_scale;
|
||||||
|
gate_zacc[r] += g_xsum * g_zero;
|
||||||
|
up_acc[r] += u_dot * u_scale;
|
||||||
|
up_zacc[r] += u_xsum * u_zero;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint r = 0; r < rows_this; r++) {
|
||||||
|
float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]);
|
||||||
|
float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]);
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
float s = g / (1.0f + exp(-g));
|
||||||
|
out[uint64_t(t) * out_dim + base_row + r] = s * u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched RMSNorm (N tokens) ──────────────────────────────────────
|
||||||
|
// x[t*dim .. (t+1)*dim-1] → out[t*dim .. (t+1)*dim-1]
|
||||||
|
// Grid: (n_tokens, 1, 1) threadgroups, each normalizes one token.
|
||||||
|
kernel void rms_norm_batched(
|
||||||
|
device const float *x [[buffer(0)]],
|
||||||
|
device const float *w [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
constant uint &dim [[buffer(3)]],
|
||||||
|
constant float &eps [[buffer(4)]],
|
||||||
|
constant uint &n_tokens [[buffer(5)]],
|
||||||
|
uint tgid [[threadgroup_position_in_grid]],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint tpg [[threads_per_threadgroup]])
|
||||||
|
{
|
||||||
|
if (tgid >= n_tokens) return;
|
||||||
|
|
||||||
|
device const float *xi = x + uint64_t(tgid) * dim;
|
||||||
|
device float *oi = out + uint64_t(tgid) * dim;
|
||||||
|
|
||||||
|
threadgroup float partial[1024];
|
||||||
|
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
for (uint i = tid; i < dim; i += tpg)
|
||||||
|
local_sum += xi[i] * xi[i];
|
||||||
|
partial[tid] = local_sum;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) partial[tid] += partial[tid + s];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
float rms_inv = rsqrt(partial[0] / float(dim) + eps);
|
||||||
|
|
||||||
|
for (uint i = tid; i < dim; i += tpg)
|
||||||
|
oi[i] = xi[i] * rms_inv * w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched embedding lookup (N tokens) ─────────────────────────────
|
||||||
|
// Grid: (dim, n_tokens, 1). Each thread copies one element.
|
||||||
|
kernel void embed_lookup_batched(
|
||||||
|
device const float *embed [[buffer(0)]],
|
||||||
|
device float *out [[buffer(1)]],
|
||||||
|
device const uint *token_ids [[buffer(2)]],
|
||||||
|
constant uint &dim [[buffer(3)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint i = gid.x;
|
||||||
|
uint t = gid.y;
|
||||||
|
if (i >= dim) return;
|
||||||
|
out[uint64_t(t) * dim + i] = embed[uint64_t(token_ids[t]) * dim + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched RoPE (N tokens) ─────────────────────────────────────────
|
||||||
|
// Applies RoPE to Q[t] and K[t] for each token t at position base_pos+t.
|
||||||
|
// Grid: total_pairs per token * n_tokens
|
||||||
|
kernel void rope_apply_batched(
|
||||||
|
device float *q [[buffer(0)]],
|
||||||
|
device float *k [[buffer(1)]],
|
||||||
|
device const float *cos_tbl [[buffer(2)]],
|
||||||
|
device const float *sin_tbl [[buffer(3)]],
|
||||||
|
constant uint &n_q_heads [[buffer(4)]],
|
||||||
|
constant uint &n_kv_heads [[buffer(5)]],
|
||||||
|
constant uint &head_dim [[buffer(6)]],
|
||||||
|
constant uint &base_pos [[buffer(7)]],
|
||||||
|
constant uint &q_stride [[buffer(8)]],
|
||||||
|
constant uint &k_stride [[buffer(9)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint pair_idx = gid.x;
|
||||||
|
uint t = gid.y;
|
||||||
|
|
||||||
|
uint half_dim = head_dim / 2;
|
||||||
|
uint total_pairs = (n_q_heads + n_kv_heads) * half_dim;
|
||||||
|
if (pair_idx >= total_pairs) return;
|
||||||
|
|
||||||
|
uint head_pair = pair_idx / half_dim;
|
||||||
|
uint i = pair_idx % half_dim;
|
||||||
|
uint pos = base_pos + t;
|
||||||
|
|
||||||
|
device float *vec;
|
||||||
|
if (head_pair < n_q_heads)
|
||||||
|
vec = q + uint64_t(t) * q_stride + head_pair * head_dim;
|
||||||
|
else
|
||||||
|
vec = k + uint64_t(t) * k_stride + (head_pair - n_q_heads) * head_dim;
|
||||||
|
|
||||||
|
uint cos_off = pos * half_dim;
|
||||||
|
float f = vec[i];
|
||||||
|
float s = vec[i + half_dim];
|
||||||
|
float c = cos_tbl[cos_off + i];
|
||||||
|
float sv = sin_tbl[cos_off + i];
|
||||||
|
vec[i] = f * c - s * sv;
|
||||||
|
vec[i + half_dim] = s * c + f * sv;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched vec_add: out[i] = a[i] + b[i] for N*dim elements ────────
|
||||||
|
kernel void vec_add_batched(
|
||||||
|
device const float *a [[buffer(0)]],
|
||||||
|
device const float *b [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
constant uint &total_n [[buffer(3)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= total_n) return;
|
||||||
|
out[gid] = a[gid] + b[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── F16 matrix-vector multiply ───────────────────────────────────────
|
||||||
|
kernel void sgemv_f16(
|
||||||
|
device const half *W [[buffer(0)]],
|
||||||
|
device const float *x [[buffer(1)]],
|
||||||
|
device float *y [[buffer(2)]],
|
||||||
|
constant uint &in_dim [[buffer(3)]],
|
||||||
|
constant uint &out_dim [[buffer(4)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= out_dim) return;
|
||||||
|
|
||||||
|
device const half *row = W + uint64_t(gid) * in_dim;
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
uint i = 0;
|
||||||
|
for (; i + 7 < in_dim; i += 8) {
|
||||||
|
sum += float(row[i]) * x[i];
|
||||||
|
sum += float(row[i + 1]) * x[i + 1];
|
||||||
|
sum += float(row[i + 2]) * x[i + 2];
|
||||||
|
sum += float(row[i + 3]) * x[i + 3];
|
||||||
|
sum += float(row[i + 4]) * x[i + 4];
|
||||||
|
sum += float(row[i + 5]) * x[i + 5];
|
||||||
|
sum += float(row[i + 6]) * x[i + 6];
|
||||||
|
sum += float(row[i + 7]) * x[i + 7];
|
||||||
|
}
|
||||||
|
for (; i < in_dim; i++)
|
||||||
|
sum += float(row[i]) * x[i];
|
||||||
|
|
||||||
|
y[gid] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── F32 matrix-vector multiply ───────────────────────────────────────
|
||||||
|
kernel void sgemv_f32(
|
||||||
|
device const float *W [[buffer(0)]],
|
||||||
|
device const float *x [[buffer(1)]],
|
||||||
|
device float *y [[buffer(2)]],
|
||||||
|
constant uint &in_dim [[buffer(3)]],
|
||||||
|
constant uint &out_dim [[buffer(4)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= out_dim) return;
|
||||||
|
|
||||||
|
device const float *row = W + uint64_t(gid) * in_dim;
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
uint i = 0;
|
||||||
|
for (; i + 7 < in_dim; i += 8) {
|
||||||
|
sum += row[i] * x[i];
|
||||||
|
sum += row[i + 1] * x[i + 1];
|
||||||
|
sum += row[i + 2] * x[i + 2];
|
||||||
|
sum += row[i + 3] * x[i + 3];
|
||||||
|
sum += row[i + 4] * x[i + 4];
|
||||||
|
sum += row[i + 5] * x[i + 5];
|
||||||
|
sum += row[i + 6] * x[i + 6];
|
||||||
|
sum += row[i + 7] * x[i + 7];
|
||||||
|
}
|
||||||
|
for (; i < in_dim; i++)
|
||||||
|
sum += row[i] * x[i];
|
||||||
|
|
||||||
|
y[gid] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── RMS Normalization ────────────────────────────────────────────────
|
||||||
|
// out[i] = x[i] * w[i] / sqrt(mean(x^2) + eps)
|
||||||
|
// Two-pass: first compute sum of squares (reduction), then normalize.
|
||||||
|
// Single threadgroup processes the entire vector.
|
||||||
|
kernel void rms_norm(
|
||||||
|
device const float *x [[buffer(0)]],
|
||||||
|
device const float *w [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
constant uint &dim [[buffer(3)]],
|
||||||
|
constant float &eps [[buffer(4)]],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint tpg [[threads_per_threadgroup]])
|
||||||
|
{
|
||||||
|
threadgroup float partial[1024];
|
||||||
|
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
for (uint i = tid; i < dim; i += tpg)
|
||||||
|
local_sum += x[i] * x[i];
|
||||||
|
partial[tid] = local_sum;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// Tree reduction
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) partial[tid] += partial[tid + s];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
float rms_inv = rsqrt(partial[0] / float(dim) + eps);
|
||||||
|
|
||||||
|
for (uint i = tid; i < dim; i += tpg)
|
||||||
|
out[i] = x[i] * rms_inv * w[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── RoPE (Rotary Position Embedding) ─────────────────────────────────
|
||||||
|
// Applies RoPE to Q and K vectors in-place.
|
||||||
|
// cos_sin is precomputed: [half_dim] cos values followed by [half_dim] sin values.
|
||||||
|
kernel void rope_apply(
|
||||||
|
device float *q [[buffer(0)]],
|
||||||
|
device float *k [[buffer(1)]],
|
||||||
|
device const float *cos_v [[buffer(2)]],
|
||||||
|
device const float *sin_v [[buffer(3)]],
|
||||||
|
constant uint &n_q_heads [[buffer(4)]],
|
||||||
|
constant uint &n_kv_heads [[buffer(5)]],
|
||||||
|
constant uint &head_dim [[buffer(6)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint half_dim = head_dim / 2;
|
||||||
|
uint total_pairs = (n_q_heads + n_kv_heads) * half_dim;
|
||||||
|
if (gid >= total_pairs) return;
|
||||||
|
|
||||||
|
uint head_pair = gid / half_dim;
|
||||||
|
uint i = gid % half_dim;
|
||||||
|
|
||||||
|
device float *vec;
|
||||||
|
if (head_pair < n_q_heads) {
|
||||||
|
vec = q + head_pair * head_dim;
|
||||||
|
} else {
|
||||||
|
vec = k + (head_pair - n_q_heads) * head_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
float f = vec[i];
|
||||||
|
float s = vec[i + half_dim];
|
||||||
|
float c = cos_v[i];
|
||||||
|
float sv = sin_v[i];
|
||||||
|
vec[i] = f * c - s * sv;
|
||||||
|
vec[i + half_dim] = s * c + f * sv;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── SiLU activation + element-wise multiply ──────────────────────────
|
||||||
|
// gate[i] = silu(gate[i]) * up[i]
|
||||||
|
// silu(x) = x / (1 + exp(-x))
|
||||||
|
kernel void silu_mul(
|
||||||
|
device float *gate [[buffer(0)]],
|
||||||
|
device const float *up [[buffer(1)]],
|
||||||
|
constant uint &n [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= n) return;
|
||||||
|
float x = gate[gid];
|
||||||
|
float s = x / (1.0f + exp(-x));
|
||||||
|
gate[gid] = s * up[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Vector add (residual connection) ─────────────────────────────────
|
||||||
|
// out[i] = a[i] + b[i]
|
||||||
|
kernel void vec_add(
|
||||||
|
device const float *a [[buffer(0)]],
|
||||||
|
device const float *b [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
constant uint &n [[buffer(3)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= n) return;
|
||||||
|
out[gid] = a[gid] + b[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Bias add ─────────────────────────────────────────────────────────
|
||||||
|
// x[i] += bias[i]
|
||||||
|
kernel void bias_add(
|
||||||
|
device float *x [[buffer(0)]],
|
||||||
|
device const float *bias [[buffer(1)]],
|
||||||
|
constant uint &n [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= n) return;
|
||||||
|
x[gid] += bias[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Embedding lookup ─────────────────────────────────────────────────
|
||||||
|
// out[i] = embed[token_id * dim + i]
|
||||||
|
kernel void embed_lookup(
|
||||||
|
device const float *embed [[buffer(0)]],
|
||||||
|
device float *out [[buffer(1)]],
|
||||||
|
constant uint &token_id [[buffer(2)]],
|
||||||
|
constant uint &dim [[buffer(3)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= dim) return;
|
||||||
|
out[gid] = embed[uint64_t(token_id) * dim + gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Attention score: Q @ K^T for one head (legacy) ──────────────────
|
||||||
|
kernel void attn_score(
|
||||||
|
device const float *qh [[buffer(0)]],
|
||||||
|
device const float *kv_cache_k [[buffer(1)]],
|
||||||
|
device float *att [[buffer(2)]],
|
||||||
|
constant uint &head_dim [[buffer(3)]],
|
||||||
|
constant uint &kv_dim [[buffer(4)]],
|
||||||
|
constant uint &kv_head_offset [[buffer(5)]],
|
||||||
|
constant float &scale [[buffer(6)]],
|
||||||
|
constant uint &seq_len [[buffer(7)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= seq_len) return;
|
||||||
|
|
||||||
|
device const float *kt = kv_cache_k + uint64_t(gid) * kv_dim + kv_head_offset;
|
||||||
|
float dot = 0.0f;
|
||||||
|
for (uint i = 0; i < head_dim; i++)
|
||||||
|
dot += qh[i] * kt[i];
|
||||||
|
att[gid] = dot * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched attention score: all Q heads in one dispatch ─────────────
|
||||||
|
// Grid: (seq_len, n_q_heads, 1). Each thread computes one score for one head.
|
||||||
|
// GQA: maps Q head h to KV head h/gqa_factor.
|
||||||
|
kernel void attn_score_batched(
|
||||||
|
device const float *q [[buffer(0)]],
|
||||||
|
device const float *kv_cache_k [[buffer(1)]],
|
||||||
|
device float *att [[buffer(2)]],
|
||||||
|
constant uint &head_dim [[buffer(3)]],
|
||||||
|
constant uint &kv_dim [[buffer(4)]],
|
||||||
|
constant uint &n_q_heads [[buffer(5)]],
|
||||||
|
constant uint &gqa_factor [[buffer(6)]],
|
||||||
|
constant float &scale [[buffer(7)]],
|
||||||
|
constant uint &seq_len [[buffer(8)]],
|
||||||
|
constant uint &max_seq [[buffer(9)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint t = gid.x;
|
||||||
|
uint h = gid.y;
|
||||||
|
if (t >= seq_len || h >= n_q_heads) return;
|
||||||
|
|
||||||
|
uint kv_h = h / gqa_factor;
|
||||||
|
device const float *qh = q + h * head_dim;
|
||||||
|
device const float *kt = kv_cache_k + uint64_t(t) * kv_dim + kv_h * head_dim;
|
||||||
|
|
||||||
|
float dot = 0.0f;
|
||||||
|
for (uint i = 0; i < head_dim; i++)
|
||||||
|
dot += qh[i] * kt[i];
|
||||||
|
|
||||||
|
att[h * max_seq + t] = dot * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched softmax: all heads in one dispatch ───────────────────────
|
||||||
|
// One threadgroup per head. tid reduces over seq_len dimension.
|
||||||
|
kernel void softmax_batched(
|
||||||
|
device float *att [[buffer(0)]],
|
||||||
|
constant uint &seq_len [[buffer(1)]],
|
||||||
|
constant uint &max_seq [[buffer(2)]],
|
||||||
|
constant uint &n_q_heads [[buffer(3)]],
|
||||||
|
uint tgid [[threadgroup_position_in_grid]],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint tpg [[threads_per_threadgroup]])
|
||||||
|
{
|
||||||
|
uint h = tgid;
|
||||||
|
if (h >= n_q_heads) return;
|
||||||
|
|
||||||
|
device float *head_att = att + h * max_seq;
|
||||||
|
threadgroup float shared[1024];
|
||||||
|
|
||||||
|
float local_max = -1e30f;
|
||||||
|
for (uint i = tid; i < seq_len; i += tpg)
|
||||||
|
local_max = max(local_max, head_att[i]);
|
||||||
|
shared[tid] = local_max;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
float max_val = shared[0];
|
||||||
|
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
for (uint i = tid; i < seq_len; i += tpg) {
|
||||||
|
float e = exp(head_att[i] - max_val);
|
||||||
|
head_att[i] = e;
|
||||||
|
local_sum += e;
|
||||||
|
}
|
||||||
|
shared[tid] = local_sum;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) shared[tid] += shared[tid + s];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
float inv_sum = 1.0f / shared[0];
|
||||||
|
|
||||||
|
for (uint i = tid; i < seq_len; i += tpg)
|
||||||
|
head_att[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Batched attention weighted sum: all heads in one dispatch ────────
|
||||||
|
// Grid: (head_dim, n_q_heads, 1). Each thread computes one output dim for one head.
|
||||||
|
kernel void attn_wsum_batched(
|
||||||
|
device const float *att [[buffer(0)]],
|
||||||
|
device const float *kv_cache_v [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
constant uint &head_dim [[buffer(3)]],
|
||||||
|
constant uint &kv_dim [[buffer(4)]],
|
||||||
|
constant uint &n_q_heads [[buffer(5)]],
|
||||||
|
constant uint &gqa_factor [[buffer(6)]],
|
||||||
|
constant uint &seq_len [[buffer(7)]],
|
||||||
|
constant uint &max_seq [[buffer(8)]],
|
||||||
|
uint2 gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint d = gid.x;
|
||||||
|
uint h = gid.y;
|
||||||
|
if (d >= head_dim || h >= n_q_heads) return;
|
||||||
|
|
||||||
|
uint kv_h = h / gqa_factor;
|
||||||
|
device const float *head_att = att + h * max_seq;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (uint t = 0; t < seq_len; t++) {
|
||||||
|
float a = head_att[t];
|
||||||
|
float v = kv_cache_v[uint64_t(t) * kv_dim + kv_h * head_dim + d];
|
||||||
|
sum += a * v;
|
||||||
|
}
|
||||||
|
out[h * head_dim + d] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Softmax (legacy, single head) ───────────────────────────────────
|
||||||
|
kernel void softmax_inplace(
|
||||||
|
device float *att [[buffer(0)]],
|
||||||
|
constant uint &seq_len [[buffer(1)]],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint tpg [[threads_per_threadgroup]])
|
||||||
|
{
|
||||||
|
threadgroup float shared[1024];
|
||||||
|
|
||||||
|
float local_max = -1e30f;
|
||||||
|
for (uint i = tid; i < seq_len; i += tpg)
|
||||||
|
local_max = max(local_max, att[i]);
|
||||||
|
shared[tid] = local_max;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
float max_val = shared[0];
|
||||||
|
|
||||||
|
float local_sum = 0.0f;
|
||||||
|
for (uint i = tid; i < seq_len; i += tpg) {
|
||||||
|
float e = exp(att[i] - max_val);
|
||||||
|
att[i] = e;
|
||||||
|
local_sum += e;
|
||||||
|
}
|
||||||
|
shared[tid] = local_sum;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) shared[tid] += shared[tid + s];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
float inv_sum = 1.0f / shared[0];
|
||||||
|
|
||||||
|
for (uint i = tid; i < seq_len; i += tpg)
|
||||||
|
att[i] *= inv_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Attention weighted sum (legacy, single head) ─────────────────────
|
||||||
|
kernel void attn_weighted_sum(
|
||||||
|
device const float *att [[buffer(0)]],
|
||||||
|
device const float *kv_cache_v [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
constant uint &head_dim [[buffer(3)]],
|
||||||
|
constant uint &kv_dim [[buffer(4)]],
|
||||||
|
constant uint &kv_head_offset [[buffer(5)]],
|
||||||
|
constant uint &seq_len [[buffer(6)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= head_dim) return;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (uint t = 0; t < seq_len; t++) {
|
||||||
|
float a = att[t];
|
||||||
|
float v = kv_cache_v[uint64_t(t) * kv_dim + kv_head_offset + gid];
|
||||||
|
sum += a * v;
|
||||||
|
}
|
||||||
|
out[gid] = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Argmax ───────────────────────────────────────────────────────────
|
||||||
|
// Finds argmax of logits[0..n-1], writes to result[0].
|
||||||
|
// Single threadgroup.
|
||||||
|
kernel void argmax_kernel(
|
||||||
|
device const float *logits [[buffer(0)]],
|
||||||
|
device int *result [[buffer(1)]],
|
||||||
|
constant uint &n [[buffer(2)]],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint tpg [[threads_per_threadgroup]])
|
||||||
|
{
|
||||||
|
threadgroup float shared_val[1024];
|
||||||
|
threadgroup int shared_idx[1024];
|
||||||
|
|
||||||
|
float local_max = -1e30f;
|
||||||
|
int local_idx = 0;
|
||||||
|
for (uint i = tid; i < n; i += tpg) {
|
||||||
|
if (logits[i] > local_max) {
|
||||||
|
local_max = logits[i];
|
||||||
|
local_idx = int(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
shared_val[tid] = local_max;
|
||||||
|
shared_idx[tid] = local_idx;
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s && shared_val[tid + s] > shared_val[tid]) {
|
||||||
|
shared_val[tid] = shared_val[tid + s];
|
||||||
|
shared_idx[tid] = shared_idx[tid + s];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) result[0] = shared_idx[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Copy kernel ──────────────────────────────────────────────────────
|
||||||
|
// dst[i] = src[i] for i in [0, n)
|
||||||
|
kernel void vec_copy(
|
||||||
|
device const float *src [[buffer(0)]],
|
||||||
|
device float *dst [[buffer(1)]],
|
||||||
|
constant uint &n [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= n) return;
|
||||||
|
dst[gid] = src[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Zero-fill ────────────────────────────────────────────────────────
|
||||||
|
kernel void vec_zero(
|
||||||
|
device float *dst [[buffer(0)]],
|
||||||
|
constant uint &n [[buffer(1)]],
|
||||||
|
uint gid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
if (gid >= n) return;
|
||||||
|
dst[gid] = 0.0f;
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -7,7 +7,8 @@ MODEL_DIR="$HOME/models/Qwen2.5-0.5B-Instruct"
|
||||||
WEIGHTS_BIN="$SCRIPT_DIR/qwen05b.bin"
|
WEIGHTS_BIN="$SCRIPT_DIR/qwen05b.bin"
|
||||||
BINARY="$SCRIPT_DIR/qwen_ane"
|
BINARY="$SCRIPT_DIR/qwen_ane"
|
||||||
VENV_DIR="$SCRIPT_DIR/.venv"
|
VENV_DIR="$SCRIPT_DIR/.venv"
|
||||||
EXPECTED_WEIGHT_SIZE=1976131100
|
EXPECTED_WEIGHT_SIZE_F32=1976131100
|
||||||
|
EXPECTED_WEIGHT_SIZE_F16=988082236
|
||||||
|
|
||||||
info() { printf "\033[1;34m==> %s\033[0m\n" "$1"; }
|
info() { printf "\033[1;34m==> %s\033[0m\n" "$1"; }
|
||||||
ok() { printf "\033[1;32m ✓ %s\033[0m\n" "$1"; }
|
ok() { printf "\033[1;32m ✓ %s\033[0m\n" "$1"; }
|
||||||
|
|
@ -86,16 +87,16 @@ info "Converting weights to binary format..."
|
||||||
|
|
||||||
if [ -f "$WEIGHTS_BIN" ]; then
|
if [ -f "$WEIGHTS_BIN" ]; then
|
||||||
ACTUAL_SIZE=$(stat -f%z "$WEIGHTS_BIN" 2>/dev/null || stat -c%s "$WEIGHTS_BIN" 2>/dev/null)
|
ACTUAL_SIZE=$(stat -f%z "$WEIGHTS_BIN" 2>/dev/null || stat -c%s "$WEIGHTS_BIN" 2>/dev/null)
|
||||||
if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE" ]; then
|
if [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE_F16" ] || [ "$ACTUAL_SIZE" -eq "$EXPECTED_WEIGHT_SIZE_F32" ]; then
|
||||||
ok "Weights already converted ($((ACTUAL_SIZE / 1024 / 1024)) MB)"
|
ok "Weights already converted ($((ACTUAL_SIZE / 1024 / 1024)) MB)"
|
||||||
else
|
else
|
||||||
warn "Weight file exists but wrong size ($ACTUAL_SIZE vs $EXPECTED_WEIGHT_SIZE), reconverting"
|
warn "Weight file exists but unexpected size ($ACTUAL_SIZE), reconverting as F16"
|
||||||
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN"
|
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" --f16
|
||||||
ok "Weights converted"
|
ok "Weights converted (F16)"
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN"
|
python3 "$SCRIPT_DIR/convert_weights.py" "$MODEL_DIR" "$WEIGHTS_BIN" --f16
|
||||||
ok "Weights converted"
|
ok "Weights converted (F16)"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# --- Step 5: Build binary ---
|
# --- Step 5: Build binary ---
|
||||||
|
|
@ -113,8 +114,10 @@ elif [ "$SCRIPT_DIR/main.m" -nt "$BINARY" ] || \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ "$NEEDS_BUILD" -eq 1 ]; then
|
if [ "$NEEDS_BUILD" -eq 1 ]; then
|
||||||
xcrun clang -O2 -framework Foundation -framework IOSurface \
|
xcrun clang -O3 -ffast-math -mcpu=apple-m4 -flto \
|
||||||
-framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
|
-framework Foundation -framework IOSurface \
|
||||||
|
-framework CoreML -framework Accelerate -framework Metal \
|
||||||
|
-ldl -lobjc -fobjc-arc \
|
||||||
-o "$BINARY" "$SCRIPT_DIR/main.m"
|
-o "$BINARY" "$SCRIPT_DIR/main.m"
|
||||||
ok "Binary built: $BINARY"
|
ok "Binary built: $BINARY"
|
||||||
else
|
else
|
||||||
|
|
|
||||||
|
|
@ -232,6 +232,124 @@ static NSData *mil_build_ffn_up_weight_blob(const float *w1, const float *w3, in
|
||||||
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
|
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate MIL for fused GQA QKV: Q, K, V have different output dimensions
|
||||||
|
// Qwen2.5-0.5B: Q=[q_dim, dim], K=[kv_dim, dim], V=[kv_dim, dim]
|
||||||
|
// Weight blob: Wq[q_dim,dim] @ chunk0, Wk[kv_dim,dim] @ chunk1, Wv[kv_dim,dim] @ chunk2
|
||||||
|
static NSString *mil_gen_qkv_gqa(int dim, int q_dim, int kv_dim, int spatial) {
|
||||||
|
NSUInteger cs_q = 64 + (NSUInteger)q_dim * dim * 2;
|
||||||
|
NSUInteger cs_kv = 64 + (NSUInteger)kv_dim * dim * 2;
|
||||||
|
NSUInteger off_k = 64 + cs_q;
|
||||||
|
NSUInteger off_v = off_k + cs_kv;
|
||||||
|
if (g_fp16_io) {
|
||||||
|
return [NSString stringWithFormat:
|
||||||
|
@"program(1.0)\n"
|
||||||
|
"[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{\"coremlc-version\", \"3505.4.1\"}})]\n"
|
||||||
|
"{\n"
|
||||||
|
" func main<ios16>(tensor<fp16, [1, %d, 1, %d]> x) {\n"
|
||||||
|
" tensor<string, []> c_pad_type = const()[name = tensor<string, []>(\"c_pad_type\"), val = tensor<string, []>(\"valid\")];\n"
|
||||||
|
" tensor<int32, [2]> c_strides = const()[name = tensor<string, []>(\"c_strides\"), val = tensor<int32, [2]>([1, 1])];\n"
|
||||||
|
" tensor<int32, [4]> c_pad = const()[name = tensor<string, []>(\"c_pad\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
|
||||||
|
" tensor<int32, [2]> c_dilations = const()[name = tensor<string, []>(\"c_dilations\"), val = tensor<int32, [2]>([1, 1])];\n"
|
||||||
|
" tensor<int32, []> c_groups = const()[name = tensor<string, []>(\"c_groups\"), val = tensor<int32, []>(1)];\n"
|
||||||
|
" tensor<fp16, [%d, %d, 1, 1]> Wq = const()[name = tensor<string, []>(\"Wq\"), "
|
||||||
|
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(64)))];\n"
|
||||||
|
" tensor<fp16, [%d, %d, 1, 1]> Wk = const()[name = tensor<string, []>(\"Wk\"), "
|
||||||
|
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
|
||||||
|
" tensor<fp16, [%d, %d, 1, 1]> Wv = const()[name = tensor<string, []>(\"Wv\"), "
|
||||||
|
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> q = conv(dilations = c_dilations, groups = c_groups, "
|
||||||
|
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wq, x = x)[name = tensor<string, []>(\"conv_q\")];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> k = conv(dilations = c_dilations, groups = c_groups, "
|
||||||
|
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wk, x = x)[name = tensor<string, []>(\"conv_k\")];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> v = conv(dilations = c_dilations, groups = c_groups, "
|
||||||
|
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wv, x = x)[name = tensor<string, []>(\"conv_v\")];\n"
|
||||||
|
" } -> (q, k, v);\n"
|
||||||
|
"}\n",
|
||||||
|
dim, spatial,
|
||||||
|
q_dim, dim, q_dim, dim,
|
||||||
|
kv_dim, dim, kv_dim, dim, (unsigned long)off_k,
|
||||||
|
kv_dim, dim, kv_dim, dim, (unsigned long)off_v,
|
||||||
|
q_dim, spatial, kv_dim, spatial, kv_dim, spatial];
|
||||||
|
}
|
||||||
|
return [NSString stringWithFormat:
|
||||||
|
@"program(1.0)\n"
|
||||||
|
"[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{\"coremlc-version\", \"3505.4.1\"}})]\n"
|
||||||
|
"{\n"
|
||||||
|
" func main<ios16>(tensor<fp32, [1, %d, 1, %d]> x) {\n"
|
||||||
|
" tensor<string, []> c_pad_type = const()[name = tensor<string, []>(\"c_pad_type\"), val = tensor<string, []>(\"valid\")];\n"
|
||||||
|
" tensor<int32, [2]> c_strides = const()[name = tensor<string, []>(\"c_strides\"), val = tensor<int32, [2]>([1, 1])];\n"
|
||||||
|
" tensor<int32, [4]> c_pad = const()[name = tensor<string, []>(\"c_pad\"), val = tensor<int32, [4]>([0, 0, 0, 0])];\n"
|
||||||
|
" tensor<int32, [2]> c_dilations = const()[name = tensor<string, []>(\"c_dilations\"), val = tensor<int32, [2]>([1, 1])];\n"
|
||||||
|
" tensor<int32, []> c_groups = const()[name = tensor<string, []>(\"c_groups\"), val = tensor<int32, []>(1)];\n"
|
||||||
|
" tensor<string, []> to_fp16 = const()[name = tensor<string, []>(\"to_fp16\"), val = tensor<string, []>(\"fp16\")];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> x16 = cast(dtype = to_fp16, x = x)[name = tensor<string, []>(\"cast_in\")];\n"
|
||||||
|
" tensor<fp16, [%d, %d, 1, 1]> Wq = const()[name = tensor<string, []>(\"Wq\"), "
|
||||||
|
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(64)))];\n"
|
||||||
|
" tensor<fp16, [%d, %d, 1, 1]> Wk = const()[name = tensor<string, []>(\"Wk\"), "
|
||||||
|
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
|
||||||
|
" tensor<fp16, [%d, %d, 1, 1]> Wv = const()[name = tensor<string, []>(\"Wv\"), "
|
||||||
|
"val = tensor<fp16, [%d, %d, 1, 1]>(BLOBFILE(path = tensor<string, []>(\"@model_path/weights/weight.bin\"), offset = tensor<uint64, []>(%lu)))];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> q16 = conv(dilations = c_dilations, groups = c_groups, "
|
||||||
|
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wq, x = x16)[name = tensor<string, []>(\"conv_q\")];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> k16 = conv(dilations = c_dilations, groups = c_groups, "
|
||||||
|
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wk, x = x16)[name = tensor<string, []>(\"conv_k\")];\n"
|
||||||
|
" tensor<fp16, [1, %d, 1, %d]> v16 = conv(dilations = c_dilations, groups = c_groups, "
|
||||||
|
"pad = c_pad, pad_type = c_pad_type, strides = c_strides, weight = Wv, x = x16)[name = tensor<string, []>(\"conv_v\")];\n"
|
||||||
|
" tensor<string, []> to_fp32 = const()[name = tensor<string, []>(\"to_fp32\"), val = tensor<string, []>(\"fp32\")];\n"
|
||||||
|
" tensor<fp32, [1, %d, 1, %d]> q = cast(dtype = to_fp32, x = q16)[name = tensor<string, []>(\"cast_q\")];\n"
|
||||||
|
" tensor<fp32, [1, %d, 1, %d]> k = cast(dtype = to_fp32, x = k16)[name = tensor<string, []>(\"cast_k\")];\n"
|
||||||
|
" tensor<fp32, [1, %d, 1, %d]> v = cast(dtype = to_fp32, x = v16)[name = tensor<string, []>(\"cast_v\")];\n"
|
||||||
|
" } -> (q, k, v);\n"
|
||||||
|
"}\n",
|
||||||
|
dim, spatial, dim, spatial,
|
||||||
|
q_dim, dim, q_dim, dim,
|
||||||
|
kv_dim, dim, kv_dim, dim, (unsigned long)off_k,
|
||||||
|
kv_dim, dim, kv_dim, dim, (unsigned long)off_v,
|
||||||
|
q_dim, spatial, kv_dim, spatial, kv_dim, spatial,
|
||||||
|
q_dim, spatial, kv_dim, spatial, kv_dim, spatial];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build weight blob for GQA QKV (3 weight matrices with different shapes)
|
||||||
|
static NSData *mil_build_qkv_gqa_weight_blob(const float *wq, int q_dim, int dim,
|
||||||
|
const float *wk, const float *wv, int kv_dim) {
|
||||||
|
NSUInteger wsize_q = (NSUInteger)q_dim * dim * 2;
|
||||||
|
NSUInteger wsize_kv = (NSUInteger)kv_dim * dim * 2;
|
||||||
|
NSUInteger cs_q = 64 + wsize_q;
|
||||||
|
NSUInteger cs_kv = 64 + wsize_kv;
|
||||||
|
NSUInteger total = 64 + cs_q + 2 * cs_kv;
|
||||||
|
uint8_t *buf = (uint8_t*)calloc(total, 1);
|
||||||
|
buf[0] = 0x01; buf[4] = 0x02;
|
||||||
|
|
||||||
|
// Chunk 0: Wq
|
||||||
|
{
|
||||||
|
uint8_t *chunk = buf + 64;
|
||||||
|
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01;
|
||||||
|
*(uint32_t*)(chunk + 8) = (uint32_t)wsize_q;
|
||||||
|
*(uint32_t*)(chunk + 16) = (uint32_t)(64 + 64);
|
||||||
|
_Float16 *fp16 = (_Float16*)(chunk + 64);
|
||||||
|
for (NSUInteger i = 0; i < (NSUInteger)q_dim * dim; i++) fp16[i] = (_Float16)wq[i];
|
||||||
|
}
|
||||||
|
// Chunk 1: Wk
|
||||||
|
{
|
||||||
|
uint8_t *chunk = buf + 64 + cs_q;
|
||||||
|
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01;
|
||||||
|
*(uint32_t*)(chunk + 8) = (uint32_t)wsize_kv;
|
||||||
|
*(uint32_t*)(chunk + 16) = (uint32_t)(64 + cs_q + 64);
|
||||||
|
_Float16 *fp16 = (_Float16*)(chunk + 64);
|
||||||
|
for (NSUInteger i = 0; i < (NSUInteger)kv_dim * dim; i++) fp16[i] = (_Float16)wk[i];
|
||||||
|
}
|
||||||
|
// Chunk 2: Wv
|
||||||
|
{
|
||||||
|
uint8_t *chunk = buf + 64 + cs_q + cs_kv;
|
||||||
|
chunk[0]=0xEF; chunk[1]=0xBE; chunk[2]=0xAD; chunk[3]=0xDE; chunk[4]=0x01;
|
||||||
|
*(uint32_t*)(chunk + 8) = (uint32_t)wsize_kv;
|
||||||
|
*(uint32_t*)(chunk + 16) = (uint32_t)(64 + cs_q + cs_kv + 64);
|
||||||
|
_Float16 *fp16 = (_Float16*)(chunk + 64);
|
||||||
|
for (NSUInteger i = 0; i < (NSUInteger)kv_dim * dim; i++) fp16[i] = (_Float16)wv[i];
|
||||||
|
}
|
||||||
|
return [NSData dataWithBytesNoCopy:buf length:total freeWhenDone:YES];
|
||||||
|
}
|
||||||
|
|
||||||
// Generate MIL for fused FFN up: w1 + w3 parallel convs
|
// Generate MIL for fused FFN up: w1 + w3 parallel convs
|
||||||
static NSString *mil_gen_ffn_up(int dim, int hidden_dim, int spatial) {
|
static NSString *mil_gen_ffn_up(int dim, int hidden_dim, int spatial) {
|
||||||
NSUInteger cs = 64 + (NSUInteger)hidden_dim * dim * 2;
|
NSUInteger cs = 64 + (NSUInteger)hidden_dim * dim * 2;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue