ANE/training/test_classifier.m

256 lines
12 KiB
Matlab

// test_classifier.mTest classifier matmul (32000 channels) and softmax on ANE
// This tests the riskiest operations: VOCAB-sized conv and softmax
// Build: xcrun clang -O2 -framework Foundation -framework IOSurface \
// -framework CoreML -framework Accelerate -ldl -lobjc \
// -o test_classifier test_classifier.m
#include "ane_classifier.h"
#include "stories_cpu_ops.h"
int main(void) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
printf("=== Test: Classifier + Softmax on ANE ===\n");
printf("DIM=%d SEQ=%d VOCAB=%d\n\n", DIM, SEQ, VOCAB);
// ======== Test 1: Final RMSNorm ========
printf("--- Test 1: Final RMSNorm on ANE ---\n");
{
float *x = (float*)malloc(DIM * SEQ * 4);
float *w = (float*)malloc(DIM * 4);
float *out_cpu = (float*)malloc(DIM * SEQ * 4);
float *out_ane = (float*)malloc(DIM * SEQ * 4);
srand48(42);
for (int i = 0; i < DIM * SEQ; i++) x[i] = (float)(drand48() * 2 - 1);
for (int i = 0; i < DIM; i++) w[i] = (float)(drand48() * 0.5 + 0.75);
rmsnorm(out_cpu, x, w, DIM, SEQ);
Kern *kern = compile_kern_mil_w(gen_final_rmsnorm(), (@{
@"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":build_blob(w, 1, DIM)},
}), DIM*SEQ*2, DIM*SEQ*2);
if (!kern) { printf("FAIL: Final RMSNorm compile failed\n"); return 1; }
printf("Compile OK\n");
io_write_fp16(kern->ioIn, x, DIM, SEQ);
ane_eval(kern);
io_read_fp16(kern->ioOut, out_ane, 0, DIM, SEQ);
float max_err = 0;
for (int i = 0; i < DIM*SEQ; i++) {
float e = fabsf(out_cpu[i] - out_ane[i]);
if (e > max_err) max_err = e;
}
printf("Max error: %.6f %s\n\n", max_err, max_err < 0.05 ? "PASS ✅" : "FAIL ❌");
free_kern(kern);
free(x); free(w); free(out_cpu); free(out_ane);
}
// ======== Test 2: Classifier forward (32000-channel conv) ========
printf("--- Test 2: Classifier Forward (VOCAB=%d channel conv) ---\n", VOCAB);
{
float *x_final = (float*)malloc(DIM * SEQ * 4);
float *embed = (float*)malloc((size_t)VOCAB * DIM * 4);
float *logits_cpu = (float*)malloc((size_t)VOCAB * SEQ * 4);
float *logits_ane = (float*)malloc((size_t)VOCAB * SEQ * 4);
srand48(123);
for (int i = 0; i < DIM * SEQ; i++) x_final[i] = (float)(drand48() * 2 - 1) * 0.1f;
for (size_t i = 0; i < (size_t)VOCAB * DIM; i++) embed[i] = (float)(drand48() * 2 - 1) * 0.02f;
// CPU reference: logits = embed @ x_final
// logits[v, t] = sum_d embed[v,d] * x_final[d,t]
// embed is [VOCAB, DIM] row-major, x_final is [DIM, SEQ] channel-first
uint64_t t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
VOCAB, SEQ, DIM, 1.0f,
embed, DIM, x_final, SEQ, 0.0f, logits_cpu, SEQ);
uint64_t t1 = mach_absolute_time();
printf("CPU cblas_sgemm: %.2f ms\n", tb_ms(t1-t0));
// ANE: build weight blob for embed [VOCAB, DIM]
printf("Building embed blob (%.1f MB fp16)...\n", (float)VOCAB*DIM*2/1e6);
NSData *embed_blob = build_blob(embed, VOCAB, DIM);
printf("Compiling classifier kernel...\n");
t0 = mach_absolute_time();
Kern *cls = compile_kern_mil_w(gen_classifier_fwd(), (@{
@"@model_path/weights/embed.bin": @{@"offset":@0, @"data":embed_blob},
}), DIM*SEQ*2, VOCAB*SEQ*2);
t1 = mach_absolute_time();
if (!cls) {
printf("FAIL: Classifier compile failed (32000 channels too large for ANE)\n");
printf("This confirms tiling is needed.\n\n");
} else {
printf("Compile OK in %.0f ms (compiles=%d)\n", tb_ms(t1-t0), g_compile_count);
io_write_fp16(cls->ioIn, x_final, DIM, SEQ);
t0 = mach_absolute_time();
ane_eval(cls);
t1 = mach_absolute_time();
printf("ANE eval: %.2f ms\n", tb_ms(t1-t0));
// Read back and compare (samplefull read would be 32000*256*4 = 32MB)
io_read_fp16(cls->ioOut, logits_ane, 0, VOCAB, SEQ);
float max_err = 0, sum_err = 0;
int cnt = 0;
for (int v = 0; v < VOCAB; v++) {
for (int t = 0; t < SEQ; t++) {
int idx = v*SEQ + t;
float e = fabsf(logits_cpu[idx] - logits_ane[idx]);
sum_err += e;
cnt++;
if (e > max_err) max_err = e;
}
}
printf("Max error: %.6f Mean error: %.6f %s\n",
max_err, sum_err/cnt, max_err < 1.0 ? "PASS ✅" : "FAIL ❌");
// Benchmark
int N = 10;
t0 = mach_absolute_time();
for (int i = 0; i < N; i++) ane_eval(cls);
t1 = mach_absolute_time();
printf("Benchmark: %d evals in %.2f ms (%.2f ms/eval)\n\n", N, tb_ms(t1-t0), tb_ms(t1-t0)/N);
free_kern(cls);
}
free(x_final); free(embed); free(logits_cpu); free(logits_ane);
}
// ======== Test 3: Softmax over VOCAB dimension ========
printf("--- Test 3: Softmax over VOCAB=%d ---\n", VOCAB);
{
float *logits = (float*)malloc((size_t)VOCAB * SEQ * 4);
float *probs_cpu = (float*)malloc((size_t)VOCAB * SEQ * 4);
float *probs_ane = (float*)malloc((size_t)VOCAB * SEQ * 4);
srand48(999);
for (size_t i = 0; i < (size_t)VOCAB * SEQ; i++)
logits[i] = (float)(drand48() * 10 - 5);
// CPU reference softmax (per position, over vocab)
// logits is [VOCAB, SEQ] channel-first
uint64_t t0 = mach_absolute_time();
for (int t = 0; t < SEQ; t++) {
float maxv = -1e30f;
for (int v = 0; v < VOCAB; v++) {
float val = logits[v*SEQ+t];
if (val > maxv) maxv = val;
}
float sum = 0;
for (int v = 0; v < VOCAB; v++) {
probs_cpu[v*SEQ+t] = expf(logits[v*SEQ+t] - maxv);
sum += probs_cpu[v*SEQ+t];
}
for (int v = 0; v < VOCAB; v++) probs_cpu[v*SEQ+t] /= sum;
}
uint64_t t1 = mach_absolute_time();
printf("CPU softmax: %.2f ms\n", tb_ms(t1-t0));
printf("Compiling softmax kernel...\n");
int sm_bytes = VOCAB * SEQ * 2;
Kern *sm = compile_kern_mil_w(gen_softmax_vocab(), @{}, sm_bytes, sm_bytes);
if (!sm) {
printf("FAIL: Softmax compile failed\n\n");
} else {
printf("Compile OK\n");
io_write_fp16(sm->ioIn, logits, VOCAB, SEQ);
t0 = mach_absolute_time();
ane_eval(sm);
t1 = mach_absolute_time();
printf("ANE eval: %.2f ms\n", tb_ms(t1-t0));
io_read_fp16(sm->ioOut, probs_ane, 0, VOCAB, SEQ);
// Check: probs should sum to ~1.0 per position
float max_err = 0;
for (int t = 0; t < 4; t++) {
float sum_cpu = 0, sum_ane = 0;
for (int v = 0; v < VOCAB; v++) {
sum_cpu += probs_cpu[v*SEQ+t];
sum_ane += probs_ane[v*SEQ+t];
float e = fabsf(probs_cpu[v*SEQ+t] - probs_ane[v*SEQ+t]);
if (e > max_err) max_err = e;
}
printf(" pos %d: CPU sum=%.4f ANE sum=%.4f\n", t, sum_cpu, sum_ane);
}
printf("Max error (first 4 positions): %.6f %s\n",
max_err, max_err < 0.01 ? "PASS ✅" : "FAIL ❌");
int N = 10;
t0 = mach_absolute_time();
for (int i = 0; i < N; i++) ane_eval(sm);
t1 = mach_absolute_time();
printf("Benchmark: %d evals in %.2f ms (%.2f ms/eval)\n\n", N, tb_ms(t1-t0), tb_ms(t1-t0)/N);
free_kern(sm);
}
free(logits); free(probs_cpu); free(probs_ane);
}
// ======== Test 4: Classifier backward ========
printf("--- Test 4: Classifier Backward (DIM=%d from VOCAB=%d) ---\n", DIM, VOCAB);
{
float *dlogits = (float*)malloc((size_t)VOCAB * SEQ * 4);
float *embed = (float*)malloc((size_t)VOCAB * DIM * 4);
float *dx_cpu = (float*)malloc(DIM * SEQ * 4);
float *dx_ane = (float*)malloc(DIM * SEQ * 4);
srand48(456);
for (size_t i = 0; i < (size_t)VOCAB * SEQ; i++) dlogits[i] = (float)(drand48() * 2 - 1) * 0.01f;
for (size_t i = 0; i < (size_t)VOCAB * DIM; i++) embed[i] = (float)(drand48() * 2 - 1) * 0.02f;
// CPU: dx = embed^T @ dlogits
uint64_t t0 = mach_absolute_time();
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, VOCAB, 1.0f,
embed, DIM, dlogits, SEQ, 0.0f, dx_cpu, SEQ);
uint64_t t1 = mach_absolute_time();
printf("CPU cblas_sgemm: %.2f ms\n", tb_ms(t1-t0));
// Build transposed embed blob
NSData *embed_t_blob = build_blob_t(embed, VOCAB, DIM);
printf("Compiling classifier backward...\n");
Kern *clsb = compile_kern_mil_w(gen_classifier_bwd(), (@{
@"@model_path/weights/embed_t.bin": @{@"offset":@0, @"data":embed_t_blob},
}), VOCAB*SEQ*2, DIM*SEQ*2);
if (!clsb) {
printf("FAIL: Classifier backward compile failed\n\n");
} else {
printf("Compile OK\n");
io_write_fp16(clsb->ioIn, dlogits, VOCAB, SEQ);
t0 = mach_absolute_time();
ane_eval(clsb);
t1 = mach_absolute_time();
printf("ANE eval: %.2f ms\n", tb_ms(t1-t0));
io_read_fp16(clsb->ioOut, dx_ane, 0, DIM, SEQ);
float max_err = 0, sum_err = 0;
for (int i = 0; i < DIM*SEQ; i++) {
float e = fabsf(dx_cpu[i] - dx_ane[i]);
sum_err += e;
if (e > max_err) max_err = e;
}
printf("Max error: %.6f Mean error: %.6f %s\n\n",
max_err, sum_err/(DIM*SEQ), max_err < 1.0 ? "PASS ✅" : "FAIL ❌");
free_kern(clsb);
}
free(dlogits); free(embed); free(dx_cpu); free(dx_ane);
}
printf("=== All tests complete ===\n");
printf("Total ANE compiles used: %d\n", g_compile_count);
return 0;
}
}