[feat] Inference server mode: keep ANE kernels loaded between prompts (stdin loop + Unix socket server). Subsequent queries respond in ~0.5s instead of ~6s. run.py auto-connects to socket server when available.

This commit is contained in:
Erik Bray 2026-03-03 17:34:54 +01:00
parent b4d81b71d4
commit 6f16dbefca
4 changed files with 403 additions and 98 deletions

View File

@ -42,16 +42,60 @@ python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct qwen05b.bin
# 2. Build # 2. Build
xcrun clang -O2 -framework Foundation -framework IOSurface \ xcrun clang -O2 -framework Foundation -framework IOSurface \
-framework CoreML -framework Accelerate -ldl -lobjc \ -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
-o qwen_ane main.m -o qwen_ane main.m
# 3. Run (pass space-separated token IDs) # 3. Run (single-shot, pass space-separated token IDs)
./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20 ./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20
# 4. With tokenizer (requires transformers) # 4. With tokenizer (requires transformers)
python3 run.py "Say hello in one word." python3 run.py "Say hello in one word."
``` ```
## Server Mode (Recommended)
The first invocation compiles 169 ANE kernels (~5.5s). Server mode keeps them loaded so subsequent prompts respond instantly.
### Socket server (best for `run.py` integration)
```bash
# Terminal 1: start the server (compiles once, stays running)
./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
# Terminal 2: queries are instant (~0.5s instead of ~6s)
python3 run.py "What is 2+2?"
python3 run.py "Capital of France?"
python3 run.py "Count from 1 to 5"
```
`run.py` auto-detects the socket at `/tmp/qwen_ane.sock` and connects to it. If no server is running, it falls back to subprocess mode (slower).
You can also query the socket directly:
```bash
echo '{"tokens": [151644, 8948, 198], "max_tokens": 50}' | nc -U /tmp/qwen_ane.sock
```
Response format:
```json
{"output": [9707, 0, 151645], "prefill_tps": 68.4, "decode_tps": 67.8, "prompt_tokens": 28, "gen_tokens": 3}
```
### Stdin server (for piping/scripting)
```bash
./qwen_ane qwen05b.bin --server
# Waits for "READY", then send lines of space-separated token IDs:
# 151644 8948 198 2610 525|20
# (pipe character separates max_tokens)
```
### Performance comparison
| Mode | First prompt | Subsequent prompts |
|------|-------------|-------------------|
| Single-shot | ~6s | ~6s (recompiles) |
| Server | ~6s (startup) | ~0.5s |
## Output ## Output
``` ```
@ -104,7 +148,6 @@ Adapting to other architectures (LLaMA, Gemma, Mistral) requires:
## Known Limitations ## Known Limitations
- **CPU projections only** — ANE baked-weight conv kernels compile successfully but produce incorrect output (FP16 weight blob format mismatch). The `USE_ANE_PROJECTIONS` toggle exists but defaults to 0 (CPU via Accelerate BLAS). Fixing this would push decode speed from 82 t/s to 120+ t/s. - **CPU projections only** — ANE baked-weight conv kernels compile successfully but produce incorrect output (FP16 weight blob format mismatch). The `USE_ANE_PROJECTIONS` toggle exists but defaults to 0 (CPU via Accelerate BLAS). Fixing this would push decode speed from 82 t/s to 120+ t/s.
- **No persistent server** — each invocation recompiles 169 kernels (~5s). A server mode that compiles once and serves via HTTP would eliminate this overhead.
- **Single model** — hardcoded for Qwen2.5-0.5B. Needs parameterization for other sizes. - **Single model** — hardcoded for Qwen2.5-0.5B. Needs parameterization for other sizes.
- **f32 weights** — 1.9GB on disk. FP16 or quantized weight support would halve this. - **f32 weights** — 1.9GB on disk. FP16 or quantized weight support would halve this.

View File

@ -1,29 +1,43 @@
// main.m Qwen2.5-0.5B inference on Apple Neural Engine // main.m -- Qwen2.5-0.5B inference on Apple Neural Engine
// Compiles ANE kernels for all linear projections, runs autoregressive decode. // Supports three modes:
// 1. Single-shot: ./qwen_ane weights.bin "token_ids" [max_tokens]
// 2. Stdin server: ./qwen_ane weights.bin --server
// 3. Socket server: ./qwen_ane weights.bin --server /tmp/qwen_ane.sock
// //
// Build: // Build:
// xcrun clang -O2 -framework Foundation -framework IOSurface \ // xcrun clang -O2 -framework Foundation -framework IOSurface \
// -framework CoreML -framework Accelerate -ldl -lobjc \ // -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
// -o qwen_ane main.m // -o qwen_ane main.m
// //
// Run:
// ./qwen_ane qwen05b.bin "Hello world"
//
#import <Foundation/Foundation.h> #import <Foundation/Foundation.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <time.h> #include <time.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include <signal.h>
#include "qwen_ane_infer.h" #include "qwen_ane_infer.h"
int g_fp16_io = 0; int g_fp16_io = 0;
static QwenModel g_model; static QwenModel g_model;
static const char *g_sock_path = NULL;
static void cleanup_socket(void) {
if (g_sock_path) unlink(g_sock_path);
}
static void handle_signal(int sig) {
(void)sig;
cleanup_socket();
_exit(0);
}
static int load_weights(const char *path) { static int load_weights(const char *path) {
FILE *f = fopen(path, "rb"); FILE *f = fopen(path, "rb");
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; } if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
// Read config header
int config[7]; int config[7];
fread(config, sizeof(int), 7, f); fread(config, sizeof(int), 7, f);
int dim = config[0], hidden = config[1], n_layers = config[2]; int dim = config[0], hidden = config[1], n_layers = config[2];
@ -34,11 +48,9 @@ static int load_weights(const char *path) {
int q_dim = n_heads * QWEN_HEAD_DIM; int q_dim = n_heads * QWEN_HEAD_DIM;
int kv_dim = n_kv_heads * QWEN_HEAD_DIM; int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
// Embedding
g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float)); g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float));
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f); fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
// Per-layer
for (int l = 0; l < n_layers; l++) { for (int l = 0; l < n_layers; l++) {
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float)); g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_att[l], sizeof(float), dim, f); fread(g_model.rms_att[l], sizeof(float), dim, f);
@ -49,10 +61,9 @@ static int load_weights(const char *path) {
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f); fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float)); g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f); fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); // o_proj is [dim, q_dim] g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f); fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
// Q/K/V biases
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float)); g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float)); g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float)); g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
@ -74,89 +85,273 @@ static int load_weights(const char *path) {
g_model.rms_final = (float*)malloc(dim * sizeof(float)); g_model.rms_final = (float*)malloc(dim * sizeof(float));
fread(g_model.rms_final, sizeof(float), dim, f); fread(g_model.rms_final, sizeof(float), dim, f);
long file_size = ftell(f);
fclose(f); fclose(f);
printf("Weights loaded (%.0f MB)\n", printf("Weights loaded (%.0f MB)\n", (float)file_size / 1024 / 1024);
(float)ftell(f) / 1024 / 1024);
return 0; return 0;
} }
// Parse space-separated token IDs from a string. Returns count.
static int parse_tokens(const char *str, int *ids, int max_ids) {
int n = 0;
char *buf = strdup(str);
char *saveptr;
char *p = strtok_r(buf, " \t\n\r", &saveptr);
while (p && n < max_ids) {
ids[n++] = atoi(p);
p = strtok_r(NULL, " \t\n\r", &saveptr);
}
free(buf);
return n;
}
static double timespec_diff(struct timespec *a, struct timespec *b) {
return (b->tv_sec - a->tv_sec) + (b->tv_nsec - a->tv_nsec) / 1e9;
}
// Run one generation pass. Writes output token IDs to out_ids, returns count.
// If out_fd >= 0, writes formatted results there; otherwise prints to stdout.
static int generate(int *prompt_ids, int n_prompt, int max_gen,
int *out_ids, int max_out,
double *prefill_tps, double *decode_tps) {
struct timespec t0, t1, t_pre;
clock_gettime(CLOCK_MONOTONIC, &t0);
int next = 0;
for (int i = 0; i < n_prompt; i++)
next = qwen_forward(&g_model, prompt_ids[i]);
clock_gettime(CLOCK_MONOTONIC, &t_pre);
double ps = timespec_diff(&t0, &t_pre);
*prefill_tps = ps > 0 ? n_prompt / ps : 0;
int eos = 151645, eos2 = 151643;
int n_out = 0;
for (int i = 0; i < max_gen && n_out < max_out; i++) {
if (n_out < max_out) out_ids[n_out++] = next;
if (next == eos || next == eos2) break;
next = qwen_forward(&g_model, next);
}
clock_gettime(CLOCK_MONOTONIC, &t1);
double ds = timespec_diff(&t_pre, &t1);
int gen_tokens = n_out > 1 ? n_out - 1 : 0;
*decode_tps = ds > 0 ? gen_tokens / ds : 0;
return n_out;
}
// --- Stdin server mode ---
static void run_stdin_server(void) {
printf("READY\n");
fflush(stdout);
char line[65536];
while (fgets(line, sizeof(line), stdin)) {
// Format: "token_id token_id ... [|max_tokens]"
int max_gen = 50;
char *pipe = strchr(line, '|');
if (pipe) {
max_gen = atoi(pipe + 1);
*pipe = '\0';
}
int prompt_ids[2048];
int n_prompt = parse_tokens(line, prompt_ids, 2048);
if (n_prompt == 0) {
printf("ERR: empty prompt\n");
fflush(stdout);
continue;
}
int out_ids[4096];
double p_tps, d_tps;
int n_out = generate(prompt_ids, n_prompt, max_gen, out_ids, 4096, &p_tps, &d_tps);
printf("OUT:");
for (int i = 0; i < n_out; i++) printf(" %d", out_ids[i]);
printf("\n");
printf("PERF: prefill=%.1f decode=%.1f prompt=%d gen=%d\n",
p_tps, d_tps, n_prompt, n_out);
fflush(stdout);
qwen_reset(&g_model);
}
}
// --- Socket server mode ---
static void run_socket_server(const char *sock_path) {
g_sock_path = sock_path;
signal(SIGINT, handle_signal);
signal(SIGTERM, handle_signal);
atexit(cleanup_socket);
unlink(sock_path);
int srv = socket(AF_UNIX, SOCK_STREAM, 0);
if (srv < 0) { perror("socket"); return; }
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1);
if (bind(srv, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
perror("bind"); close(srv); return;
}
if (listen(srv, 4) < 0) {
perror("listen"); close(srv); return;
}
printf("Listening on %s\n", sock_path);
printf("READY\n");
fflush(stdout);
while (1) {
int client = accept(srv, NULL, NULL);
if (client < 0) { perror("accept"); continue; }
// Read request: {"tokens": [1,2,3], "max_tokens": 50}
char buf[131072];
ssize_t total = 0;
while (total < (ssize_t)sizeof(buf) - 1) {
ssize_t n = read(client, buf + total, sizeof(buf) - 1 - total);
if (n <= 0) break;
total += n;
if (memchr(buf, '\n', total) || memchr(buf, '}', total)) break;
}
buf[total] = '\0';
// Minimal JSON parsing for {"tokens": [...], "max_tokens": N}
int prompt_ids[2048];
int n_prompt = 0;
int max_gen = 50;
char *tok_start = strstr(buf, "\"tokens\"");
if (tok_start) {
char *bracket = strchr(tok_start, '[');
if (bracket) {
char *p = bracket + 1;
while (*p && *p != ']' && n_prompt < 2048) {
while (*p && (*p == ' ' || *p == ',')) p++;
if (*p == ']') break;
prompt_ids[n_prompt++] = (int)strtol(p, &p, 10);
}
}
}
char *mt = strstr(buf, "\"max_tokens\"");
if (mt) {
char *colon = strchr(mt, ':');
if (colon) max_gen = (int)strtol(colon + 1, NULL, 10);
}
if (n_prompt == 0) {
const char *err = "{\"error\": \"no tokens\"}\n";
write(client, err, strlen(err));
close(client);
continue;
}
int out_ids[4096];
double p_tps, d_tps;
int n_out = generate(prompt_ids, n_prompt, max_gen, out_ids, 4096, &p_tps, &d_tps);
// Build JSON response
char resp[131072];
int off = snprintf(resp, sizeof(resp),
"{\"output\": [");
for (int i = 0; i < n_out; i++)
off += snprintf(resp + off, sizeof(resp) - off,
"%s%d", i ? ", " : "", out_ids[i]);
off += snprintf(resp + off, sizeof(resp) - off,
"], \"prefill_tps\": %.1f, \"decode_tps\": %.1f, "
"\"prompt_tokens\": %d, \"gen_tokens\": %d}\n",
p_tps, d_tps, n_prompt, n_out);
write(client, resp, off);
close(client);
printf("[socket] prompt=%d gen=%d prefill=%.1f decode=%.1f t/s\n",
n_prompt, n_out, p_tps, d_tps);
fflush(stdout);
qwen_reset(&g_model);
}
}
int main(int argc, char **argv) { int main(int argc, char **argv) {
@autoreleasepool { @autoreleasepool {
if (argc < 3) { if (argc < 2) {
fprintf(stderr, "Usage: %s <weights.bin> <prompt>\n", argv[0]); fprintf(stderr,
"Usage:\n"
" %s <weights.bin> \"token_ids\" [max_tokens] (single-shot)\n"
" %s <weights.bin> --server (stdin loop)\n"
" %s <weights.bin> --server /tmp/qwen_ane.sock (socket server)\n",
argv[0], argv[0], argv[0]);
return 1; return 1;
} }
printf("=== Qwen2.5-0.5B ANE Inference ===\n\n"); printf("=== Qwen2.5-0.5B ANE Inference ===\n\n");
// Load weights setbuf(stdout, NULL);
printf("Loading weights...\n"); printf("Loading weights...\n");
if (load_weights(argv[1]) != 0) return 1; if (load_weights(argv[1]) != 0) return 1;
// Allocate buffers
qwen_alloc(&g_model); qwen_alloc(&g_model);
// Compile ANE kernels
printf("Compiling ANE kernels (169 total)...\n"); printf("Compiling ANE kernels (169 total)...\n");
struct timespec t0, t1; struct timespec t0, t1;
clock_gettime(CLOCK_MONOTONIC, &t0); clock_gettime(CLOCK_MONOTONIC, &t0);
qwen_compile_kernels(&g_model); qwen_compile_kernels(&g_model);
clock_gettime(CLOCK_MONOTONIC, &t1); clock_gettime(CLOCK_MONOTONIC, &t1);
double compile_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9; double compile_sec = timespec_diff(&t0, &t1);
printf("Compile time: %.1fs\n\n", compile_sec); printf("Compile time: %.1fs\n\n", compile_sec);
// Parse token IDs from argv[2] (space-separated) // Check for --server flag
// argv[3] = max generation tokens int server_mode = 0;
int max_gen = 50; const char *sock_path = NULL;
if (argc >= 4) max_gen = atoi(argv[3]); for (int i = 2; i < argc; i++) {
if (strcmp(argv[i], "--server") == 0) {
// Parse input token IDs server_mode = 1;
int prompt_ids[2048]; if (i + 1 < argc && argv[i+1][0] != '-')
int n_prompt = 0; sock_path = argv[++i];
char *tok_str = strdup(argv[2]); }
char *saveptr;
char *p = strtok_r(tok_str, " ", &saveptr);
while (p && n_prompt < 2048) {
prompt_ids[n_prompt++] = atoi(p);
p = strtok_r(NULL, " ", &saveptr);
} }
free(tok_str);
if (server_mode) {
if (sock_path)
run_socket_server(sock_path);
else
run_stdin_server();
return 0;
}
// Single-shot mode (original behavior)
if (argc < 3) {
fprintf(stderr, "Error: provide token IDs or --server\n");
return 1;
}
int max_gen = 50;
if (argc >= 4 && strcmp(argv[3], "--server") != 0)
max_gen = atoi(argv[3]);
int prompt_ids[2048];
int n_prompt = parse_tokens(argv[2], prompt_ids, 2048);
printf("Prompt: %d tokens, generating up to %d\n", n_prompt, max_gen); printf("Prompt: %d tokens, generating up to %d\n", n_prompt, max_gen);
clock_gettime(CLOCK_MONOTONIC, &t0); int out_ids[4096];
double p_tps, d_tps;
int n_out = generate(prompt_ids, n_prompt, max_gen, out_ids, 4096, &p_tps, &d_tps);
// Prefill: feed all prompt tokens
int next = 0;
for (int i = 0; i < n_prompt; i++) {
next = qwen_forward(&g_model, prompt_ids[i]);
}
struct timespec t_prefill;
clock_gettime(CLOCK_MONOTONIC, &t_prefill);
double prefill_sec = (t_prefill.tv_sec - t0.tv_sec) + (t_prefill.tv_nsec - t0.tv_nsec) / 1e9;
printf("Prefill: %d tokens in %.2fs (%.1f t/s)\n", n_prompt, prefill_sec, n_prompt / prefill_sec);
// Generate
int eos = 151645; // <|im_end|>
int eos2 = 151643; // <|endoftext|>
printf("OUT:"); printf("OUT:");
for (int i = 0; i < max_gen; i++) { for (int i = 0; i < n_out; i++) printf(" %d", out_ids[i]);
printf(" %d", next);
fflush(stdout);
if (next == eos || next == eos2) break;
next = qwen_forward(&g_model, next);
}
printf("\n"); printf("\n");
clock_gettime(CLOCK_MONOTONIC, &t1); printf("\nPrefill: %.1f t/s (%d tokens)\n", p_tps, n_prompt);
double gen_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9; printf("Decode: %.1f t/s (%d tokens)\n", d_tps, n_out > 1 ? n_out - 1 : 0);
int total_tokens = g_model.pos;
int gen_tokens = total_tokens - n_prompt;
double decode_sec = gen_sec - prefill_sec;
printf("\nTotal: %d tokens in %.2fs\n", total_tokens, gen_sec);
printf("Prefill: %.1f t/s (%d tokens)\n", n_prompt / prefill_sec, n_prompt);
printf("Decode: %.1f t/s (%d tokens)\n",
decode_sec > 0 ? gen_tokens / decode_sec : 0, gen_tokens);
return 0; return 0;
} }

View File

@ -433,3 +433,11 @@ static void qwen_alloc(QwenModel *m) {
} }
m->pos = 0; m->pos = 0;
} }
static void qwen_reset(QwenModel *m) {
for (int l = 0; l < QWEN_LAYERS; l++) {
memset(m->kv_cache_k[l], 0, QWEN_MAX_SEQ * QWEN_KV_DIM * sizeof(float));
memset(m->kv_cache_v[l], 0, QWEN_MAX_SEQ * QWEN_KV_DIM * sizeof(float));
}
m->pos = 0;
}

View File

@ -1,12 +1,21 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Run Qwen2.5-0.5B on ANE with proper tokenization. """Run Qwen2.5-0.5B on ANE with proper tokenization.
Auto-connects to a running socket server for instant responses (~0ms startup).
Falls back to subprocess mode if no server is running (~6s startup per call).
Usage: Usage:
python3 run.py "Your prompt here" [--max-tokens 50] python3 run.py "Your prompt here" [--max-tokens 50]
Server mode (start server first in another terminal):
./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
python3 run.py "Your prompt here"
""" """
import argparse import argparse
import ctypes import json
import struct import os
import socket
import subprocess
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
@ -14,12 +23,66 @@ from pathlib import Path
INFERENCE_DIR = Path(__file__).parent INFERENCE_DIR = Path(__file__).parent
WEIGHTS_PATH = INFERENCE_DIR / "qwen05b.bin" WEIGHTS_PATH = INFERENCE_DIR / "qwen05b.bin"
MODEL_DIR = Path.home() / "models" / "Qwen2.5-0.5B-Instruct" MODEL_DIR = Path.home() / "models" / "Qwen2.5-0.5B-Instruct"
DEFAULT_SOCK = "/tmp/qwen_ane.sock"
def query_socket(token_ids: list[int], max_tokens: int, sock_path: str = DEFAULT_SOCK) -> dict | None:
"""Send a request to the socket server. Returns parsed JSON or None on failure."""
try:
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.settimeout(120)
s.connect(sock_path)
req = json.dumps({"tokens": token_ids, "max_tokens": max_tokens}) + "\n"
s.sendall(req.encode())
data = b""
while True:
chunk = s.recv(131072)
if not chunk:
break
data += chunk
if b"\n" in data:
break
s.close()
return json.loads(data.decode().strip())
except (ConnectionRefusedError, FileNotFoundError, OSError):
return None
def query_subprocess(token_ids: list[int], max_tokens: int) -> dict | None:
"""Fall back to spawning the binary as a subprocess."""
binary = str(INFERENCE_DIR / "qwen_ane")
if not os.path.exists(binary):
print(f"Binary not found: {binary}", file=sys.stderr)
return None
result = subprocess.run(
[binary, str(WEIGHTS_PATH),
" ".join(str(t) for t in token_ids),
str(max_tokens)],
capture_output=True, text=True, timeout=120,
)
print(result.stdout)
if result.stderr:
print(result.stderr[:500], file=sys.stderr)
output_ids = []
for line in result.stdout.split("\n"):
if line.startswith("OUT:"):
ids = [int(x) for x in line[4:].split() if x.lstrip("-").isdigit()]
output_ids.extend(ids)
return {"output": output_ids} if output_ids else None
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser(description="Qwen2.5-0.5B ANE inference")
parser.add_argument("prompt", type=str) parser.add_argument("prompt", type=str)
parser.add_argument("--max-tokens", type=int, default=50) parser.add_argument("--max-tokens", type=int, default=50)
parser.add_argument("--no-server", action="store_true",
help="Force subprocess mode even if server is running")
parser.add_argument("--sock", type=str, default=DEFAULT_SOCK,
help="Socket path for server mode")
args = parser.parse_args() args = parser.parse_args()
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -27,47 +90,43 @@ def main():
print("Loading tokenizer...") print("Loading tokenizer...")
tok = AutoTokenizer.from_pretrained(str(MODEL_DIR), trust_remote_code=True) tok = AutoTokenizer.from_pretrained(str(MODEL_DIR), trust_remote_code=True)
# Build chat template
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant. Be concise."}, {"role": "system", "content": "You are a helpful assistant. Be concise."},
{"role": "user", "content": args.prompt}, {"role": "user", "content": args.prompt},
] ]
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
input_ids = tok.encode(text) input_ids = tok.encode(text)
print(f"Prompt tokens: {len(input_ids)}") print(f"Prompt: {len(input_ids)} tokens")
# Run the C binary — pass token IDs as arguments # Try socket server first (instant response)
import subprocess result = None
binary = str(INFERENCE_DIR / "qwen_ane") if not args.no_server and os.path.exists(args.sock):
print(f"Connecting to server at {args.sock}...")
t0 = time.time()
result = query_socket(input_ids, args.max_tokens, args.sock)
elapsed = time.time() - t0
if result:
print(f"Server responded in {elapsed:.3f}s")
else:
print("Server not responding, falling back to subprocess...")
# We need to modify the binary to accept token IDs as input # Fall back to subprocess
# For now, print the token IDs so we can verify tokenization if result is None:
print(f"First 10 tokens: {input_ids[:10]}") print("Running inference (subprocess mode, ~6s startup)...")
print(f"Token text: {[tok.decode([t]) for t in input_ids[:10]]}") result = query_subprocess(input_ids, args.max_tokens)
print(f"\nRunning ANE inference with {len(input_ids)} prompt tokens + {args.max_tokens} generation...")
# Call binary with token IDs piped via stdin if not result or "output" not in result:
result = subprocess.run( print("(No output received)", file=sys.stderr)
[binary, str(WEIGHTS_PATH), " ".join(str(t) for t in input_ids), return
str(args.max_tokens)],
capture_output=True, text=True, timeout=120,
)
print(result.stdout)
if result.stderr:
print(result.stderr[:500], file=sys.stderr)
# Parse output token IDs from binary stdout
output_ids = []
for line in result.stdout.split("\n"):
if line.startswith("OUT:"):
ids = [int(x) for x in line[4:].split() if x.isdigit()]
output_ids.extend(ids)
output_ids = result["output"]
if output_ids: if output_ids:
decoded = tok.decode(output_ids, skip_special_tokens=True) decoded = tok.decode(output_ids, skip_special_tokens=True)
print(f"\n=== Response ===\n{decoded}") print(f"\n=== Response ===\n{decoded}")
else:
print("\n(No output tokens parsed — binary may need token ID input mode)") if "prefill_tps" in result:
print(f"\nPrefill: {result['prefill_tps']:.1f} t/s | "
f"Decode: {result['decode_tps']:.1f} t/s")
if __name__ == "__main__": if __name__ == "__main__":