mirror of https://github.com/maderix/ANE.git
74 lines
2.3 KiB
C
74 lines
2.3 KiB
C
// data_validation.h — Shared token-data validation helpers
|
|
#pragma once
|
|
#include <stdbool.h>
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
typedef enum {
|
|
TOKEN_DATA_VALID = 0,
|
|
TOKEN_DATA_ERR_TOO_SHORT = 1,
|
|
TOKEN_DATA_ERR_OOB_TOKEN = 2
|
|
} TokenDataValidationCode;
|
|
|
|
typedef struct {
|
|
size_t required_tokens;
|
|
size_t bad_index;
|
|
uint16_t bad_token;
|
|
} TokenDataValidationError;
|
|
|
|
// Token files are 16-bit ids. Return false when byte length is misaligned.
|
|
static inline bool token_data_bytes_to_token_count(size_t n_bytes, size_t *n_tokens, size_t *extra_bytes) {
|
|
size_t rem = n_bytes % sizeof(uint16_t);
|
|
if (n_tokens) *n_tokens = n_bytes / sizeof(uint16_t);
|
|
if (extra_bytes) *extra_bytes = rem;
|
|
return rem == 0;
|
|
}
|
|
|
|
static inline bool token_data_has_min_tokens(size_t n_tokens, int seq, size_t *required_tokens) {
|
|
if (seq < 0) return false;
|
|
size_t needed = (size_t)seq + 1;
|
|
if (required_tokens) *required_tokens = needed;
|
|
return n_tokens >= needed;
|
|
}
|
|
|
|
static inline bool token_data_find_oob_token(const uint16_t *token_data, size_t n_tokens, int vocab,
|
|
size_t *bad_index, uint16_t *bad_token) {
|
|
if (!token_data || n_tokens == 0 || vocab <= 0) return false;
|
|
for (size_t i = 0; i < n_tokens; i++) {
|
|
if ((int)token_data[i] >= vocab) {
|
|
if (bad_index) *bad_index = i;
|
|
if (bad_token) *bad_token = token_data[i];
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static inline TokenDataValidationCode token_data_validate(const uint16_t *token_data, size_t n_tokens,
|
|
int seq, int vocab,
|
|
TokenDataValidationError *err) {
|
|
if (err) {
|
|
err->required_tokens = 0;
|
|
err->bad_index = 0;
|
|
err->bad_token = 0;
|
|
}
|
|
|
|
size_t required = 0;
|
|
if (!token_data_has_min_tokens(n_tokens, seq, &required)) {
|
|
if (err) err->required_tokens = required;
|
|
return TOKEN_DATA_ERR_TOO_SHORT;
|
|
}
|
|
|
|
size_t bad_index = 0;
|
|
uint16_t bad_token = 0;
|
|
if (token_data_find_oob_token(token_data, n_tokens, vocab, &bad_index, &bad_token)) {
|
|
if (err) {
|
|
err->bad_index = bad_index;
|
|
err->bad_token = bad_token;
|
|
}
|
|
return TOKEN_DATA_ERR_OOB_TOKEN;
|
|
}
|
|
|
|
return TOKEN_DATA_VALID;
|
|
}
|