diff --git a/training/train_double_buffer.m b/training/train_double_buffer.m index a74eabe..66cc359 100644 --- a/training/train_double_buffer.m +++ b/training/train_double_buffer.m @@ -282,15 +282,30 @@ int main(int argc, char *argv[]) { printf("ANE FLOPs/step: %.0fM (fwd+bwd_dx+sdpa_bwd) | CPU: dW+cls (cblas)\n\n", ane_f/1e6); } - // mmap token data + // mmap token data (or generate synthetic if not available) + uint16_t *token_data = NULL; + size_t n_tokens = 0; + size_t data_len = 0; + bool synthetic_data = false; int data_fd = open(DATA_PATH, O_RDONLY); - if (data_fd < 0) { printf("Cannot open %s\n", DATA_PATH); return 1; } - struct stat st; fstat(data_fd, &st); - size_t data_len = st.st_size; - uint16_t *token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); - if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } - size_t n_tokens = data_len / 2; - printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); + if (data_fd >= 0) { + struct stat st; fstat(data_fd, &st); + data_len = st.st_size; + token_data = (uint16_t*)mmap(NULL, data_len, PROT_READ, MAP_PRIVATE, data_fd, 0); + if (token_data == MAP_FAILED) { printf("mmap failed\n"); return 1; } + n_tokens = data_len / 2; + printf("Token data: %zu tokens (%.1f MB)\n", n_tokens, data_len/1e6); + } else { + // Synthetic data for double-buffer benchmark + synthetic_data = true; + n_tokens = 100000; + data_len = n_tokens * 2; + token_data = (uint16_t*)malloc(data_len); + srand48(123); + for (size_t i = 0; i < n_tokens; i++) + token_data[i] = (uint16_t)(drand48() * (VOCAB - 1)); + printf("[DB] Using synthetic data: %zu tokens (benchmark mode)\n", n_tokens); + } // Gradient buffers shared across layers (reused each step) float *dy = (float*)malloc(SEQ*DIM*4); // gradient flowing backward @@ -748,8 +763,8 @@ int main(int argc, char *argv[]) { layer_acts_free(&acts[L]); layer_grads_free(&grads[L]); } - munmap(token_data, data_len); - close(data_fd); + if (synthetic_data) { free(token_data); } + else { munmap(token_data, data_len); close(data_fd); } free(rms_final); free(embed); free(grms_final); free(gembed); adam_free(&arms_final); adam_free(&aembed); free(dy); free(dffn); free(dh1); free(dh3); free(dx_ffn); free(dx2);