mirror of https://github.com/maderix/ANE.git
Merge 17cda7d940 into 20cd236f61
This commit is contained in:
commit
09028da929
156
training/model.h
156
training/model.h
|
|
@ -82,30 +82,57 @@ static int model_load_weights(Model *m, const char *path) {
|
|||
fprintf(stderr, "ERROR: failed to read config from %s\n", path);
|
||||
fclose(f); return -1;
|
||||
}
|
||||
|
||||
if (m->cfg.n_layers < 1 || m->cfg.n_layers > N_LAYERS) {
|
||||
fprintf(stderr, "ERROR: n_layers (%d) exceeds maximum allowed (%d)\n", m->cfg.n_layers, N_LAYERS);
|
||||
fclose(f); return -1;
|
||||
}
|
||||
|
||||
if (m->cfg.dim < 1 || m->cfg.dim > 8192 ||
|
||||
m->cfg.hidden_dim < 1 || m->cfg.hidden_dim > 32768) {
|
||||
fprintf(stderr, "ERROR: model dimensions out of safe bounds\n");
|
||||
fclose(f); return -1;
|
||||
}
|
||||
|
||||
bool shared = m->cfg.vocab_size > 0;
|
||||
if (m->cfg.vocab_size < 0) m->cfg.vocab_size = -m->cfg.vocab_size;
|
||||
|
||||
if (m->cfg.vocab_size == 0 || m->cfg.vocab_size > 256000) {
|
||||
fprintf(stderr, "ERROR: vocab_size out of safe bounds\n");
|
||||
fclose(f); return -1;
|
||||
}
|
||||
|
||||
printf("Model: dim=%d hidden=%d layers=%d heads=%d vocab=%d seq=%d\n",
|
||||
m->cfg.dim, m->cfg.hidden_dim, m->cfg.n_layers, m->cfg.n_heads,
|
||||
m->cfg.vocab_size, m->cfg.seq_len);
|
||||
|
||||
int d = m->cfg.dim, hd = m->cfg.hidden_dim, nl = m->cfg.n_layers, vs = m->cfg.vocab_size;
|
||||
size_t d = (size_t)m->cfg.dim, hd = (size_t)m->cfg.hidden_dim, nl = (size_t)m->cfg.n_layers, vs = (size_t)m->cfg.vocab_size;
|
||||
|
||||
m->token_embedding = (float*)malloc(vs * d * sizeof(float));
|
||||
if (fread(m->token_embedding, sizeof(float), vs * d, f) != (size_t)(vs * d)) {
|
||||
if (!m->token_embedding) {
|
||||
fprintf(stderr, "ERROR: OOM allocating token_embedding\n");
|
||||
fclose(f); return -1;
|
||||
}
|
||||
if (fread(m->token_embedding, sizeof(float), vs * d, f) != (vs * d)) {
|
||||
fprintf(stderr, "ERROR: short read on token_embedding (file truncated?)\n");
|
||||
fclose(f); return -1;
|
||||
}
|
||||
|
||||
float *rms_att_all = (float*)malloc(nl * d * sizeof(float));
|
||||
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wk_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wv_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wo_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wq_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wk_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wv_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *wo_all = (float*)malloc(nl * d * d * sizeof(float));
|
||||
float *rms_ffn_all = (float*)malloc(nl * d * sizeof(float));
|
||||
float *w1_all = (float*)malloc(nl * hd * d * sizeof(float));
|
||||
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
|
||||
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));
|
||||
float *w1_all = (float*)malloc(nl * hd * d * sizeof(float));
|
||||
float *w2_all = (float*)malloc(nl * d * hd * sizeof(float));
|
||||
float *w3_all = (float*)malloc(nl * hd * d * sizeof(float));
|
||||
|
||||
if (!rms_att_all || !wq_all || !wk_all || !wv_all || !wo_all ||
|
||||
!rms_ffn_all || !w1_all || !w2_all || !w3_all) {
|
||||
fprintf(stderr, "ERROR: OOM allocating layer weights\n");
|
||||
fclose(f); return -1;
|
||||
}
|
||||
|
||||
#define FREAD_CHECK(buf, count, file, label) do { \
|
||||
size_t _n = fread(buf, sizeof(float), count, file); \
|
||||
|
|
@ -126,26 +153,28 @@ static int model_load_weights(Model *m, const char *path) {
|
|||
FREAD_CHECK(w2_all, nl * d * hd, f, "w2");
|
||||
FREAD_CHECK(w3_all, nl * hd * d, f, "w3");
|
||||
|
||||
#define SAFE_MALLOC_MEMCPY(dest, src, size) do { \
|
||||
dest = (float*)malloc(size); \
|
||||
if (!(dest)) { \
|
||||
fprintf(stderr, "ERROR: memory allocation failed for size %zu\n", (size_t)(size)); \
|
||||
fclose(f); return -1; \
|
||||
} \
|
||||
memcpy(dest, src, size); \
|
||||
} while(0)
|
||||
|
||||
for (int l = 0; l < nl; l++) {
|
||||
m->rms_att_w[l] = (float*)malloc(d * sizeof(float));
|
||||
memcpy(m->rms_att_w[l], rms_att_all + l*d, d * sizeof(float));
|
||||
m->wq[l] = (float*)malloc(d*d*sizeof(float));
|
||||
memcpy(m->wq[l], wq_all + l*d*d, d*d*sizeof(float));
|
||||
m->wk[l] = (float*)malloc(d*d*sizeof(float));
|
||||
memcpy(m->wk[l], wk_all + l*d*d, d*d*sizeof(float));
|
||||
m->wv[l] = (float*)malloc(d*d*sizeof(float));
|
||||
memcpy(m->wv[l], wv_all + l*d*d, d*d*sizeof(float));
|
||||
m->wo[l] = (float*)malloc(d*d*sizeof(float));
|
||||
memcpy(m->wo[l], wo_all + l*d*d, d*d*sizeof(float));
|
||||
m->rms_ffn_w[l] = (float*)malloc(d * sizeof(float));
|
||||
memcpy(m->rms_ffn_w[l], rms_ffn_all + l*d, d * sizeof(float));
|
||||
m->w1[l] = (float*)malloc(hd*d*sizeof(float));
|
||||
memcpy(m->w1[l], w1_all + l*hd*d, hd*d*sizeof(float));
|
||||
m->w2[l] = (float*)malloc(d*hd*sizeof(float));
|
||||
memcpy(m->w2[l], w2_all + l*d*hd, d*hd*sizeof(float));
|
||||
m->w3[l] = (float*)malloc(hd*d*sizeof(float));
|
||||
memcpy(m->w3[l], w3_all + l*hd*d, hd*d*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->rms_att_w[l], rms_att_all + l*d, d * sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->wq[l], wq_all + l*d*d, d*d*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->wk[l], wk_all + l*d*d, d*d*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->wv[l], wv_all + l*d*d, d*d*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->wo[l], wo_all + l*d*d, d*d*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->rms_ffn_w[l], rms_ffn_all + l*d, d * sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->w1[l], w1_all + l*hd*d, hd*d*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->w2[l], w2_all + l*d*hd, d*hd*sizeof(float));
|
||||
SAFE_MALLOC_MEMCPY(m->w3[l], w3_all + l*hd*d, hd*d*sizeof(float));
|
||||
}
|
||||
|
||||
#undef SAFE_MALLOC_MEMCPY
|
||||
free(rms_att_all); free(wq_all); free(wk_all); free(wv_all); free(wo_all);
|
||||
free(rms_ffn_all); free(w1_all); free(w2_all); free(w3_all);
|
||||
|
||||
|
|
@ -246,40 +275,55 @@ static int model_recompile_kernels(Model *m) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static void model_alloc_training(Model *m) {
|
||||
int d = m->cfg.dim, hd = m->cfg.hidden_dim, vs = m->cfg.vocab_size, S = m->seq_len;
|
||||
for (int l = 0; l < N_LAYERS; l++) {
|
||||
m->act_x[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_xnorm[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_q[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_k[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_v[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_attn_out[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_ffn_in[l] = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_h1[l] = (float*)calloc(S * hd, sizeof(float));
|
||||
m->act_h3[l] = (float*)calloc(S * hd, sizeof(float));
|
||||
m->act_silu[l] = (float*)calloc(S * hd, sizeof(float));
|
||||
static int model_alloc_training(Model *m) {
|
||||
|
||||
size_t d = (size_t)m->cfg.dim, hd = (size_t)m->cfg.hidden_dim;
|
||||
size_t vs = (size_t)m->cfg.vocab_size, S = (size_t)m->seq_len;
|
||||
|
||||
m->grad_wq[l] = (float*)calloc(d * d, sizeof(float));
|
||||
m->grad_wk[l] = (float*)calloc(d * d, sizeof(float));
|
||||
m->grad_wv[l] = (float*)calloc(d * d, sizeof(float));
|
||||
m->grad_wo[l] = (float*)calloc(d * d, sizeof(float));
|
||||
m->grad_w1[l] = (float*)calloc(hd * d, sizeof(float));
|
||||
m->grad_w2[l] = (float*)calloc(d * hd, sizeof(float));
|
||||
m->grad_w3[l] = (float*)calloc(hd * d, sizeof(float));
|
||||
#define SAFE_CALLOC(dest, count) do { \
|
||||
dest = (float*)calloc(count, sizeof(float)); \
|
||||
if (!(dest)) { \
|
||||
fprintf(stderr, "ERROR: OOM in model_alloc_training for size %zu\n", (size_t)(count)); \
|
||||
return -1; \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
for (int l = 0; l < N_LAYERS; l++) {
|
||||
SAFE_CALLOC(m->act_x[l], S * d);
|
||||
SAFE_CALLOC(m->act_xnorm[l], S * d);
|
||||
SAFE_CALLOC(m->act_q[l], S * d);
|
||||
SAFE_CALLOC(m->act_k[l], S * d);
|
||||
SAFE_CALLOC(m->act_v[l], S * d);
|
||||
SAFE_CALLOC(m->act_attn_out[l], S * d);
|
||||
SAFE_CALLOC(m->act_ffn_in[l], S * d);
|
||||
SAFE_CALLOC(m->act_h1[l], S * hd);
|
||||
SAFE_CALLOC(m->act_h3[l], S * hd);
|
||||
SAFE_CALLOC(m->act_silu[l], S * hd);
|
||||
|
||||
SAFE_CALLOC(m->grad_wq[l], d * d);
|
||||
SAFE_CALLOC(m->grad_wk[l], d * d);
|
||||
SAFE_CALLOC(m->grad_wv[l], d * d);
|
||||
SAFE_CALLOC(m->grad_wo[l], d * d);
|
||||
SAFE_CALLOC(m->grad_w1[l], hd * d);
|
||||
SAFE_CALLOC(m->grad_w2[l], d * hd);
|
||||
SAFE_CALLOC(m->grad_w3[l], hd * d);
|
||||
}
|
||||
m->act_final = (float*)calloc(S * d, sizeof(float));
|
||||
m->act_pre_final = (float*)calloc(S * d, sizeof(float));
|
||||
m->logits = (float*)calloc(S * vs, sizeof(float));
|
||||
m->grad_wcls = (float*)calloc(vs * d, sizeof(float));
|
||||
m->grad_emb = (float*)calloc(vs * d, sizeof(float));
|
||||
SAFE_CALLOC(m->act_final, S * d);
|
||||
SAFE_CALLOC(m->act_pre_final, S * d);
|
||||
SAFE_CALLOC(m->logits, S * vs);
|
||||
SAFE_CALLOC(m->grad_wcls, vs * d);
|
||||
SAFE_CALLOC(m->grad_emb, vs * d);
|
||||
|
||||
m->total_params = 0;
|
||||
for (int l = 0; l < N_LAYERS; l++)
|
||||
m->total_params += 4*(size_t)d*d + 2*(size_t)hd*d + (size_t)d*hd;
|
||||
m->total_params += (size_t)vs * d * 2;
|
||||
m->adam_m = (float*)calloc(m->total_params, sizeof(float));
|
||||
m->adam_v = (float*)calloc(m->total_params, sizeof(float));
|
||||
m->total_params += 4*d*d + 2*hd*d + d*hd;
|
||||
m->total_params += vs * d * 2;
|
||||
SAFE_CALLOC(m->adam_m, m->total_params);
|
||||
SAFE_CALLOC(m->adam_v, m->total_params);
|
||||
m->adam_step = 0;
|
||||
|
||||
#undef SAFE_CALLOC
|
||||
|
||||
printf("Total trainable params: %zu (%.1f M)\n", m->total_params, m->total_params/1e6);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue