mirror of https://github.com/maderix/ANE.git
[feat] Inference server mode: keep ANE kernels loaded between prompts (stdin loop + Unix socket server). Subsequent queries respond in ~0.5s instead of ~6s. run.py auto-connects to socket server when available.
This commit is contained in:
parent
b4d81b71d4
commit
6f16dbefca
|
|
@ -42,16 +42,60 @@ python3 convert_weights.py /path/to/Qwen2.5-0.5B-Instruct qwen05b.bin
|
|||
|
||||
# 2. Build
|
||||
xcrun clang -O2 -framework Foundation -framework IOSurface \
|
||||
-framework CoreML -framework Accelerate -ldl -lobjc \
|
||||
-framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
|
||||
-o qwen_ane main.m
|
||||
|
||||
# 3. Run (pass space-separated token IDs)
|
||||
# 3. Run (single-shot, pass space-separated token IDs)
|
||||
./qwen_ane qwen05b.bin "151644 8948 198 2610 525 264 10950 17847 13" 20
|
||||
|
||||
# 4. With tokenizer (requires transformers)
|
||||
python3 run.py "Say hello in one word."
|
||||
```
|
||||
|
||||
## Server Mode (Recommended)
|
||||
|
||||
The first invocation compiles 169 ANE kernels (~5.5s). Server mode keeps them loaded so subsequent prompts respond instantly.
|
||||
|
||||
### Socket server (best for `run.py` integration)
|
||||
|
||||
```bash
|
||||
# Terminal 1: start the server (compiles once, stays running)
|
||||
./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
|
||||
|
||||
# Terminal 2: queries are instant (~0.5s instead of ~6s)
|
||||
python3 run.py "What is 2+2?"
|
||||
python3 run.py "Capital of France?"
|
||||
python3 run.py "Count from 1 to 5"
|
||||
```
|
||||
|
||||
`run.py` auto-detects the socket at `/tmp/qwen_ane.sock` and connects to it. If no server is running, it falls back to subprocess mode (slower).
|
||||
|
||||
You can also query the socket directly:
|
||||
```bash
|
||||
echo '{"tokens": [151644, 8948, 198], "max_tokens": 50}' | nc -U /tmp/qwen_ane.sock
|
||||
```
|
||||
|
||||
Response format:
|
||||
```json
|
||||
{"output": [9707, 0, 151645], "prefill_tps": 68.4, "decode_tps": 67.8, "prompt_tokens": 28, "gen_tokens": 3}
|
||||
```
|
||||
|
||||
### Stdin server (for piping/scripting)
|
||||
|
||||
```bash
|
||||
./qwen_ane qwen05b.bin --server
|
||||
# Waits for "READY", then send lines of space-separated token IDs:
|
||||
# 151644 8948 198 2610 525|20
|
||||
# (pipe character separates max_tokens)
|
||||
```
|
||||
|
||||
### Performance comparison
|
||||
|
||||
| Mode | First prompt | Subsequent prompts |
|
||||
|------|-------------|-------------------|
|
||||
| Single-shot | ~6s | ~6s (recompiles) |
|
||||
| Server | ~6s (startup) | ~0.5s |
|
||||
|
||||
## Output
|
||||
|
||||
```
|
||||
|
|
@ -104,7 +148,6 @@ Adapting to other architectures (LLaMA, Gemma, Mistral) requires:
|
|||
## Known Limitations
|
||||
|
||||
- **CPU projections only** — ANE baked-weight conv kernels compile successfully but produce incorrect output (FP16 weight blob format mismatch). The `USE_ANE_PROJECTIONS` toggle exists but defaults to 0 (CPU via Accelerate BLAS). Fixing this would push decode speed from 82 t/s to 120+ t/s.
|
||||
- **No persistent server** — each invocation recompiles 169 kernels (~5s). A server mode that compiles once and serves via HTTP would eliminate this overhead.
|
||||
- **Single model** — hardcoded for Qwen2.5-0.5B. Needs parameterization for other sizes.
|
||||
- **f32 weights** — 1.9GB on disk. FP16 or quantized weight support would halve this.
|
||||
|
||||
|
|
|
|||
323
inference/main.m
323
inference/main.m
|
|
@ -1,29 +1,43 @@
|
|||
// main.m — Qwen2.5-0.5B inference on Apple Neural Engine
|
||||
// Compiles ANE kernels for all linear projections, runs autoregressive decode.
|
||||
// main.m -- Qwen2.5-0.5B inference on Apple Neural Engine
|
||||
// Supports three modes:
|
||||
// 1. Single-shot: ./qwen_ane weights.bin "token_ids" [max_tokens]
|
||||
// 2. Stdin server: ./qwen_ane weights.bin --server
|
||||
// 3. Socket server: ./qwen_ane weights.bin --server /tmp/qwen_ane.sock
|
||||
//
|
||||
// Build:
|
||||
// xcrun clang -O2 -framework Foundation -framework IOSurface \
|
||||
// -framework CoreML -framework Accelerate -ldl -lobjc \
|
||||
// -framework CoreML -framework Accelerate -ldl -lobjc -fobjc-arc \
|
||||
// -o qwen_ane main.m
|
||||
//
|
||||
// Run:
|
||||
// ./qwen_ane qwen05b.bin "Hello world"
|
||||
//
|
||||
#import <Foundation/Foundation.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/un.h>
|
||||
#include <unistd.h>
|
||||
#include <signal.h>
|
||||
#include "qwen_ane_infer.h"
|
||||
|
||||
int g_fp16_io = 0;
|
||||
static QwenModel g_model;
|
||||
static const char *g_sock_path = NULL;
|
||||
|
||||
static void cleanup_socket(void) {
|
||||
if (g_sock_path) unlink(g_sock_path);
|
||||
}
|
||||
|
||||
static void handle_signal(int sig) {
|
||||
(void)sig;
|
||||
cleanup_socket();
|
||||
_exit(0);
|
||||
}
|
||||
|
||||
static int load_weights(const char *path) {
|
||||
FILE *f = fopen(path, "rb");
|
||||
if (!f) { fprintf(stderr, "Cannot open %s\n", path); return -1; }
|
||||
|
||||
// Read config header
|
||||
int config[7];
|
||||
fread(config, sizeof(int), 7, f);
|
||||
int dim = config[0], hidden = config[1], n_layers = config[2];
|
||||
|
|
@ -34,11 +48,9 @@ static int load_weights(const char *path) {
|
|||
int q_dim = n_heads * QWEN_HEAD_DIM;
|
||||
int kv_dim = n_kv_heads * QWEN_HEAD_DIM;
|
||||
|
||||
// Embedding
|
||||
g_model.embed = (float*)malloc((size_t)vocab * dim * sizeof(float));
|
||||
fread(g_model.embed, sizeof(float), (size_t)vocab * dim, f);
|
||||
|
||||
// Per-layer
|
||||
for (int l = 0; l < n_layers; l++) {
|
||||
g_model.rms_att[l] = (float*)malloc(dim * sizeof(float));
|
||||
fread(g_model.rms_att[l], sizeof(float), dim, f);
|
||||
|
|
@ -49,10 +61,9 @@ static int load_weights(const char *path) {
|
|||
fread(g_model.wk[l], sizeof(float), (size_t)kv_dim * dim, f);
|
||||
g_model.wv[l] = (float*)malloc((size_t)kv_dim * dim * sizeof(float));
|
||||
fread(g_model.wv[l], sizeof(float), (size_t)kv_dim * dim, f);
|
||||
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float)); // o_proj is [dim, q_dim]
|
||||
g_model.wo[l] = (float*)malloc((size_t)q_dim * dim * sizeof(float));
|
||||
fread(g_model.wo[l], sizeof(float), (size_t)dim * q_dim, f);
|
||||
|
||||
// Q/K/V biases
|
||||
g_model.q_bias[l] = (float*)malloc(q_dim * sizeof(float));
|
||||
g_model.k_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
||||
g_model.v_bias[l] = (float*)malloc(kv_dim * sizeof(float));
|
||||
|
|
@ -74,89 +85,273 @@ static int load_weights(const char *path) {
|
|||
g_model.rms_final = (float*)malloc(dim * sizeof(float));
|
||||
fread(g_model.rms_final, sizeof(float), dim, f);
|
||||
|
||||
long file_size = ftell(f);
|
||||
fclose(f);
|
||||
printf("Weights loaded (%.0f MB)\n",
|
||||
(float)ftell(f) / 1024 / 1024);
|
||||
printf("Weights loaded (%.0f MB)\n", (float)file_size / 1024 / 1024);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Parse space-separated token IDs from a string. Returns count.
|
||||
static int parse_tokens(const char *str, int *ids, int max_ids) {
|
||||
int n = 0;
|
||||
char *buf = strdup(str);
|
||||
char *saveptr;
|
||||
char *p = strtok_r(buf, " \t\n\r", &saveptr);
|
||||
while (p && n < max_ids) {
|
||||
ids[n++] = atoi(p);
|
||||
p = strtok_r(NULL, " \t\n\r", &saveptr);
|
||||
}
|
||||
free(buf);
|
||||
return n;
|
||||
}
|
||||
|
||||
static double timespec_diff(struct timespec *a, struct timespec *b) {
|
||||
return (b->tv_sec - a->tv_sec) + (b->tv_nsec - a->tv_nsec) / 1e9;
|
||||
}
|
||||
|
||||
// Run one generation pass. Writes output token IDs to out_ids, returns count.
|
||||
// If out_fd >= 0, writes formatted results there; otherwise prints to stdout.
|
||||
static int generate(int *prompt_ids, int n_prompt, int max_gen,
|
||||
int *out_ids, int max_out,
|
||||
double *prefill_tps, double *decode_tps) {
|
||||
struct timespec t0, t1, t_pre;
|
||||
clock_gettime(CLOCK_MONOTONIC, &t0);
|
||||
|
||||
int next = 0;
|
||||
for (int i = 0; i < n_prompt; i++)
|
||||
next = qwen_forward(&g_model, prompt_ids[i]);
|
||||
|
||||
clock_gettime(CLOCK_MONOTONIC, &t_pre);
|
||||
double ps = timespec_diff(&t0, &t_pre);
|
||||
*prefill_tps = ps > 0 ? n_prompt / ps : 0;
|
||||
|
||||
int eos = 151645, eos2 = 151643;
|
||||
int n_out = 0;
|
||||
for (int i = 0; i < max_gen && n_out < max_out; i++) {
|
||||
if (n_out < max_out) out_ids[n_out++] = next;
|
||||
if (next == eos || next == eos2) break;
|
||||
next = qwen_forward(&g_model, next);
|
||||
}
|
||||
|
||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||
double ds = timespec_diff(&t_pre, &t1);
|
||||
int gen_tokens = n_out > 1 ? n_out - 1 : 0;
|
||||
*decode_tps = ds > 0 ? gen_tokens / ds : 0;
|
||||
|
||||
return n_out;
|
||||
}
|
||||
|
||||
// --- Stdin server mode ---
|
||||
static void run_stdin_server(void) {
|
||||
printf("READY\n");
|
||||
fflush(stdout);
|
||||
|
||||
char line[65536];
|
||||
while (fgets(line, sizeof(line), stdin)) {
|
||||
// Format: "token_id token_id ... [|max_tokens]"
|
||||
int max_gen = 50;
|
||||
char *pipe = strchr(line, '|');
|
||||
if (pipe) {
|
||||
max_gen = atoi(pipe + 1);
|
||||
*pipe = '\0';
|
||||
}
|
||||
|
||||
int prompt_ids[2048];
|
||||
int n_prompt = parse_tokens(line, prompt_ids, 2048);
|
||||
if (n_prompt == 0) {
|
||||
printf("ERR: empty prompt\n");
|
||||
fflush(stdout);
|
||||
continue;
|
||||
}
|
||||
|
||||
int out_ids[4096];
|
||||
double p_tps, d_tps;
|
||||
int n_out = generate(prompt_ids, n_prompt, max_gen, out_ids, 4096, &p_tps, &d_tps);
|
||||
|
||||
printf("OUT:");
|
||||
for (int i = 0; i < n_out; i++) printf(" %d", out_ids[i]);
|
||||
printf("\n");
|
||||
printf("PERF: prefill=%.1f decode=%.1f prompt=%d gen=%d\n",
|
||||
p_tps, d_tps, n_prompt, n_out);
|
||||
fflush(stdout);
|
||||
|
||||
qwen_reset(&g_model);
|
||||
}
|
||||
}
|
||||
|
||||
// --- Socket server mode ---
|
||||
static void run_socket_server(const char *sock_path) {
|
||||
g_sock_path = sock_path;
|
||||
signal(SIGINT, handle_signal);
|
||||
signal(SIGTERM, handle_signal);
|
||||
atexit(cleanup_socket);
|
||||
|
||||
unlink(sock_path);
|
||||
|
||||
int srv = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
if (srv < 0) { perror("socket"); return; }
|
||||
|
||||
struct sockaddr_un addr;
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.sun_family = AF_UNIX;
|
||||
strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1);
|
||||
|
||||
if (bind(srv, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
|
||||
perror("bind"); close(srv); return;
|
||||
}
|
||||
if (listen(srv, 4) < 0) {
|
||||
perror("listen"); close(srv); return;
|
||||
}
|
||||
|
||||
printf("Listening on %s\n", sock_path);
|
||||
printf("READY\n");
|
||||
fflush(stdout);
|
||||
|
||||
while (1) {
|
||||
int client = accept(srv, NULL, NULL);
|
||||
if (client < 0) { perror("accept"); continue; }
|
||||
|
||||
// Read request: {"tokens": [1,2,3], "max_tokens": 50}
|
||||
char buf[131072];
|
||||
ssize_t total = 0;
|
||||
while (total < (ssize_t)sizeof(buf) - 1) {
|
||||
ssize_t n = read(client, buf + total, sizeof(buf) - 1 - total);
|
||||
if (n <= 0) break;
|
||||
total += n;
|
||||
if (memchr(buf, '\n', total) || memchr(buf, '}', total)) break;
|
||||
}
|
||||
buf[total] = '\0';
|
||||
|
||||
// Minimal JSON parsing for {"tokens": [...], "max_tokens": N}
|
||||
int prompt_ids[2048];
|
||||
int n_prompt = 0;
|
||||
int max_gen = 50;
|
||||
|
||||
char *tok_start = strstr(buf, "\"tokens\"");
|
||||
if (tok_start) {
|
||||
char *bracket = strchr(tok_start, '[');
|
||||
if (bracket) {
|
||||
char *p = bracket + 1;
|
||||
while (*p && *p != ']' && n_prompt < 2048) {
|
||||
while (*p && (*p == ' ' || *p == ',')) p++;
|
||||
if (*p == ']') break;
|
||||
prompt_ids[n_prompt++] = (int)strtol(p, &p, 10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
char *mt = strstr(buf, "\"max_tokens\"");
|
||||
if (mt) {
|
||||
char *colon = strchr(mt, ':');
|
||||
if (colon) max_gen = (int)strtol(colon + 1, NULL, 10);
|
||||
}
|
||||
|
||||
if (n_prompt == 0) {
|
||||
const char *err = "{\"error\": \"no tokens\"}\n";
|
||||
write(client, err, strlen(err));
|
||||
close(client);
|
||||
continue;
|
||||
}
|
||||
|
||||
int out_ids[4096];
|
||||
double p_tps, d_tps;
|
||||
int n_out = generate(prompt_ids, n_prompt, max_gen, out_ids, 4096, &p_tps, &d_tps);
|
||||
|
||||
// Build JSON response
|
||||
char resp[131072];
|
||||
int off = snprintf(resp, sizeof(resp),
|
||||
"{\"output\": [");
|
||||
for (int i = 0; i < n_out; i++)
|
||||
off += snprintf(resp + off, sizeof(resp) - off,
|
||||
"%s%d", i ? ", " : "", out_ids[i]);
|
||||
off += snprintf(resp + off, sizeof(resp) - off,
|
||||
"], \"prefill_tps\": %.1f, \"decode_tps\": %.1f, "
|
||||
"\"prompt_tokens\": %d, \"gen_tokens\": %d}\n",
|
||||
p_tps, d_tps, n_prompt, n_out);
|
||||
|
||||
write(client, resp, off);
|
||||
close(client);
|
||||
|
||||
printf("[socket] prompt=%d gen=%d prefill=%.1f decode=%.1f t/s\n",
|
||||
n_prompt, n_out, p_tps, d_tps);
|
||||
fflush(stdout);
|
||||
|
||||
qwen_reset(&g_model);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
@autoreleasepool {
|
||||
if (argc < 3) {
|
||||
fprintf(stderr, "Usage: %s <weights.bin> <prompt>\n", argv[0]);
|
||||
if (argc < 2) {
|
||||
fprintf(stderr,
|
||||
"Usage:\n"
|
||||
" %s <weights.bin> \"token_ids\" [max_tokens] (single-shot)\n"
|
||||
" %s <weights.bin> --server (stdin loop)\n"
|
||||
" %s <weights.bin> --server /tmp/qwen_ane.sock (socket server)\n",
|
||||
argv[0], argv[0], argv[0]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("=== Qwen2.5-0.5B ANE Inference ===\n\n");
|
||||
|
||||
// Load weights
|
||||
setbuf(stdout, NULL);
|
||||
|
||||
printf("Loading weights...\n");
|
||||
if (load_weights(argv[1]) != 0) return 1;
|
||||
|
||||
// Allocate buffers
|
||||
qwen_alloc(&g_model);
|
||||
|
||||
// Compile ANE kernels
|
||||
printf("Compiling ANE kernels (169 total)...\n");
|
||||
struct timespec t0, t1;
|
||||
clock_gettime(CLOCK_MONOTONIC, &t0);
|
||||
qwen_compile_kernels(&g_model);
|
||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||
double compile_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9;
|
||||
double compile_sec = timespec_diff(&t0, &t1);
|
||||
printf("Compile time: %.1fs\n\n", compile_sec);
|
||||
|
||||
// Parse token IDs from argv[2] (space-separated)
|
||||
// argv[3] = max generation tokens
|
||||
int max_gen = 50;
|
||||
if (argc >= 4) max_gen = atoi(argv[3]);
|
||||
|
||||
// Parse input token IDs
|
||||
int prompt_ids[2048];
|
||||
int n_prompt = 0;
|
||||
char *tok_str = strdup(argv[2]);
|
||||
char *saveptr;
|
||||
char *p = strtok_r(tok_str, " ", &saveptr);
|
||||
while (p && n_prompt < 2048) {
|
||||
prompt_ids[n_prompt++] = atoi(p);
|
||||
p = strtok_r(NULL, " ", &saveptr);
|
||||
// Check for --server flag
|
||||
int server_mode = 0;
|
||||
const char *sock_path = NULL;
|
||||
for (int i = 2; i < argc; i++) {
|
||||
if (strcmp(argv[i], "--server") == 0) {
|
||||
server_mode = 1;
|
||||
if (i + 1 < argc && argv[i+1][0] != '-')
|
||||
sock_path = argv[++i];
|
||||
}
|
||||
}
|
||||
free(tok_str);
|
||||
|
||||
if (server_mode) {
|
||||
if (sock_path)
|
||||
run_socket_server(sock_path);
|
||||
else
|
||||
run_stdin_server();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Single-shot mode (original behavior)
|
||||
if (argc < 3) {
|
||||
fprintf(stderr, "Error: provide token IDs or --server\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
int max_gen = 50;
|
||||
if (argc >= 4 && strcmp(argv[3], "--server") != 0)
|
||||
max_gen = atoi(argv[3]);
|
||||
|
||||
int prompt_ids[2048];
|
||||
int n_prompt = parse_tokens(argv[2], prompt_ids, 2048);
|
||||
printf("Prompt: %d tokens, generating up to %d\n", n_prompt, max_gen);
|
||||
|
||||
clock_gettime(CLOCK_MONOTONIC, &t0);
|
||||
int out_ids[4096];
|
||||
double p_tps, d_tps;
|
||||
int n_out = generate(prompt_ids, n_prompt, max_gen, out_ids, 4096, &p_tps, &d_tps);
|
||||
|
||||
// Prefill: feed all prompt tokens
|
||||
int next = 0;
|
||||
for (int i = 0; i < n_prompt; i++) {
|
||||
next = qwen_forward(&g_model, prompt_ids[i]);
|
||||
}
|
||||
|
||||
struct timespec t_prefill;
|
||||
clock_gettime(CLOCK_MONOTONIC, &t_prefill);
|
||||
double prefill_sec = (t_prefill.tv_sec - t0.tv_sec) + (t_prefill.tv_nsec - t0.tv_nsec) / 1e9;
|
||||
printf("Prefill: %d tokens in %.2fs (%.1f t/s)\n", n_prompt, prefill_sec, n_prompt / prefill_sec);
|
||||
|
||||
// Generate
|
||||
int eos = 151645; // <|im_end|>
|
||||
int eos2 = 151643; // <|endoftext|>
|
||||
printf("OUT:");
|
||||
for (int i = 0; i < max_gen; i++) {
|
||||
printf(" %d", next);
|
||||
fflush(stdout);
|
||||
if (next == eos || next == eos2) break;
|
||||
next = qwen_forward(&g_model, next);
|
||||
}
|
||||
for (int i = 0; i < n_out; i++) printf(" %d", out_ids[i]);
|
||||
printf("\n");
|
||||
|
||||
clock_gettime(CLOCK_MONOTONIC, &t1);
|
||||
double gen_sec = (t1.tv_sec - t0.tv_sec) + (t1.tv_nsec - t0.tv_nsec) / 1e9;
|
||||
int total_tokens = g_model.pos;
|
||||
int gen_tokens = total_tokens - n_prompt;
|
||||
double decode_sec = gen_sec - prefill_sec;
|
||||
printf("\nTotal: %d tokens in %.2fs\n", total_tokens, gen_sec);
|
||||
printf("Prefill: %.1f t/s (%d tokens)\n", n_prompt / prefill_sec, n_prompt);
|
||||
printf("Decode: %.1f t/s (%d tokens)\n",
|
||||
decode_sec > 0 ? gen_tokens / decode_sec : 0, gen_tokens);
|
||||
printf("\nPrefill: %.1f t/s (%d tokens)\n", p_tps, n_prompt);
|
||||
printf("Decode: %.1f t/s (%d tokens)\n", d_tps, n_out > 1 ? n_out - 1 : 0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -433,3 +433,11 @@ static void qwen_alloc(QwenModel *m) {
|
|||
}
|
||||
m->pos = 0;
|
||||
}
|
||||
|
||||
static void qwen_reset(QwenModel *m) {
|
||||
for (int l = 0; l < QWEN_LAYERS; l++) {
|
||||
memset(m->kv_cache_k[l], 0, QWEN_MAX_SEQ * QWEN_KV_DIM * sizeof(float));
|
||||
memset(m->kv_cache_v[l], 0, QWEN_MAX_SEQ * QWEN_KV_DIM * sizeof(float));
|
||||
}
|
||||
m->pos = 0;
|
||||
}
|
||||
|
|
|
|||
121
inference/run.py
121
inference/run.py
|
|
@ -1,12 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Run Qwen2.5-0.5B on ANE with proper tokenization.
|
||||
|
||||
Auto-connects to a running socket server for instant responses (~0ms startup).
|
||||
Falls back to subprocess mode if no server is running (~6s startup per call).
|
||||
|
||||
Usage:
|
||||
python3 run.py "Your prompt here" [--max-tokens 50]
|
||||
|
||||
Server mode (start server first in another terminal):
|
||||
./qwen_ane qwen05b.bin --server /tmp/qwen_ane.sock
|
||||
python3 run.py "Your prompt here"
|
||||
"""
|
||||
import argparse
|
||||
import ctypes
|
||||
import struct
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
|
@ -14,12 +23,66 @@ from pathlib import Path
|
|||
INFERENCE_DIR = Path(__file__).parent
|
||||
WEIGHTS_PATH = INFERENCE_DIR / "qwen05b.bin"
|
||||
MODEL_DIR = Path.home() / "models" / "Qwen2.5-0.5B-Instruct"
|
||||
DEFAULT_SOCK = "/tmp/qwen_ane.sock"
|
||||
|
||||
|
||||
def query_socket(token_ids: list[int], max_tokens: int, sock_path: str = DEFAULT_SOCK) -> dict | None:
|
||||
"""Send a request to the socket server. Returns parsed JSON or None on failure."""
|
||||
try:
|
||||
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
s.settimeout(120)
|
||||
s.connect(sock_path)
|
||||
req = json.dumps({"tokens": token_ids, "max_tokens": max_tokens}) + "\n"
|
||||
s.sendall(req.encode())
|
||||
|
||||
data = b""
|
||||
while True:
|
||||
chunk = s.recv(131072)
|
||||
if not chunk:
|
||||
break
|
||||
data += chunk
|
||||
if b"\n" in data:
|
||||
break
|
||||
s.close()
|
||||
return json.loads(data.decode().strip())
|
||||
except (ConnectionRefusedError, FileNotFoundError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def query_subprocess(token_ids: list[int], max_tokens: int) -> dict | None:
|
||||
"""Fall back to spawning the binary as a subprocess."""
|
||||
binary = str(INFERENCE_DIR / "qwen_ane")
|
||||
if not os.path.exists(binary):
|
||||
print(f"Binary not found: {binary}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
result = subprocess.run(
|
||||
[binary, str(WEIGHTS_PATH),
|
||||
" ".join(str(t) for t in token_ids),
|
||||
str(max_tokens)],
|
||||
capture_output=True, text=True, timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(result.stderr[:500], file=sys.stderr)
|
||||
|
||||
output_ids = []
|
||||
for line in result.stdout.split("\n"):
|
||||
if line.startswith("OUT:"):
|
||||
ids = [int(x) for x in line[4:].split() if x.lstrip("-").isdigit()]
|
||||
output_ids.extend(ids)
|
||||
|
||||
return {"output": output_ids} if output_ids else None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser(description="Qwen2.5-0.5B ANE inference")
|
||||
parser.add_argument("prompt", type=str)
|
||||
parser.add_argument("--max-tokens", type=int, default=50)
|
||||
parser.add_argument("--no-server", action="store_true",
|
||||
help="Force subprocess mode even if server is running")
|
||||
parser.add_argument("--sock", type=str, default=DEFAULT_SOCK,
|
||||
help="Socket path for server mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
|
@ -27,47 +90,43 @@ def main():
|
|||
print("Loading tokenizer...")
|
||||
tok = AutoTokenizer.from_pretrained(str(MODEL_DIR), trust_remote_code=True)
|
||||
|
||||
# Build chat template
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant. Be concise."},
|
||||
{"role": "user", "content": args.prompt},
|
||||
]
|
||||
text = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
input_ids = tok.encode(text)
|
||||
print(f"Prompt tokens: {len(input_ids)}")
|
||||
print(f"Prompt: {len(input_ids)} tokens")
|
||||
|
||||
# Run the C binary — pass token IDs as arguments
|
||||
import subprocess
|
||||
binary = str(INFERENCE_DIR / "qwen_ane")
|
||||
# Try socket server first (instant response)
|
||||
result = None
|
||||
if not args.no_server and os.path.exists(args.sock):
|
||||
print(f"Connecting to server at {args.sock}...")
|
||||
t0 = time.time()
|
||||
result = query_socket(input_ids, args.max_tokens, args.sock)
|
||||
elapsed = time.time() - t0
|
||||
if result:
|
||||
print(f"Server responded in {elapsed:.3f}s")
|
||||
else:
|
||||
print("Server not responding, falling back to subprocess...")
|
||||
|
||||
# We need to modify the binary to accept token IDs as input
|
||||
# For now, print the token IDs so we can verify tokenization
|
||||
print(f"First 10 tokens: {input_ids[:10]}")
|
||||
print(f"Token text: {[tok.decode([t]) for t in input_ids[:10]]}")
|
||||
print(f"\nRunning ANE inference with {len(input_ids)} prompt tokens + {args.max_tokens} generation...")
|
||||
# Fall back to subprocess
|
||||
if result is None:
|
||||
print("Running inference (subprocess mode, ~6s startup)...")
|
||||
result = query_subprocess(input_ids, args.max_tokens)
|
||||
|
||||
# Call binary with token IDs piped via stdin
|
||||
result = subprocess.run(
|
||||
[binary, str(WEIGHTS_PATH), " ".join(str(t) for t in input_ids),
|
||||
str(args.max_tokens)],
|
||||
capture_output=True, text=True, timeout=120,
|
||||
)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(result.stderr[:500], file=sys.stderr)
|
||||
|
||||
# Parse output token IDs from binary stdout
|
||||
output_ids = []
|
||||
for line in result.stdout.split("\n"):
|
||||
if line.startswith("OUT:"):
|
||||
ids = [int(x) for x in line[4:].split() if x.isdigit()]
|
||||
output_ids.extend(ids)
|
||||
if not result or "output" not in result:
|
||||
print("(No output received)", file=sys.stderr)
|
||||
return
|
||||
|
||||
output_ids = result["output"]
|
||||
if output_ids:
|
||||
decoded = tok.decode(output_ids, skip_special_tokens=True)
|
||||
print(f"\n=== Response ===\n{decoded}")
|
||||
else:
|
||||
print("\n(No output tokens parsed — binary may need token ID input mode)")
|
||||
|
||||
if "prefill_tps" in result:
|
||||
print(f"\nPrefill: {result['prefill_tps']:.1f} t/s | "
|
||||
f"Decode: {result['decode_tps']:.1f} t/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue