ANE/training/test_rmsnorm_bwd.m

124 lines
4.6 KiB
Objective-C

// test_rmsnorm_bwd.m — Test RMSNorm backward ANE kernel vs CPU reference
// Build: xcrun clang -O2 -framework Foundation -framework IOSurface \
// -framework CoreML -framework Accelerate -ldl -lobjc \
// -o test_rmsnorm_bwd test_rmsnorm_bwd.m
#include "ane_rmsnorm_bwd.h"
#include "stories_cpu_ops.h"
int main(void) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
printf("=== Test: RMSNorm Backward on ANE ===\n");
printf("DIM=%d SEQ=%d\n\n", DIM, SEQ);
// Allocate test data
float *x = (float*)malloc(DIM * SEQ * 4);
float *dy = (float*)malloc(DIM * SEQ * 4);
float *w = (float*)malloc(DIM * 4);
float *dx_cpu = (float*)calloc(DIM * SEQ, 4);
float *dw_cpu = (float*)calloc(DIM, 4);
float *dx_ane = (float*)malloc(DIM * SEQ * 4);
// Random init (channel-first [DIM, SEQ])
srand48(42);
for (int i = 0; i < DIM * SEQ; i++) {
x[i] = (float)(drand48() * 2 - 1) * 0.5f;
dy[i] = (float)(drand48() * 2 - 1) * 0.1f;
}
for (int i = 0; i < DIM; i++) {
w[i] = (float)(drand48() * 0.5 + 0.75); // close to 1.0
}
// === CPU Reference ===
uint64_t t0 = mach_absolute_time();
rmsnorm_bwd(dx_cpu, dw_cpu, dy, x, w, DIM, SEQ);
uint64_t t1 = mach_absolute_time();
printf("CPU rmsnorm_bwd: %.2f ms\n", tb_ms(t1 - t0));
// === ANE Kernel ===
printf("Compiling ANE rmsnorm_bwd kernel...\n");
NSString *mil = gen_rmsnorm_bwd();
// Build weight blob for RMSNorm weights
NSData *rms_blob = build_blob(w, 1, DIM);
int in_bytes = 2 * DIM * SEQ * 2; // concat(dy, x) in fp16
int out_bytes = DIM * SEQ * 2; // dx in fp16
Kern *kern = compile_kern_mil_w(mil, (@{
@"@model_path/weights/rms_w.bin": @{@"offset":@0, @"data":rms_blob},
}), in_bytes, out_bytes);
if (!kern) {
printf("FAIL: ANE kernel compilation failed!\n");
return 1;
}
printf("Compile OK (compiles=%d)\n", g_compile_count);
// Write input: concat(dy, x) into ioIn
// dy goes at channel offset 0, x goes at channel offset DIM
io_write_fp16_at(kern->ioIn, 0, dy, DIM, SEQ);
io_write_fp16_at(kern->ioIn, DIM, x, DIM, SEQ);
// Evaluate
t0 = mach_absolute_time();
ane_eval(kern);
t1 = mach_absolute_time();
printf("ANE eval: %.3f ms\n", tb_ms(t1 - t0));
// Read output
io_read_fp16(kern->ioOut, dx_ane, 0, DIM, SEQ);
// === Compare ===
float max_err = 0, sum_err = 0;
int max_i = 0, max_j = 0;
for (int i = 0; i < DIM; i++) {
for (int j = 0; j < SEQ; j++) {
int idx = i * SEQ + j;
float err = fabsf(dx_cpu[idx] - dx_ane[idx]);
sum_err += err;
if (err > max_err) {
max_err = err;
max_i = i; max_j = j;
}
}
}
float mean_err = sum_err / (DIM * SEQ);
printf("\n=== Results ===\n");
printf("Max absolute error: %.6f at [%d,%d] (CPU=%.6f ANE=%.6f)\n",
max_err, max_i, max_j, dx_cpu[max_i*SEQ+max_j], dx_ane[max_i*SEQ+max_j]);
printf("Mean absolute error: %.6f\n", mean_err);
// Sample outputs
printf("\nSample dx values (first 4 channels, first 4 positions):\n");
printf("%-6s %-12s %-12s %-10s\n", "Idx", "CPU", "ANE", "Error");
for (int i = 0; i < 4 && i < DIM; i++) {
for (int j = 0; j < 4 && j < SEQ; j++) {
int idx = i * SEQ + j;
printf("[%d,%d] %-12.6f %-12.6f %-10.6f\n",
i, j, dx_cpu[idx], dx_ane[idx], fabsf(dx_cpu[idx] - dx_ane[idx]));
}
}
// Benchmark: multiple evals
int N = 100;
t0 = mach_absolute_time();
for (int i = 0; i < N; i++) ane_eval(kern);
t1 = mach_absolute_time();
printf("\nBenchmark: %d evals in %.2f ms (%.3f ms/eval)\n",
N, tb_ms(t1-t0), tb_ms(t1-t0)/N);
// Pass/fail
bool pass = max_err < 0.05f && mean_err < 0.01f;
printf("\n%s (threshold: max<0.05, mean<0.01)\n", pass ? "PASS ✅" : "FAIL ❌");
free_kern(kern);
free(x); free(dy); free(w); free(dx_cpu); free(dw_cpu); free(dx_ane);
return pass ? 0 : 1;
}
}