mirror of https://github.com/maderix/ANE.git
wire up fp16 I/O retry in train.m forward path
This commit is contained in:
parent
0cf13e2b84
commit
2d2adacf09
|
|
@ -9,22 +9,43 @@
|
|||
// Transpose back to [S, out_dim] row-major
|
||||
static void ane_conv_eval(ANEKernel *kernel, const float *x, float *y,
|
||||
int S, int in_dim, int out_dim) {
|
||||
float *x_t = (float*)malloc(S * in_dim * sizeof(float));
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < in_dim; i++)
|
||||
x_t[i*S + t] = x[t*in_dim + i];
|
||||
if (g_fp16_io) {
|
||||
// fp16 I/O path: transpose + convert float→fp16, write, eval, read fp16→float + transpose
|
||||
_Float16 *x_t = (_Float16*)malloc(S * in_dim * sizeof(_Float16));
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < in_dim; i++)
|
||||
x_t[i*S + t] = (_Float16)x[t*in_dim + i];
|
||||
|
||||
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float));
|
||||
ane_eval(kernel);
|
||||
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(_Float16));
|
||||
ane_eval(kernel);
|
||||
|
||||
float *y_t = (float*)malloc(S * out_dim * sizeof(float));
|
||||
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float));
|
||||
_Float16 *y_t = (_Float16*)malloc(S * out_dim * sizeof(_Float16));
|
||||
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(_Float16));
|
||||
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < out_dim; i++)
|
||||
y[t*out_dim + i] = y_t[i*S + t];
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < out_dim; i++)
|
||||
y[t*out_dim + i] = (float)y_t[i*S + t];
|
||||
|
||||
free(x_t); free(y_t);
|
||||
free(x_t); free(y_t);
|
||||
} else {
|
||||
// fp32 I/O path: transpose, write, eval, read, transpose back
|
||||
float *x_t = (float*)malloc(S * in_dim * sizeof(float));
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < in_dim; i++)
|
||||
x_t[i*S + t] = x[t*in_dim + i];
|
||||
|
||||
ane_write_input(kernel, 0, x_t, S * in_dim * sizeof(float));
|
||||
ane_eval(kernel);
|
||||
|
||||
float *y_t = (float*)malloc(S * out_dim * sizeof(float));
|
||||
ane_read_output(kernel, 0, y_t, S * out_dim * sizeof(float));
|
||||
|
||||
for (int t = 0; t < S; t++)
|
||||
for (int i = 0; i < out_dim; i++)
|
||||
y[t*out_dim + i] = y_t[i*S + t];
|
||||
|
||||
free(x_t); free(y_t);
|
||||
}
|
||||
}
|
||||
|
||||
// CPU matmul fallback: y = W @ x, W[out_dim, in_dim], x[S, in_dim] → y[S, out_dim]
|
||||
|
|
|
|||
|
|
@ -151,8 +151,9 @@ static int model_load_weights(Model *m, const char *path) {
|
|||
static ANEKernel *compile_conv_kernel(const float *weights, int in_ch, int out_ch, int spatial) {
|
||||
NSData *wb = mil_build_weight_blob(weights, out_ch, in_ch);
|
||||
NSString *mil = mil_gen_conv(in_ch, out_ch, spatial);
|
||||
size_t inBytes = (size_t)in_ch * spatial * 4;
|
||||
size_t outBytes = (size_t)out_ch * spatial * 4;
|
||||
size_t bpe = g_fp16_io ? 2 : 4;
|
||||
size_t inBytes = (size_t)in_ch * spatial * bpe;
|
||||
size_t outBytes = (size_t)out_ch * spatial * bpe;
|
||||
return ane_compile([mil dataUsingEncoding:NSUTF8StringEncoding], wb, 1, &inBytes, 1, &outBytes);
|
||||
}
|
||||
|
||||
|
|
@ -161,9 +162,31 @@ static int model_compile_kernels(Model *m, int seq_len) {
|
|||
m->seq_len = seq_len;
|
||||
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size;
|
||||
int S = seq_len;
|
||||
printf("Compiling %d ANE conv kernels (S=%d)...\n", N_LAYERS * 7 + 1, S);
|
||||
printf("Compiling %d ANE conv kernels (S=%d, %s I/O)...\n",
|
||||
N_LAYERS * 7 + 1, S, g_fp16_io ? "fp16" : "fp32");
|
||||
|
||||
for (int l = 0; l < N_LAYERS; l++) {
|
||||
// Try first layer as canary — if cast op fails, retry with fp16 I/O
|
||||
m->kern_q[0] = compile_conv_kernel(m->wq[0], d, d, S);
|
||||
if (!m->kern_q[0] && !g_fp16_io) {
|
||||
printf(" Compile failed, retrying with fp16 I/O (M1/M2 fallback)...\n");
|
||||
g_fp16_io = 1;
|
||||
m->kern_q[0] = compile_conv_kernel(m->wq[0], d, d, S);
|
||||
}
|
||||
if (!m->kern_q[0]) { fprintf(stderr, "L0 kern_q fail\n"); return -1; }
|
||||
|
||||
m->kern_k[0] = compile_conv_kernel(m->wk[0], d, d, S);
|
||||
m->kern_v[0] = compile_conv_kernel(m->wv[0], d, d, S);
|
||||
m->kern_o[0] = compile_conv_kernel(m->wo[0], d, d, S);
|
||||
m->kern_w1[0] = compile_conv_kernel(m->w1[0], d, hd, S);
|
||||
m->kern_w2[0] = compile_conv_kernel(m->w2[0], hd, d, S);
|
||||
m->kern_w3[0] = compile_conv_kernel(m->w3[0], d, hd, S);
|
||||
if (!m->kern_k[0] || !m->kern_v[0] || !m->kern_o[0] ||
|
||||
!m->kern_w1[0] || !m->kern_w2[0] || !m->kern_w3[0]) {
|
||||
fprintf(stderr, "L0 compile fail\n"); return -1;
|
||||
}
|
||||
printf(" Layer 0 OK\n");
|
||||
|
||||
for (int l = 1; l < N_LAYERS; l++) {
|
||||
m->kern_q[l] = compile_conv_kernel(m->wq[l], d, d, S);
|
||||
m->kern_k[l] = compile_conv_kernel(m->wk[l], d, d, S);
|
||||
m->kern_v[l] = compile_conv_kernel(m->wv[l], d, d, S);
|
||||
|
|
@ -171,20 +194,18 @@ static int model_compile_kernels(Model *m, int seq_len) {
|
|||
m->kern_w1[l] = compile_conv_kernel(m->w1[l], d, hd, S);
|
||||
m->kern_w2[l] = compile_conv_kernel(m->w2[l], hd, d, S);
|
||||
m->kern_w3[l] = compile_conv_kernel(m->w3[l], d, hd, S);
|
||||
if (!m->kern_q[l]) { fprintf(stderr, "L%d kern_q fail\n",l); return -1; }
|
||||
if (!m->kern_k[l]) { fprintf(stderr, "L%d kern_k fail\n",l); return -1; }
|
||||
if (!m->kern_v[l]) { fprintf(stderr, "L%d kern_v fail\n",l); return -1; }
|
||||
if (!m->kern_o[l]) { fprintf(stderr, "L%d kern_o fail\n",l); return -1; }
|
||||
if (!m->kern_w1[l]) { fprintf(stderr, "L%d kern_w1 fail\n",l); return -1; }
|
||||
if (!m->kern_w2[l]) { fprintf(stderr, "L%d kern_w2 fail\n",l); return -1; }
|
||||
if (!m->kern_w3[l]) { fprintf(stderr, "L%d kern_w3 fail\n",l); return -1; }
|
||||
if (!m->kern_q[l] || !m->kern_k[l] || !m->kern_v[l] || !m->kern_o[l] ||
|
||||
!m->kern_w1[l] || !m->kern_w2[l] || !m->kern_w3[l]) {
|
||||
fprintf(stderr, "L%d compile fail\n", l); return -1;
|
||||
}
|
||||
printf(" Layer %d OK\n", l);
|
||||
}
|
||||
m->kern_cls = compile_conv_kernel(m->wcls, d, vs, S);
|
||||
if (!m->kern_cls) {
|
||||
fprintf(stderr, "Classifier kernel compile failed (dim=%d→vocab=%d too large?), using CPU for cls\n", d, vs);
|
||||
}
|
||||
printf(" All kernels compiled (%d conv + %s)\n", N_LAYERS * 7, m->kern_cls ? "cls" : "cls=CPU");
|
||||
printf(" All kernels compiled (%d conv + %s, %s I/O)\n",
|
||||
N_LAYERS * 7, m->kern_cls ? "cls" : "cls=CPU", g_fp16_io ? "fp16" : "fp32");
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue