ANE/training/train_double_buffer.m

783 lines
42 KiB
Mathematica
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// train_double_buffer.m Double-buffered async ANE training for stories110M
// Based on train_large.m with the key innovation: compile and eval overlap via GCD
// Discovery: probe_v2.m proved ANE compile and eval can run in parallel
// Architecture: two kernel sets (A/B), background compile while active set runs
// 5 weight-bearing ANE kernels per layer × 12 layers = 60 per compile batch
#include <stdatomic.h>
#include "stories_io.h"
#include "stories_mil.h"
#include "stories_cpu_ops.h"
// Double-buffer needs more compile budget than single-buffer
// The original MAX_COMPILES=100 only allows 1 batch per exec() restart
// We push higher to allow initial compile + at least 1 background compile
// If ANE rejects at ~119, the exec() restart will handle it gracefully
#define DB_MAX_COMPILES 250
#define CKPT_PATH "ane_db_ckpt.bin"
#define MODEL_PATH "../../assets/models/stories110M.bin"
#define DATA_PATH "tinystories_data00.bin"
// ===== Weight loading from llama2.c format =====
static bool load_pretrained(LayerWeights *lw, float *rms_final, float *embed, const char *path) {
FILE *f = fopen(path, "rb");
if (!f) { printf("Cannot open %s\n", path); return false; }
Llama2Config cfg;
fread(&cfg, sizeof(cfg), 1, f);
printf(" Model config: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
cfg.dim, cfg.hidden_dim, cfg.n_layers, cfg.n_heads, abs(cfg.vocab_size), cfg.seq_len);
if (cfg.dim != DIM || cfg.hidden_dim != HIDDEN || cfg.n_layers != NLAYERS) {
printf(" ERROR: Config mismatch! Expected dim=%d hidden=%d layers=%d\n", DIM, HIDDEN, NLAYERS);
fclose(f); return false;
}
int V = abs(cfg.vocab_size);
bool shared = cfg.vocab_size > 0;
// Read in llama2.c order: embed, rms_att[all], wq[all], wk[all], wv[all], wo[all],
// rms_ffn[all], w1[all], w2[all], w3[all], rms_final, [wcls]
fread(embed, 4, V * DIM, f);
// rms_att weights for all layers (contiguous)
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_att, 4, DIM, f);
// wq for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wq, 4, WQ_SZ, f);
// wk for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wk, 4, WQ_SZ, f);
// wv for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wv, 4, WQ_SZ, f);
// wo for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].Wo, 4, WO_SZ, f);
// rms_ffn weights for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].rms_ffn, 4, DIM, f);
// w1 for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W1, 4, W1_SZ, f);
// w2 for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W2, 4, W2_SZ, f);
// w3 for all layers
for (int L = 0; L < NLAYERS; L++) fread(lw[L].W3, 4, W3_SZ, f);
// rms_final
fread(rms_final, 4, DIM, f);
// wcls = embed if shared (we just use embed pointer)
fclose(f);
printf(" Loaded pretrained weights (%s)\n", shared ? "shared embed/cls" : "separate cls");
return true;
}
// ===== Compile one layer's kernels =====
static bool compile_layer_kernels(LayerKernels *lk, LayerWeights *w) {
lk->fwdAttn = compile_kern_mil_w(gen_sdpa_fwd_taps(), (@{
@"@model_path/weights/rms1.bin": @{@"offset":@0, @"data":build_blob(w->rms_att,1,DIM)},
@"@model_path/weights/wq.bin": @{@"offset":@0, @"data":build_blob(w->Wq,DIM,DIM)},
@"@model_path/weights/wk.bin": @{@"offset":@0, @"data":build_blob(w->Wk,DIM,DIM)},
@"@model_path/weights/wv.bin": @{@"offset":@0, @"data":build_blob(w->Wv,DIM,DIM)},
@"@model_path/weights/wo.bin": @{@"offset":@0, @"data":build_blob(w->Wo,DIM,DIM)},
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
}), DIM*SEQ*2, 6*DIM*SEQ*2);
lk->fwdFFN = compile_kern_mil_w(gen_ffn_fwd_taps(), (@{
@"@model_path/weights/rms2.bin": @{@"offset":@0, @"data":build_blob(w->rms_ffn,1,DIM)},
@"@model_path/weights/w1.bin": @{@"offset":@0, @"data":build_blob(w->W1,HIDDEN,DIM)},
@"@model_path/weights/w3.bin": @{@"offset":@0, @"data":build_blob(w->W3,HIDDEN,DIM)},
@"@model_path/weights/w2.bin": @{@"offset":@0, @"data":build_blob(w->W2,DIM,HIDDEN)},
}), DIM*SEQ*2, (2*DIM+3*HIDDEN)*SEQ*2);
lk->ffnBwd = compile_kern_mil_w(gen_ffn_bwd(), (@{
@"@model_path/weights/w2t.bin": @{@"offset":@0, @"data":build_blob_t(w->W2,DIM,HIDDEN)},
@"@model_path/weights/w1t.bin": @{@"offset":@0, @"data":build_blob_t(w->W1,HIDDEN,DIM)},
@"@model_path/weights/w3t.bin": @{@"offset":@0, @"data":build_blob_t(w->W3,HIDDEN,DIM)},
}), (DIM+2*HIDDEN)*SEQ*2, (DIM+2*HIDDEN)*SEQ*2);
lk->sdpaBwd1 = compile_kern_mil_w(gen_sdpa_bwd1(), (@{
@"@model_path/weights/mask.bin": @{@"offset":@0, @"data":get_mask_blob()},
@"@model_path/weights/wot.bin": @{@"offset":@0, @"data":build_blob_t(w->Wo,DIM,DIM)},
}), 4*DIM*SEQ*2, (DIM+2*SCORE_CH)*SEQ*2);
lk->qkvBwd = compile_kern_mil_w(gen_qkvb(), (@{
@"@model_path/weights/wqt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wq,DIM,DIM)},
@"@model_path/weights/wkt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wk,DIM,DIM)},
@"@model_path/weights/wvt.bin": @{@"offset":@0, @"data":build_blob_t(w->Wv,DIM,DIM)},
}), 3*DIM*SEQ*2, DIM*SEQ*2);
return lk->fwdAttn && lk->fwdFFN && lk->ffnBwd && lk->sdpaBwd1 && lk->qkvBwd;
}
// Compile weight-free sdpaBwd2 (only needs once, no weights)
static Kern *compile_sdpa_bwd2(void) {
return compile_kern_mil_w(gen_sdpa_bwd2(), @{},
(2*SCORE_CH+2*DIM)*SEQ*2, 2*DIM*SEQ*2);
}
static void free_layer_kernels(LayerKernels *lk) {
free_kern(lk->fwdAttn); free_kern(lk->fwdFFN); free_kern(lk->ffnBwd);
free_kern(lk->sdpaBwd1); free_kern(lk->qkvBwd);
// sdpaBwd2 is shared, freed separately
lk->fwdAttn = lk->fwdFFN = lk->ffnBwd = lk->sdpaBwd1 = lk->qkvBwd = NULL;
}
// ===== Checkpoint save/load =====
static void save_checkpoint(const char *path, int step, int total_steps, float lr, float loss,
double cc, double ct, double cw, int cs, int cb, int adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "wb");
CkptHdr h = {0};
h.magic = 0x424C5A54; h.version = 2;
h.step = step; h.total_steps = total_steps;
h.n_layers = NLAYERS; h.vocab_size = VOCAB; h.dim = DIM;
h.hidden_dim = HIDDEN; h.n_heads = HEADS; h.seq_len = SEQ;
h.lr = lr; h.loss = loss;
h.cum_compile = cc; h.cum_train = ct; h.cum_wall = cw;
h.cum_steps = cs; h.cum_batches = cb; h.adam_t = adam_t;
fwrite(&h, sizeof(h), 1, f);
// Per-layer weights + adam
for (int L = 0; L < NLAYERS; L++) {
fwrite(lw[L].Wq,4,WQ_SZ,f); fwrite(lw[L].Wk,4,WQ_SZ,f);
fwrite(lw[L].Wv,4,WQ_SZ,f); fwrite(lw[L].Wo,4,WO_SZ,f);
fwrite(lw[L].W1,4,W1_SZ,f); fwrite(lw[L].W2,4,W2_SZ,f); fwrite(lw[L].W3,4,W3_SZ,f);
fwrite(lw[L].rms_att,4,DIM,f); fwrite(lw[L].rms_ffn,4,DIM,f);
// Adam state
fwrite(la[L].Wq.m,4,WQ_SZ,f); fwrite(la[L].Wq.v,4,WQ_SZ,f);
fwrite(la[L].Wk.m,4,WQ_SZ,f); fwrite(la[L].Wk.v,4,WQ_SZ,f);
fwrite(la[L].Wv.m,4,WQ_SZ,f); fwrite(la[L].Wv.v,4,WQ_SZ,f);
fwrite(la[L].Wo.m,4,WO_SZ,f); fwrite(la[L].Wo.v,4,WO_SZ,f);
fwrite(la[L].W1.m,4,W1_SZ,f); fwrite(la[L].W1.v,4,W1_SZ,f);
fwrite(la[L].W2.m,4,W2_SZ,f); fwrite(la[L].W2.v,4,W2_SZ,f);
fwrite(la[L].W3.m,4,W3_SZ,f); fwrite(la[L].W3.v,4,W3_SZ,f);
fwrite(la[L].rms_att.m,4,DIM,f); fwrite(la[L].rms_att.v,4,DIM,f);
fwrite(la[L].rms_ffn.m,4,DIM,f); fwrite(la[L].rms_ffn.v,4,DIM,f);
}
fwrite(rms_final,4,DIM,f);
fwrite(arms_final->m,4,DIM,f); fwrite(arms_final->v,4,DIM,f);
fwrite(embed,4,VOCAB*DIM,f);
fwrite(aembed->m,4,VOCAB*DIM,f); fwrite(aembed->v,4,VOCAB*DIM,f);
fclose(f);
}
static bool load_checkpoint(const char *path, int *step, int *total_steps, float *lr, float *loss,
double *cc, double *ct, double *cw, int *cs, int *cb, int *adam_t,
LayerWeights *lw, LayerAdam *la, float *rms_final, AdamState *arms_final,
float *embed, AdamState *aembed) {
FILE *f = fopen(path, "rb");
if (!f) return false;
CkptHdr h;
fread(&h, sizeof(h), 1, f);
if (h.magic != 0x424C5A54 || h.version != 2) { fclose(f); return false; }
*step = h.step; *total_steps = h.total_steps; *lr = h.lr; *loss = h.loss;
*cc = h.cum_compile; *ct = h.cum_train; *cw = h.cum_wall;
*cs = h.cum_steps; *cb = h.cum_batches; *adam_t = h.adam_t;
for (int L = 0; L < NLAYERS; L++) {
fread(lw[L].Wq,4,WQ_SZ,f); fread(lw[L].Wk,4,WQ_SZ,f);
fread(lw[L].Wv,4,WQ_SZ,f); fread(lw[L].Wo,4,WO_SZ,f);
fread(lw[L].W1,4,W1_SZ,f); fread(lw[L].W2,4,W2_SZ,f); fread(lw[L].W3,4,W3_SZ,f);
fread(lw[L].rms_att,4,DIM,f); fread(lw[L].rms_ffn,4,DIM,f);
fread(la[L].Wq.m,4,WQ_SZ,f); fread(la[L].Wq.v,4,WQ_SZ,f);
fread(la[L].Wk.m,4,WQ_SZ,f); fread(la[L].Wk.v,4,WQ_SZ,f);
fread(la[L].Wv.m,4,WQ_SZ,f); fread(la[L].Wv.v,4,WQ_SZ,f);
fread(la[L].Wo.m,4,WO_SZ,f); fread(la[L].Wo.v,4,WO_SZ,f);
fread(la[L].W1.m,4,W1_SZ,f); fread(la[L].W1.v,4,W1_SZ,f);
fread(la[L].W2.m,4,W2_SZ,f); fread(la[L].W2.v,4,W2_SZ,f);
fread(la[L].W3.m,4,W3_SZ,f); fread(la[L].W3.v,4,W3_SZ,f);
fread(la[L].rms_att.m,4,DIM,f); fread(la[L].rms_att.v,4,DIM,f);
fread(la[L].rms_ffn.m,4,DIM,f); fread(la[L].rms_ffn.v,4,DIM,f);
}
fread(rms_final,4,DIM,f);
fread(arms_final->m,4,DIM,f); fread(arms_final->v,4,DIM,f);
fread(embed,4,VOCAB*DIM,f);
fread(aembed->m,4,VOCAB*DIM,f); fread(aembed->v,4,VOCAB*DIM,f);
fclose(f);
return true;
}
// ===== Main =====
int main(int argc, char *argv[]) {
@autoreleasepool {
setbuf(stdout, NULL);
ane_init();
mach_timebase_info(&g_tb);
int total_steps = 10000;
float lr = 3e-4f;
float adam_b1=0.9f, adam_b2=0.999f, adam_eps=1e-8f;
int adam_t = 0, start_step = 0;
// Parse args
bool do_resume = false;
for (int i=1; i<argc; i++) {
if (strcmp(argv[i], "--resume") == 0) do_resume = true;
else if (strcmp(argv[i], "--steps") == 0 && i+1<argc) total_steps = atoi(argv[++i]);
else if (strcmp(argv[i], "--lr") == 0 && i+1<argc) lr = atof(argv[++i]);
}
// Allocate per-layer state
LayerWeights lw[NLAYERS];
LayerAdam la[NLAYERS];
LayerActs acts[NLAYERS];
LayerGrads grads[NLAYERS];
// Double-buffer: two sets of kernels
LayerKernels kern_A[NLAYERS], kern_B[NLAYERS];
LayerKernels *kern_active = kern_A; // currently running evals
LayerKernels *kern_pending = kern_B; // being compiled in background
static _Atomic bool pending_ready = false; // signal: pending compile done
static _Atomic bool bg_compile_running = false;
dispatch_queue_t compile_q = dispatch_queue_create("ane.compile.bg", DISPATCH_QUEUE_SERIAL);
// Legacy alias for code that uses kern[L]
#define kern kern_active
for (int L=0; L<NLAYERS; L++) {
lw[L] = layer_weights_alloc();
la[L] = layer_adam_alloc();
acts[L] = layer_acts_alloc();
grads[L] = layer_grads_alloc();
memset(&kern_A[L], 0, sizeof(LayerKernels));
memset(&kern_B[L], 0, sizeof(LayerKernels));
}
// Final RMSNorm + embedding + classifier
float *rms_final = (float*)malloc(DIM*4);
float *embed = (float*)malloc(VOCAB*DIM*4); // [VOCAB, DIM] row-major
float *grms_final = (float*)calloc(DIM, 4);
float *gembed = (float*)calloc(VOCAB*DIM, 4);
AdamState arms_final = adam_alloc(DIM);
AdamState aembed = adam_alloc((size_t)VOCAB*DIM);
double cum_compile=0, cum_train=0, cum_wall=0;
int cum_steps=0, cum_batches=0;
float resume_loss = 0;
bool resuming = false;
if (do_resume) {
resuming = load_checkpoint(CKPT_PATH, &start_step, &total_steps, &lr, &resume_loss,
&cum_compile, &cum_train, &cum_wall, &cum_steps, &cum_batches, &adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
if (resuming) printf("[RESUMED step %d, loss=%.4f]\n", start_step, resume_loss);
}
if (!resuming) {
printf("=== ANE Training: Stories110M (12 layers) ===\n");
printf("dim=%d hidden=%d heads=%d seq=%d vocab=%d layers=%d\n", DIM, HIDDEN, HEADS, SEQ, VOCAB, NLAYERS);
if (!load_pretrained(lw, rms_final, embed, MODEL_PATH)) {
printf("Pretrained load failed, using random init\n");
srand48(42);
float scale_d=1.0f/sqrtf(DIM), scale_h=1.0f/sqrtf(HIDDEN);
for (int L=0; L<NLAYERS; L++) {
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wq[i]=scale_d*(2*drand48()-1);lw[L].Wk[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<WQ_SZ;i++){lw[L].Wv[i]=scale_d*(2*drand48()-1);lw[L].Wo[i]=scale_d*(2*drand48()-1);}
for(size_t i=0;i<W1_SZ;i++) lw[L].W1[i]=scale_h*(2*drand48()-1);
for(size_t i=0;i<W2_SZ;i++) lw[L].W2[i]=scale_d*(2*drand48()-1);
for(size_t i=0;i<W3_SZ;i++) lw[L].W3[i]=scale_h*(2*drand48()-1);
for(int i=0;i<DIM;i++){lw[L].rms_att[i]=1.0f; lw[L].rms_ffn[i]=1.0f;}
}
for(int i=0;i<DIM;i++) rms_final[i]=1.0f;
float escale = 0.02f;
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) embed[i]=escale*(2*drand48()-1);
}
size_t tp = (size_t)NLAYERS*LAYER_PARAMS + DIM + (size_t)VOCAB*DIM;
double xfmr_params = (double)NLAYERS*LAYER_PARAMS;
double embed_params = (double)VOCAB*DIM;
printf("Params: %.2fM (transformer %.2fM + embed %.2fM)\n", tp/1e6, xfmr_params/1e6, embed_params/1e6);
printf("Kernels: %d (%d weight-bearing + %d static sdpaBwd2)\n",
TOTAL_WEIGHT_KERNELS+NLAYERS, TOTAL_WEIGHT_KERNELS, NLAYERS);
printf("Accum %d steps per recompile | Adam LR=%.1e b1=%.1f b2=%.3f\n", ACCUM_STEPS, lr, adam_b1, adam_b2);
double fwd_f = NLAYERS*(4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bwd_dx_f = fwd_f, bwd_dw_f = fwd_f;
double sdpa_f = NLAYERS*2.0*HEADS*5*SEQ*SEQ*HD;
double cls_f = 2.0*VOCAB*DIM*SEQ;
double total_f = fwd_f + bwd_dx_f + bwd_dw_f + sdpa_f + cls_f*3;
double ane_f = fwd_f + bwd_dx_f + sdpa_f;
printf("FLOPs/step: fwd=%.0fM bwd_dx=%.0fM bwd_dW=%.0fM sdpa_bwd=%.0fM total=%.0fM\n",
fwd_f/1e6, bwd_dx_f/1e6, bwd_dw_f/1e6, sdpa_f/1e6, total_f/1e6);
printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | CPU: dW+cls (cblas)\n\n", ane_f/1e6);
}
// mmap token data (or generate synthetic if not available)
uint16_t *token_data = NULL;
size_t n_tokens = 0;
size_t data_len = 0;
bool synthetic_data = false;
int data_fd = open(DATA_PATH, O_RDONLY);
if (data_fd >= 0) {
struct stat st; fstat(data_fd, &st);
data_len = st.st_size;
token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0);
if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; }
n_tokens = data_len / 2;
printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6);
} else {
// Synthetic data for double-buffer benchmark
synthetic_data = true;
n_tokens = 100000;
data_len = n_tokens * 2;
token_data = (uint16_t*)malloc(data_len);
srand48(123);
for (size_t i = 0; i < n_tokens; i++)
token_data[i] = (uint16_t)(drand48() * (VOCAB - 1));
printf("[DB] Using synthetic data: %zu tokens (benchmark mode)\n", n_tokens);
}
// Gradient buffers shared across layers (reused each step)
float *dy = (float*)malloc(SEQ*DIM*4); // gradient flowing backward
float *dffn = (float*)malloc(SEQ*DIM*4);
float *dh1 = (float*)malloc(SEQ*HIDDEN*4);
float *dh3 = (float*)malloc(SEQ*HIDDEN*4);
float *dx_ffn = (float*)malloc(SEQ*DIM*4);
float *dx2 = (float*)malloc(SEQ*DIM*4);
float *do_out_buf = (float*)malloc(SEQ*DIM*4);
float *dq = (float*)malloc(SEQ*DIM*4);
float *dk = (float*)malloc(SEQ*DIM*4);
float *dv = (float*)malloc(SEQ*DIM*4);
float *dx_attn = (float*)malloc(SEQ*DIM*4);
// x buffer for input to each layer (channel-first [DIM, SEQ])
float *x_cur = (float*)malloc(SEQ*DIM*4);
float *x_final = (float*)malloc(SEQ*DIM*4); // after final rmsnorm
float *logits = (float*)malloc(SEQ*VOCAB*4); // [VOCAB, SEQ] for cross-entropy
float *dlogits = (float*)malloc(SEQ*VOCAB*4);
// Compile static sdpaBwd2 kernels (no weights, one per layer)
Kern *sdpaBwd2[NLAYERS];
for (int L=0; L<NLAYERS; L++) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
}
dispatch_queue_t dw_q = dispatch_queue_create("dw_cblas", DISPATCH_QUEUE_SERIAL);
dispatch_group_t dw_grp = dispatch_group_create();
float last_loss = 999.0f;
double total_compile_ms=0, total_train_ms=0;
int total_steps_done=0, total_batches=0;
uint64_t t_wall_start = mach_absolute_time();
srand48(42 + start_step);
// ===== DOUBLE-BUFFER: Initial synchronous compile (first batch only) =====
printf(" [DB] Initial compile (synchronous)...\n");
{
uint64_t tc = mach_absolute_time();
for (int L=0; L<NLAYERS; L++) {
printf(" Compiling layer %d/%d... (%d compiles)\r", L+1, NLAYERS, g_compile_count);
fflush(stdout);
if (!compile_layer_kernels(&kern_active[L], &lw[L])) {
printf("\nInitial compile failed at layer %d\n", L);
return 1;
}
}
// Compile static sdpaBwd2 kernels
for (int L=0; L<NLAYERS; L++) {
if (!sdpaBwd2[L]) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 compile failed\n"); return 1; }
}
}
double cms = tb_ms(mach_absolute_time() - tc);
total_compile_ms += cms;
printf(" [DB] Initial compile: %d kernels in %.0fms\n", TOTAL_WEIGHT_KERNELS, cms);
}
// Helper block: compile all layers into a kernel set
// Captured by the GCD block for background compilation
void (^compile_into)(LayerKernels *, LayerWeights *) = ^(LayerKernels *target, LayerWeights *weights) {
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&target[L]);
if (!compile_layer_kernels(&target[L], &weights[L])) {
printf("\n [DB] Background compile failed at layer %d\n", L);
return;
}
}
};
int step = start_step;
int batches_since_swap = 0;
double total_stall_ms = 0;
while (step < total_steps) {
// Check compile budget
if (g_compile_count + TOTAL_WEIGHT_KERNELS > DB_MAX_COMPILES) {
// Wait for any in-flight background compile
dispatch_sync(compile_q, ^{});
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&kern_A[L]);
free_layer_kernels(&kern_B[L]);
free_kern(sdpaBwd2[L]); sdpaBwd2[L] = NULL;
}
#undef kern
double wall = tb_ms(mach_absolute_time() - t_wall_start);
save_checkpoint(CKPT_PATH, step, total_steps, lr, last_loss,
total_compile_ms+cum_compile, total_train_ms+cum_train, wall+cum_wall,
total_steps_done+cum_steps, total_batches+cum_batches, adam_t,
lw, la, rms_final, &arms_final, embed, &aembed);
printf("[exec() restart step %d, %d compiles, loss=%.4f]\n", step, g_compile_count, last_loss);
fflush(stdout);
execl(argv[0], argv[0], "--resume", NULL);
perror("execl"); return 1;
#define kern kern_active
}
// ===== DOUBLE-BUFFER: Check if pending kernels are ready to swap =====
if (atomic_load(&pending_ready)) {
// Swap: pending becomes active, old active becomes recycle target
LayerKernels *old_active = kern_active;
kern_active = kern_pending;
kern_pending = old_active;
atomic_store(&pending_ready, false);
batches_since_swap = 0;
printf(" [DB] Swapped kernels (stall=0ms)\n");
}
// Re-compile sdpaBwd2 if needed (after exec restart)
for (int L=0; L<NLAYERS; L++) {
if (!sdpaBwd2[L]) {
sdpaBwd2[L] = compile_sdpa_bwd2();
if (!sdpaBwd2[L]) { printf("sdpaBwd2 recompile failed\n"); return 1; }
}
}
// Zero gradient accumulators
for (int L=0; L<NLAYERS; L++) layer_grads_zero(&grads[L]);
memset(grms_final, 0, DIM*4);
memset(gembed, 0, (size_t)VOCAB*DIM*4);
int steps_batch = 0;
uint64_t tt = mach_absolute_time();
double t_ane=0,t_io=0,t_elem=0,t_rms=0,t_cblas_wait=0,t_cls=0;
for (int a=0; a<ACCUM_STEPS && step<total_steps; a++, step++) {
uint64_t t0,t1;
// Sample random position in token data
size_t max_pos = n_tokens - SEQ - 1;
size_t pos = (size_t)(drand48() * max_pos);
uint16_t *input_tokens = token_data + pos;
uint16_t *target_tokens = token_data + pos + 1;
// Embedding lookup x_cur [DIM, SEQ] channel-first
t0=mach_absolute_time();
embed_lookup(x_cur, embed, input_tokens, DIM, SEQ);
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
// ===== FORWARD (12 layers) =====
for (int L=0; L<NLAYERS; L++) {
LayerActs *ac = &acts[L];
// Save layer input for rmsnorm1 backward
memcpy(ac->layer_in, x_cur, SEQ*DIM*4);
// Attention forward: x_cur o_out,Q,K,V,attn_out,xnorm
t0=mach_absolute_time();
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
t1=mach_absolute_time(); t_cblas_wait+=tb_ms(t1-t0); t0=t1;
io_write_fp16(kern[L].fwdAttn->ioIn, x_cur, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kern[L].fwdAttn);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kern[L].fwdAttn->ioOut, ac->o_out, 0, DIM, SEQ);
io_read_fp16(kern[L].fwdAttn->ioOut, ac->attn_out, 4*DIM, DIM, SEQ);
io_read_fp16(kern[L].fwdAttn->ioOut, ac->xnorm, 5*DIM, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
vDSP_vadd(x_cur, 1, ac->o_out, 1, ac->x2, 1, (vDSP_Length)(SEQ*DIM));
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// FFN forward
io_write_fp16(kern[L].fwdFFN->ioIn, ac->x2, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
ane_eval(kern[L].fwdFFN);
t1=mach_absolute_time(); t_ane+=tb_ms(t1-t0); t0=t1;
io_read_fp16(kern[L].fwdFFN->ioOut, ac->ffn_out, 0, DIM, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->h1, DIM, HIDDEN, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->h3, DIM+HIDDEN, HIDDEN, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->silu_out, DIM+2*HIDDEN, HIDDEN, SEQ);
io_read_fp16(kern[L].fwdFFN->ioOut, ac->x2norm, DIM+3*HIDDEN, DIM, SEQ);
t1=mach_absolute_time(); t_io+=tb_ms(t1-t0); t0=t1;
vDSP_vadd(ac->x2, 1, ac->ffn_out, 1, x_cur, 1, (vDSP_Length)(SEQ*DIM));
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0);
}
// Final RMSNorm (CPU)
t0=mach_absolute_time();
rmsnorm(x_final, x_cur, rms_final, DIM, SEQ);
t1=mach_absolute_time(); t_rms+=tb_ms(t1-t0); t0=t1;
// Classifier: logits = embed^T @ x_final
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
VOCAB, SEQ, DIM, 1.0f,
embed, DIM, x_final, SEQ, 0.0f, logits, SEQ);
t1=mach_absolute_time(); t_cls+=tb_ms(t1-t0); t0=t1;
// Cross-entropy loss
float loss = cross_entropy_loss(dlogits, logits, target_tokens, VOCAB, SEQ);
last_loss = loss;
t1=mach_absolute_time(); t_elem+=tb_ms(t1-t0); t0=t1;
// ===== BACKWARD =====
// dlogits already computed by cross_entropy_loss
// Classifier backward: dx_final = embed^T @ dlogits, dembed += dlogits @ x_final^T
// dx_final[DIM,SEQ] = embed^T[DIM,VOCAB] @ dlogits[VOCAB,SEQ]
cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans,
DIM, SEQ, VOCAB, 1.0f,
embed, DIM, dlogits, SEQ, 0.0f, dy, SEQ);
// dembed[VOCAB,DIM] += dlogits[VOCAB,SEQ] @ x_final^T[SEQ,DIM]
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
VOCAB, DIM, SEQ, 1.0f,
dlogits, SEQ, x_final, SEQ, 1.0f, gembed, DIM);
});
// Final RMSNorm backward
float *dx_rms_final = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms_final, grms_final, dy, x_cur, rms_final, DIM, SEQ);
memcpy(dy, dx_rms_final, SEQ*DIM*4);
free(dx_rms_final);
// ===== BACKWARD (12 layers, reverse) =====
for (int L=NLAYERS-1; L>=0; L--) {
LayerActs *ac = &acts[L];
LayerGrads *gr = &grads[L];
// dy is the gradient at the output of this layer
// dffn = dy (residual connection: d(x2 + ffn) = dy for both)
memcpy(dffn, dy, SEQ*DIM*4);
// FFN backward (ANE)
io_write_fp16_at(kern[L].ffnBwd->ioIn, 0, dffn, DIM, SEQ);
io_copy(kern[L].ffnBwd->ioIn, DIM, kern[L].fwdFFN->ioOut, DIM, 2*HIDDEN, SEQ);
ane_eval(kern[L].ffnBwd);
io_read_fp16(kern[L].ffnBwd->ioOut, dx_ffn, 0, DIM, SEQ);
io_read_fp16(kern[L].ffnBwd->ioOut, dh1, DIM, HIDDEN, SEQ);
io_read_fp16(kern[L].ffnBwd->ioOut, dh3, DIM+HIDDEN, HIDDEN, SEQ);
// dW FFN async
float *capt_dffn = (float*)malloc(SEQ*DIM*4); memcpy(capt_dffn, dffn, SEQ*DIM*4);
float *capt_silu = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_silu, ac->silu_out, SEQ*HIDDEN*4);
float *capt_dh1 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh1, dh1, SEQ*HIDDEN*4);
float *capt_dh3 = (float*)malloc(SEQ*HIDDEN*4); memcpy(capt_dh3, dh3, SEQ*HIDDEN*4);
float *capt_x2n = (float*)malloc(SEQ*DIM*4); memcpy(capt_x2n, ac->x2norm, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, HIDDEN, SEQ,
1.0f, capt_dffn, SEQ, capt_silu, SEQ, 1.0f, gr->W2, HIDDEN);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh1, SEQ, capt_x2n, SEQ, 1.0f, gr->W1, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, HIDDEN, DIM, SEQ,
1.0f, capt_dh3, SEQ, capt_x2n, SEQ, 1.0f, gr->W3, DIM);
free(capt_dffn); free(capt_silu); free(capt_dh1); free(capt_dh3); free(capt_x2n);
});
// RMSNorm2 backward
memset(dx2, 0, SEQ*DIM*4);
rmsnorm_bwd(dx2, gr->rms_ffn, dx_ffn, ac->x2, lw[L].rms_ffn, DIM, SEQ);
// Add residual: dx2 += dy (from skip connection)
for(int i=0;i<SEQ*DIM;i++) dx2[i] += dy[i];
// dWo async
memcpy(do_out_buf, dx2, SEQ*DIM*4);
float *capt_do = (float*)malloc(SEQ*DIM*4); memcpy(capt_do, do_out_buf, SEQ*DIM*4);
float *capt_attn = (float*)malloc(SEQ*DIM*4); memcpy(capt_attn, ac->attn_out, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_do, SEQ, capt_attn, SEQ, 1.0f, gr->Wo, DIM);
free(capt_do); free(capt_attn);
});
// SDPA backward (ANE)
io_copy(kern[L].sdpaBwd1->ioIn, 0, kern[L].fwdAttn->ioOut, DIM, 3*DIM, SEQ);
io_write_fp16_at(kern[L].sdpaBwd1->ioIn, 3*DIM, dx2, DIM, SEQ);
ane_eval(kern[L].sdpaBwd1);
io_copy(sdpaBwd2[L]->ioIn, 0, kern[L].sdpaBwd1->ioOut, DIM, 2*SCORE_CH, SEQ);
io_copy(sdpaBwd2[L]->ioIn, 2*SCORE_CH, kern[L].fwdAttn->ioOut, DIM, 2*DIM, SEQ);
ane_eval(sdpaBwd2[L]);
io_read_fp16(sdpaBwd2[L]->ioOut, dq, 0, DIM, SEQ);
io_read_fp16(sdpaBwd2[L]->ioOut, dk, DIM, DIM, SEQ);
io_read_fp16(kern[L].sdpaBwd1->ioOut, dv, 0, DIM, SEQ);
// dWq/dWk/dWv async
float *capt_dq = (float*)malloc(SEQ*DIM*4); memcpy(capt_dq, dq, SEQ*DIM*4);
float *capt_dk = (float*)malloc(SEQ*DIM*4); memcpy(capt_dk, dk, SEQ*DIM*4);
float *capt_dv = (float*)malloc(SEQ*DIM*4); memcpy(capt_dv, dv, SEQ*DIM*4);
float *capt_xn = (float*)malloc(SEQ*DIM*4); memcpy(capt_xn, ac->xnorm, SEQ*DIM*4);
dispatch_group_async(dw_grp, dw_q, ^{
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dq, SEQ, capt_xn, SEQ, 1.0f, gr->Wq, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dk, SEQ, capt_xn, SEQ, 1.0f, gr->Wk, DIM);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, DIM, DIM, SEQ,
1.0f, capt_dv, SEQ, capt_xn, SEQ, 1.0f, gr->Wv, DIM);
free(capt_dq); free(capt_dk); free(capt_dv); free(capt_xn);
});
// QKV backward (ANE)
io_copy(kern[L].qkvBwd->ioIn, 0, sdpaBwd2[L]->ioOut, 0, 2*DIM, SEQ);
io_copy(kern[L].qkvBwd->ioIn, 2*DIM, kern[L].sdpaBwd1->ioOut, 0, DIM, SEQ);
ane_eval(kern[L].qkvBwd);
io_read_fp16(kern[L].qkvBwd->ioOut, dx_attn, 0, DIM, SEQ);
// RMSNorm1 backward (using saved layer input)
float *dx_rms1 = (float*)calloc(SEQ*DIM, 4);
rmsnorm_bwd(dx_rms1, gr->rms_att, dx_attn, ac->layer_in, lw[L].rms_att, DIM, SEQ);
// dy for next layer (going backward) = dx_rms1 + dx2 residual
// Actually: layer output = layer_input + o_out, and x2 = layer_input + o_out
// So dx(layer_input) = dx_attn_rmsnorm + dx2 (residual from attn skip)
// Wait, dx2 already includes the attn skip residual gradient.
// dy = dx_rms1 (through rmsnorm1) is the gradient to the layer input
// But there's also the skip connection: layer_input x2 directly
// So total gradient to layer_input = dx_rms1 + dx2_skip
// dx2 was computed as rmsnorm2_bwd + dy(ffn_skip), which already flows to x2
// x2 = layer_input + o_out, so d(layer_input) from x2 path = dx2
// And d(layer_input) from attn path through rmsnorm1 = dx_rms1
// Total: dy_prev = dx_rms1 (attn rmsnorm path)
// Wait no - dx2 = d(loss)/d(x2), not d(loss)/d(layer_input)
// d(layer_input) = d(loss)/d(x2) * d(x2)/d(layer_input) = dx2 (since x2 = input + o_out, d(x2)/d(input) = 1)
// Plus the path through rmsnorm1: dx_rms1
// Hmm but dx2 was already used as input to SDPA backward... let me reconsider.
//
// Actually the gradient flow is:
// dy split to (dffn, dy_skip) [dy_skip = dy due to residual]
// dffn ffnBwd dx_ffn
// dx_ffn rmsnorm2_bwd dx_rms2
// dx2 = dx_rms2 + dy (skip connection from residual x2 output)
// dx2 sdpaBwd dx_attn through Wo^T
// dx_attn qkvBwd dx_qkv
// dx_qkv rmsnorm1_bwd dx_rms1
// dy_prev_layer = dx_rms1 + dx2 (skip connection input x2)
//
// So: dy for previous layer = dx_rms1 + dx2
for(int i=0;i<SEQ*DIM;i++) dy[i] = dx_rms1[i] + dx2[i];
free(dx_rms1);
}
// Embedding backward
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
embed_backward(gembed, dy, input_tokens, DIM, SEQ);
steps_batch++;
if (step % 10 == 0 || step == start_step)
printf("step %-4d loss=%.4f\n", step, loss);
// JSON telemetry to stderr
double step_ane = t_ane/steps_batch, step_io = t_io/steps_batch;
double step_cls = t_cls/steps_batch, step_elem = t_elem/steps_batch;
double step_rms = t_rms/steps_batch, step_cbw = t_cblas_wait/steps_batch;
fprintf(stderr, "{\"type\":\"step\",\"step\":%d,\"loss\":%.6f,"
"\"t_ane\":%.3f,\"t_io\":%.3f,\"t_cls\":%.3f,"
"\"t_elem\":%.3f,\"t_rms\":%.3f,\"t_cblas_wait\":%.3f,"
"\"compiles\":%d}\n",
step, loss, step_ane, step_io, step_cls, step_elem, step_rms, step_cbw, g_compile_count);
}
double tms = tb_ms(mach_absolute_time() - tt);
total_train_ms += tms;
total_steps_done += steps_batch;
total_batches++;
// Ensure all async dW finished
dispatch_group_wait(dw_grp, DISPATCH_TIME_FOREVER);
// Adam update (scale gradients by 1/steps_batch)
float gsc = 1.0f / steps_batch;
adam_t++;
for (int L=0; L<NLAYERS; L++) {
LayerGrads *g = &grads[L];
for(size_t i=0;i<WQ_SZ;i++){g->Wq[i]*=gsc;g->Wk[i]*=gsc;g->Wv[i]*=gsc;g->Wo[i]*=gsc;}
for(size_t i=0;i<W1_SZ;i++) g->W1[i]*=gsc;
for(size_t i=0;i<W2_SZ;i++) g->W2[i]*=gsc;
for(size_t i=0;i<W3_SZ;i++) g->W3[i]*=gsc;
for(int i=0;i<DIM;i++){g->rms_att[i]*=gsc; g->rms_ffn[i]*=gsc;}
adam_update(lw[L].Wq, g->Wq, &la[L].Wq, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wk, g->Wk, &la[L].Wk, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wv, g->Wv, &la[L].Wv, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].Wo, g->Wo, &la[L].Wo, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W1, g->W1, &la[L].W1, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W2, g->W2, &la[L].W2, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].W3, g->W3, &la[L].W3, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_att, g->rms_att, &la[L].rms_att, adam_t, lr, adam_b1, adam_b2, adam_eps);
adam_update(lw[L].rms_ffn, g->rms_ffn, &la[L].rms_ffn, adam_t, lr, adam_b1, adam_b2, adam_eps);
}
for(int i=0;i<DIM;i++) grms_final[i]*=gsc;
adam_update(rms_final, grms_final, &arms_final, adam_t, lr, adam_b1, adam_b2, adam_eps);
// Scale and update embed
for(size_t i=0;i<(size_t)VOCAB*DIM;i++) gembed[i]*=gsc;
adam_update(embed, gembed, &aembed, adam_t, lr, adam_b1, adam_b2, adam_eps);
// ===== DOUBLE-BUFFER: Start background compile with updated weights =====
batches_since_swap++;
// Only start bg compile if we have budget
if (!atomic_load(&bg_compile_running) &&
g_compile_count + TOTAL_WEIGHT_KERNELS <= DB_MAX_COMPILES) {
atomic_store(&bg_compile_running, true);
// Capture pointers (not stack arrays) for background block
LayerKernels *bg_target = kern_pending;
LayerWeights *bg_weights = lw; // decays to pointer, safe for block
dispatch_async(compile_q, ^{
compile_into(bg_target, bg_weights);
atomic_store(&pending_ready, true);
atomic_store(&bg_compile_running, false);
});
}
double cms = 0; // compile was async, no stall
printf(" [batch %d: compile_stall=0ms train=%.1fms (%.1fms/step) compiles=%d bg=%s]\n",
steps_batch, tms, tms/steps_batch, g_compile_count,
atomic_load(&bg_compile_running) ? "compiling" : "idle");
printf(" ane=%.1f io=%.1f cls=%.1f elem=%.1f rms=%.1f cblas_wait=%.1f ms/step\n",
t_ane/steps_batch, t_io/steps_batch, t_cls/steps_batch, t_elem/steps_batch,
t_rms/steps_batch, t_cblas_wait/steps_batch);
// JSON batch telemetry to stderr
{
double bf = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double bs = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
double ane_f_batch = (bf*2 + bs) * steps_batch;
double ane_tflops = ane_f_batch / (tms * 1e9);
fprintf(stderr, "{\"type\":\"batch\",\"batch\":%d,\"compile_ms\":%.1f,"
"\"train_ms\":%.1f,\"ms_per_step\":%.1f}\n",
steps_batch, cms, tms, tms/steps_batch);
fprintf(stderr, "{\"type\":\"perf\",\"ane_tflops\":%.3f,\"ane_util_pct\":%.2f}\n",
ane_tflops, 100.0*ane_tflops/15.8);
}
}
// Efficiency report
double wall = tb_ms(mach_absolute_time() - t_wall_start);
total_compile_ms += cum_compile; total_train_ms += cum_train;
wall += cum_wall; total_steps_done += cum_steps; total_batches += cum_batches;
double fwd_flops = NLAYERS * (4.0*2*DIM*DIM*SEQ + 2.0*2*DIM*HIDDEN*SEQ + 2.0*HIDDEN*DIM*SEQ);
double sdpa_flops = NLAYERS * 2.0*HEADS*5*SEQ*SEQ*HD;
double cls_flops = 2.0*VOCAB*DIM*SEQ;
double total_flops = (fwd_flops*3 + sdpa_flops + cls_flops*3) * total_steps_done;
double ane_flops = (fwd_flops*2 + sdpa_flops) * total_steps_done;
printf("\n=== Efficiency Report ===\n");
printf("Total steps: %d\n", total_steps_done);
printf("Wall time: %.0f ms (%.1f s)\n", wall, wall/1000);
printf("Compile time: %.0f ms (%.1f%%)\n", total_compile_ms, 100*total_compile_ms/wall);
printf("Train time: %.0f ms (%.1f%%)\n", total_train_ms, 100*total_train_ms/wall);
printf("Avg train: %.1f ms/step\n", total_train_ms/total_steps_done);
printf("ANE TFLOPS: %.2f sustained\n", ane_flops / (total_train_ms * 1e9));
printf("Total TFLOPS: %.2f (ANE+CPU)\n", total_flops / (total_train_ms * 1e9));
printf("ANE utilization: %.1f%% of 15.8 TFLOPS\n", 100*ane_flops/(total_train_ms*1e9)/15.8);
// Wait for any in-flight background compile
dispatch_sync(compile_q, ^{});
// Cleanup
#undef kern
for (int L=0; L<NLAYERS; L++) {
free_layer_kernels(&kern_A[L]);
free_layer_kernels(&kern_B[L]);
free_kern(sdpaBwd2[L]);
layer_weights_free(&lw[L]);
layer_adam_free(&la[L]);
layer_acts_free(&acts[L]);
layer_grads_free(&grads[L]);
}
if (synthetic_data) { free(token_data); }
else { munmap(token_data, data_len); close(data_fd); }
free(rms_final); free(embed); free(grms_final); free(gembed);
adam_free(&arms_final); adam_free(&aembed);
free(dy); free(dffn); free(dh1); free(dh3); free(dx_ffn); free(dx2);
free(do_out_buf); free(dq); free(dk); free(dv); free(dx_attn);
free(x_cur); free(x_final); free(logits); free(dlogits);
}
return 0;
}