mirror of https://github.com/maderix/ANE.git
124 lines
4.6 KiB
Objective-C
124 lines
4.6 KiB
Objective-C
// test_rmsnorm_bwd.m — Test RMSNorm backward ANE kernel vs CPU reference
|
|
// Build: xcrun clang -O2 -framework Foundation -framework IOSurface \
|
|
// -framework CoreML -framework Accelerate -ldl -lobjc \
|
|
// -o test_rmsnorm_bwd test_rmsnorm_bwd.m
|
|
#include "ane_rmsnorm_bwd.h"
|
|
#include "stories_cpu_ops.h"
|
|
|
|
int main(void) {
|
|
@autoreleasepool {
|
|
setbuf(stdout, NULL);
|
|
ane_init();
|
|
mach_timebase_info(&g_tb);
|
|
|
|
printf("=== Test: RMSNorm Backward on ANE ===\n");
|
|
printf("DIM=%d SEQ=%d\n\n", DIM, SEQ);
|
|
|
|
// Allocate test data
|
|
float *x = (float*)malloc(DIM * SEQ * 4);
|
|
float *dy = (float*)malloc(DIM * SEQ * 4);
|
|
float *w = (float*)malloc(DIM * 4);
|
|
float *dx_cpu = (float*)calloc(DIM * SEQ, 4);
|
|
float *dw_cpu = (float*)calloc(DIM, 4);
|
|
float *dx_ane = (float*)malloc(DIM * SEQ * 4);
|
|
|
|
// Random init (channel-first [DIM, SEQ])
|
|
srand48(42);
|
|
for (int i = 0; i < DIM * SEQ; i++) {
|
|
x[i] = (float)(drand48() * 2 - 1) * 0.5f;
|
|
dy[i] = (float)(drand48() * 2 - 1) * 0.1f;
|
|
}
|
|
for (int i = 0; i < DIM; i++) {
|
|
w[i] = (float)(drand48() * 0.5 + 0.75); // close to 1.0
|
|
}
|
|
|
|
// === CPU Reference ===
|
|
uint64_t t0 = mach_absolute_time();
|
|
rmsnorm_bwd(dx_cpu, dw_cpu, dy, x, w, DIM, SEQ);
|
|
uint64_t t1 = mach_absolute_time();
|
|
printf("CPU rmsnorm_bwd: %.2f ms\n", tb_ms(t1 - t0));
|
|
|
|
// === ANE Kernel ===
|
|
printf("Compiling ANE rmsnorm_bwd kernel...\n");
|
|
NSString *mil = gen_rmsnorm_bwd();
|
|
|
|
// Build weight blob for RMSNorm weights
|
|
NSData *rms_blob = build_blob(w, 1, DIM);
|
|
|
|
int in_bytes = 2 * DIM * SEQ * 2; // concat(dy, x) in fp16
|
|
int out_bytes = DIM * SEQ * 2; // dx in fp16
|
|
|
|
Kern *kern = compile_kern_mil_w(mil, (@{
|
|
@"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":rms_blob},
|
|
}), in_bytes, out_bytes);
|
|
|
|
if (!kern) {
|
|
printf("FAIL: ANE kernel compilation failed!\n");
|
|
return 1;
|
|
}
|
|
printf("Compile OK (compiles=%d)\n", g_compile_count);
|
|
|
|
// Write input: concat(dy, x) into ioIn
|
|
// dy goes at channel offset 0, x goes at channel offset DIM
|
|
io_write_fp16_at(kern->ioIn, 0, dy, DIM, SEQ);
|
|
io_write_fp16_at(kern->ioIn, DIM, x, DIM, SEQ);
|
|
|
|
// Evaluate
|
|
t0 = mach_absolute_time();
|
|
ane_eval(kern);
|
|
t1 = mach_absolute_time();
|
|
printf("ANE eval: %.3f ms\n", tb_ms(t1 - t0));
|
|
|
|
// Read output
|
|
io_read_fp16(kern->ioOut, dx_ane, 0, DIM, SEQ);
|
|
|
|
// === Compare ===
|
|
float max_err = 0, sum_err = 0;
|
|
int max_i = 0, max_j = 0;
|
|
for (int i = 0; i < DIM; i++) {
|
|
for (int j = 0; j < SEQ; j++) {
|
|
int idx = i * SEQ + j;
|
|
float err = fabsf(dx_cpu[idx] - dx_ane[idx]);
|
|
sum_err += err;
|
|
if (err > max_err) {
|
|
max_err = err;
|
|
max_i = i; max_j = j;
|
|
}
|
|
}
|
|
}
|
|
float mean_err = sum_err / (DIM * SEQ);
|
|
|
|
printf("\n=== Results ===\n");
|
|
printf("Max absolute error: %.6f at [%d,%d] (CPU=%.6f ANE=%.6f)\n",
|
|
max_err, max_i, max_j, dx_cpu[max_i*SEQ+max_j], dx_ane[max_i*SEQ+max_j]);
|
|
printf("Mean absolute error: %.6f\n", mean_err);
|
|
|
|
// Sample outputs
|
|
printf("\nSample dx values (first 4 channels, first 4 positions):\n");
|
|
printf("%-6s %-12s %-12s %-10s\n", "Idx", "CPU", "ANE", "Error");
|
|
for (int i = 0; i < 4 && i < DIM; i++) {
|
|
for (int j = 0; j < 4 && j < SEQ; j++) {
|
|
int idx = i * SEQ + j;
|
|
printf("[%d,%d] %-12.6f %-12.6f %-10.6f\n",
|
|
i, j, dx_cpu[idx], dx_ane[idx], fabsf(dx_cpu[idx] - dx_ane[idx]));
|
|
}
|
|
}
|
|
|
|
// Benchmark: multiple evals
|
|
int N = 100;
|
|
t0 = mach_absolute_time();
|
|
for (int i = 0; i < N; i++) ane_eval(kern);
|
|
t1 = mach_absolute_time();
|
|
printf("\nBenchmark: %d evals in %.2f ms (%.3f ms/eval)\n",
|
|
N, tb_ms(t1-t0), tb_ms(t1-t0)/N);
|
|
|
|
// Pass/fail
|
|
bool pass = max_err < 0.05f && mean_err < 0.01f;
|
|
printf("\n%s (threshold: max<0.05, mean<0.01)\n", pass ? "PASS ✅" : "FAIL ❌");
|
|
|
|
free_kern(kern);
|
|
free(x); free(dy); free(w); free(dx_cpu); free(dw_cpu); free(dx_ane);
|
|
return pass ? 0 : 1;
|
|
}
|
|
}
|