mirror of https://github.com/maderix/ANE.git
922 lines
34 KiB
Metal
922 lines
34 KiB
Metal
#include <metal_stdlib>
|
|
using namespace metal;
|
|
|
|
// ── Q4_0 block format ────────────────────────────────────────────────
|
|
// Block of 32 values: 2 bytes F16 scale + 2 bytes F16 zero + 16 bytes packed uint8
|
|
// Each uint8 stores 2 values: low nibble = even index, high nibble = odd index
|
|
// Total: 20 bytes per block of 32 weights
|
|
#define Q4_BLOCK_SIZE 32
|
|
#define Q4_BLOCK_BYTES 20
|
|
|
|
// ── Q4 Matrix-vector multiply (legacy, 1 thread per row) ────────────
|
|
// Kept as fallback for edge cases.
|
|
kernel void sgemv_q4(
|
|
device const uint8_t *W [[buffer(0)]],
|
|
device const float *x [[buffer(1)]],
|
|
device float *y [[buffer(2)]],
|
|
constant uint &in_dim [[buffer(3)]],
|
|
constant uint &out_dim [[buffer(4)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= out_dim) return;
|
|
|
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
|
device const uint8_t *row = W + uint64_t(gid) * row_bytes;
|
|
|
|
float sum = 0.0f;
|
|
for (uint b = 0; b < n_blocks; b++) {
|
|
device const uint8_t *block = row + b * Q4_BLOCK_BYTES;
|
|
half scale_h, zero_h;
|
|
scale_h = *reinterpret_cast<device const half*>(block);
|
|
zero_h = *reinterpret_cast<device const half*>(block + 2);
|
|
float scale = float(scale_h);
|
|
float zero = float(zero_h);
|
|
|
|
device const uint8_t *packed = block + 4;
|
|
uint base = b * Q4_BLOCK_SIZE;
|
|
|
|
for (uint i = 0; i < 16; i++) {
|
|
uint8_t byte = packed[i];
|
|
float w0 = float(byte & 0xF) * scale + zero;
|
|
float w1 = float(byte >> 4) * scale + zero;
|
|
sum += w0 * x[base + i * 2];
|
|
sum += w1 * x[base + i * 2 + 1];
|
|
}
|
|
}
|
|
y[gid] = sum;
|
|
}
|
|
|
|
// ── Q4 SIMD-optimized matrix-vector multiply ─────────────────────────
|
|
// MLX-style cooperative SIMD kernel: 2 SIMD groups per threadgroup,
|
|
// each SIMD group handles ROWS_PER_SIMD output rows cooperatively.
|
|
// 32 threads in a SIMD group split the K (input) dimension, then
|
|
// reduce via simd_sum(). No threadgroup memory needed.
|
|
//
|
|
// Threadgroup layout: 64 threads = 2 SIMD groups of 32
|
|
// Grid: (ceil(out_dim / ROWS_PER_TG), 1, 1) threadgroups
|
|
//
|
|
// Optional bias: if bias pointer is non-null (use_bias != 0),
|
|
// y[r] = dot(W[r], x) + bias[r]
|
|
#define ROWS_PER_SIMD 4
|
|
#define SIMD_GROUPS 2
|
|
#define ROWS_PER_TG (ROWS_PER_SIMD * SIMD_GROUPS)
|
|
|
|
kernel void sgemv_q4_fast(
|
|
device const uint8_t *W [[buffer(0)]],
|
|
device const float *x [[buffer(1)]],
|
|
device float *y [[buffer(2)]],
|
|
constant uint &in_dim [[buffer(3)]],
|
|
constant uint &out_dim [[buffer(4)]],
|
|
device const float *bias [[buffer(5)]],
|
|
constant uint &use_bias [[buffer(6)]],
|
|
uint tgid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]])
|
|
{
|
|
uint base_row = tgid * ROWS_PER_TG + simd_gid * ROWS_PER_SIMD;
|
|
if (base_row >= out_dim) return;
|
|
|
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
|
|
|
uint rows_this = min((uint)ROWS_PER_SIMD, out_dim - base_row);
|
|
|
|
float accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float zero_accum[ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
|
|
// Each of 32 SIMD lanes processes a stripe of blocks.
|
|
// Lane i processes blocks i, i+32, i+64, ...
|
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
|
uint k_base = b * Q4_BLOCK_SIZE;
|
|
|
|
// Load input vector segment for this block (32 floats)
|
|
float xv[Q4_BLOCK_SIZE];
|
|
for (uint j = 0; j < 16; j++) {
|
|
xv[j * 2] = x[k_base + j * 2];
|
|
xv[j * 2 + 1] = x[k_base + j * 2 + 1];
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
device const uint8_t *block =
|
|
W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
|
|
|
half scale_h = *reinterpret_cast<device const half*>(block);
|
|
half zero_h = *reinterpret_cast<device const half*>(block + 2);
|
|
float scale = float(scale_h);
|
|
float zero = float(zero_h);
|
|
|
|
device const uint8_t *packed = block + 4;
|
|
|
|
float dot = 0.0f;
|
|
float xsum = 0.0f;
|
|
for (uint j = 0; j < 16; j++) {
|
|
uint8_t byte = packed[j];
|
|
float w0 = float(byte & 0xF);
|
|
float w1 = float(byte >> 4);
|
|
dot += w0 * xv[j * 2] + w1 * xv[j * 2 + 1];
|
|
xsum += xv[j * 2] + xv[j * 2 + 1];
|
|
}
|
|
accum[r] += dot * scale;
|
|
zero_accum[r] += xsum * zero;
|
|
}
|
|
}
|
|
|
|
// SIMD reduction across 32 lanes
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]);
|
|
if (simd_lid == 0) {
|
|
if (use_bias != 0) {
|
|
result += bias[base_row + r];
|
|
}
|
|
y[base_row + r] = result;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Fused Gate+Up+SiLU: reads x once, computes gate=silu(Wg*x)*Wu*x ──
|
|
// Combines two Q4 matvecs + silu_mul into one kernel.
|
|
// W_gate and W_up have the same dimensions [out_dim, in_dim].
|
|
// Output: gate[r] = silu(dot(W_gate[r], x)) * dot(W_up[r], x)
|
|
#define FUSED_ROWS_PER_SIMD 2
|
|
#define FUSED_SIMD_GROUPS 2
|
|
#define FUSED_ROWS_PER_TG (FUSED_ROWS_PER_SIMD * FUSED_SIMD_GROUPS)
|
|
|
|
kernel void sgemv_q4_fused_ffn(
|
|
device const uint8_t *W_gate [[buffer(0)]],
|
|
device const uint8_t *W_up [[buffer(1)]],
|
|
device const float *x [[buffer(2)]],
|
|
device float *out [[buffer(3)]],
|
|
constant uint &in_dim [[buffer(4)]],
|
|
constant uint &out_dim [[buffer(5)]],
|
|
uint tgid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]])
|
|
{
|
|
uint base_row = tgid * FUSED_ROWS_PER_TG + simd_gid * FUSED_ROWS_PER_SIMD;
|
|
if (base_row >= out_dim) return;
|
|
|
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
|
|
|
uint rows_this = min((uint)FUSED_ROWS_PER_SIMD, out_dim - base_row);
|
|
|
|
float gate_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
float gate_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
float up_acc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
float up_zacc[FUSED_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
|
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
|
uint k_base = b * Q4_BLOCK_SIZE;
|
|
|
|
float xv[Q4_BLOCK_SIZE];
|
|
for (uint j = 0; j < 16; j++) {
|
|
xv[j * 2] = x[k_base + j * 2];
|
|
xv[j * 2 + 1] = x[k_base + j * 2 + 1];
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
|
|
|
// Gate weight block
|
|
device const uint8_t *g_block = W_gate + row_off;
|
|
float g_scale = float(*reinterpret_cast<device const half*>(g_block));
|
|
float g_zero = float(*reinterpret_cast<device const half*>(g_block + 2));
|
|
device const uint8_t *g_packed = g_block + 4;
|
|
|
|
// Up weight block
|
|
device const uint8_t *u_block = W_up + row_off;
|
|
float u_scale = float(*reinterpret_cast<device const half*>(u_block));
|
|
float u_zero = float(*reinterpret_cast<device const half*>(u_block + 2));
|
|
device const uint8_t *u_packed = u_block + 4;
|
|
|
|
float g_dot = 0.0f, g_xsum = 0.0f;
|
|
float u_dot = 0.0f, u_xsum = 0.0f;
|
|
for (uint j = 0; j < 16; j++) {
|
|
float x0 = xv[j * 2];
|
|
float x1 = xv[j * 2 + 1];
|
|
float xs = x0 + x1;
|
|
|
|
uint8_t gb = g_packed[j];
|
|
g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1;
|
|
g_xsum += xs;
|
|
|
|
uint8_t ub = u_packed[j];
|
|
u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1;
|
|
u_xsum += xs;
|
|
}
|
|
gate_acc[r] += g_dot * g_scale;
|
|
gate_zacc[r] += g_xsum * g_zero;
|
|
up_acc[r] += u_dot * u_scale;
|
|
up_zacc[r] += u_xsum * u_zero;
|
|
}
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]);
|
|
float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]);
|
|
if (simd_lid == 0) {
|
|
float s = g / (1.0f + exp(-g));
|
|
out[base_row + r] = s * u;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Q4 batched matrix-matrix multiply (SGEMM) for prefill ────────────
|
|
// Y[t, r] = sum_k(dequant(W[r, k]) * X[t, k]) for t in [0, n_tokens), r in [0, out_dim)
|
|
// Grid: (ceil(out_dim / GEMM_TILE_M), n_tokens, 1)
|
|
// Each threadgroup: 2 SIMD groups, each handles GEMM_TILE_M/2 output rows for one token.
|
|
#define GEMM_TILE_M 8
|
|
#define GEMM_SIMD_GROUPS 2
|
|
#define GEMM_ROWS_PER_SIMD (GEMM_TILE_M / GEMM_SIMD_GROUPS)
|
|
|
|
kernel void sgemm_q4(
|
|
device const uint8_t *W [[buffer(0)]],
|
|
device const float *X [[buffer(1)]],
|
|
device float *Y [[buffer(2)]],
|
|
constant uint &in_dim [[buffer(3)]],
|
|
constant uint &out_dim [[buffer(4)]],
|
|
device const float *bias [[buffer(5)]],
|
|
constant uint &use_bias [[buffer(6)]],
|
|
constant uint &n_tokens [[buffer(7)]],
|
|
uint2 tgid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]])
|
|
{
|
|
uint base_row = tgid.x * GEMM_TILE_M + simd_gid * GEMM_ROWS_PER_SIMD;
|
|
uint t = tgid.y;
|
|
if (base_row >= out_dim || t >= n_tokens) return;
|
|
|
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
|
uint rows_this = min((uint)GEMM_ROWS_PER_SIMD, out_dim - base_row);
|
|
|
|
device const float *xt = X + uint64_t(t) * in_dim;
|
|
|
|
float accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
float zero_accum[GEMM_ROWS_PER_SIMD] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
|
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
|
uint k_base = b * Q4_BLOCK_SIZE;
|
|
|
|
float xv[Q4_BLOCK_SIZE];
|
|
for (uint j = 0; j < 16; j++) {
|
|
xv[j * 2] = xt[k_base + j * 2];
|
|
xv[j * 2 + 1] = xt[k_base + j * 2 + 1];
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
device const uint8_t *block =
|
|
W + uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
|
|
|
float scale = float(*reinterpret_cast<device const half*>(block));
|
|
float zero = float(*reinterpret_cast<device const half*>(block + 2));
|
|
device const uint8_t *packed = block + 4;
|
|
|
|
float dot = 0.0f;
|
|
float xsum = 0.0f;
|
|
for (uint j = 0; j < 16; j++) {
|
|
uint8_t byte = packed[j];
|
|
dot += float(byte & 0xF) * xv[j * 2] + float(byte >> 4) * xv[j * 2 + 1];
|
|
xsum += xv[j * 2] + xv[j * 2 + 1];
|
|
}
|
|
accum[r] += dot * scale;
|
|
zero_accum[r] += xsum * zero;
|
|
}
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
float result = simd_sum(accum[r]) + simd_sum(zero_accum[r]);
|
|
if (simd_lid == 0) {
|
|
if (use_bias != 0)
|
|
result += bias[base_row + r];
|
|
Y[uint64_t(t) * out_dim + base_row + r] = result;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Q4 batched fused Gate+Up+SiLU (SGEMM variant) ───────────────────
|
|
// out[t, r] = silu(Wg[r] . X[t]) * Wu[r] . X[t] for all t and r
|
|
// Grid: (ceil(out_dim / GEMM_FFN_TILE_M), n_tokens, 1)
|
|
#define GEMM_FFN_TILE_M 4
|
|
#define GEMM_FFN_SIMD_GROUPS 2
|
|
#define GEMM_FFN_ROWS_PER_SIMD (GEMM_FFN_TILE_M / GEMM_FFN_SIMD_GROUPS)
|
|
|
|
kernel void sgemm_q4_fused_ffn(
|
|
device const uint8_t *W_gate [[buffer(0)]],
|
|
device const uint8_t *W_up [[buffer(1)]],
|
|
device const float *X [[buffer(2)]],
|
|
device float *out [[buffer(3)]],
|
|
constant uint &in_dim [[buffer(4)]],
|
|
constant uint &out_dim [[buffer(5)]],
|
|
constant uint &n_tokens [[buffer(6)]],
|
|
uint2 tgid [[threadgroup_position_in_grid]],
|
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
uint simd_lid [[thread_index_in_simdgroup]])
|
|
{
|
|
uint base_row = tgid.x * GEMM_FFN_TILE_M + simd_gid * GEMM_FFN_ROWS_PER_SIMD;
|
|
uint t = tgid.y;
|
|
if (base_row >= out_dim || t >= n_tokens) return;
|
|
|
|
uint n_blocks = in_dim / Q4_BLOCK_SIZE;
|
|
uint row_bytes = n_blocks * Q4_BLOCK_BYTES;
|
|
uint rows_this = min((uint)GEMM_FFN_ROWS_PER_SIMD, out_dim - base_row);
|
|
|
|
device const float *xt = X + uint64_t(t) * in_dim;
|
|
|
|
float gate_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
float gate_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
float up_acc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
float up_zacc[GEMM_FFN_ROWS_PER_SIMD] = {0.0f, 0.0f};
|
|
|
|
for (uint b = simd_lid; b < n_blocks; b += 32) {
|
|
uint k_base = b * Q4_BLOCK_SIZE;
|
|
|
|
float xv[Q4_BLOCK_SIZE];
|
|
for (uint j = 0; j < 16; j++) {
|
|
xv[j * 2] = xt[k_base + j * 2];
|
|
xv[j * 2 + 1] = xt[k_base + j * 2 + 1];
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
uint64_t row_off = uint64_t(base_row + r) * row_bytes + uint64_t(b) * Q4_BLOCK_BYTES;
|
|
|
|
device const uint8_t *g_block = W_gate + row_off;
|
|
float g_scale = float(*reinterpret_cast<device const half*>(g_block));
|
|
float g_zero = float(*reinterpret_cast<device const half*>(g_block + 2));
|
|
device const uint8_t *g_packed = g_block + 4;
|
|
|
|
device const uint8_t *u_block = W_up + row_off;
|
|
float u_scale = float(*reinterpret_cast<device const half*>(u_block));
|
|
float u_zero = float(*reinterpret_cast<device const half*>(u_block + 2));
|
|
device const uint8_t *u_packed = u_block + 4;
|
|
|
|
float g_dot = 0.0f, g_xsum = 0.0f;
|
|
float u_dot = 0.0f, u_xsum = 0.0f;
|
|
for (uint j = 0; j < 16; j++) {
|
|
float x0 = xv[j * 2];
|
|
float x1 = xv[j * 2 + 1];
|
|
float xs = x0 + x1;
|
|
|
|
uint8_t gb = g_packed[j];
|
|
g_dot += float(gb & 0xF) * x0 + float(gb >> 4) * x1;
|
|
g_xsum += xs;
|
|
|
|
uint8_t ub = u_packed[j];
|
|
u_dot += float(ub & 0xF) * x0 + float(ub >> 4) * x1;
|
|
u_xsum += xs;
|
|
}
|
|
gate_acc[r] += g_dot * g_scale;
|
|
gate_zacc[r] += g_xsum * g_zero;
|
|
up_acc[r] += u_dot * u_scale;
|
|
up_zacc[r] += u_xsum * u_zero;
|
|
}
|
|
}
|
|
|
|
for (uint r = 0; r < rows_this; r++) {
|
|
float g = simd_sum(gate_acc[r]) + simd_sum(gate_zacc[r]);
|
|
float u = simd_sum(up_acc[r]) + simd_sum(up_zacc[r]);
|
|
if (simd_lid == 0) {
|
|
float s = g / (1.0f + exp(-g));
|
|
out[uint64_t(t) * out_dim + base_row + r] = s * u;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── Batched RMSNorm (N tokens) ──────────────────────────────────────
|
|
// x[t*dim .. (t+1)*dim-1] → out[t*dim .. (t+1)*dim-1]
|
|
// Grid: (n_tokens, 1, 1) threadgroups, each normalizes one token.
|
|
kernel void rms_norm_batched(
|
|
device const float *x [[buffer(0)]],
|
|
device const float *w [[buffer(1)]],
|
|
device float *out [[buffer(2)]],
|
|
constant uint &dim [[buffer(3)]],
|
|
constant float &eps [[buffer(4)]],
|
|
constant uint &n_tokens [[buffer(5)]],
|
|
uint tgid [[threadgroup_position_in_grid]],
|
|
uint tid [[thread_index_in_threadgroup]],
|
|
uint tpg [[threads_per_threadgroup]])
|
|
{
|
|
if (tgid >= n_tokens) return;
|
|
|
|
device const float *xi = x + uint64_t(tgid) * dim;
|
|
device float *oi = out + uint64_t(tgid) * dim;
|
|
|
|
threadgroup float partial[1024];
|
|
|
|
float local_sum = 0.0f;
|
|
for (uint i = tid; i < dim; i += tpg)
|
|
local_sum += xi[i] * xi[i];
|
|
partial[tid] = local_sum;
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s) partial[tid] += partial[tid + s];
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
float rms_inv = rsqrt(partial[0] / float(dim) + eps);
|
|
|
|
for (uint i = tid; i < dim; i += tpg)
|
|
oi[i] = xi[i] * rms_inv * w[i];
|
|
}
|
|
|
|
// ── Batched embedding lookup (N tokens) ─────────────────────────────
|
|
// Grid: (dim, n_tokens, 1). Each thread copies one element.
|
|
kernel void embed_lookup_batched(
|
|
device const float *embed [[buffer(0)]],
|
|
device float *out [[buffer(1)]],
|
|
device const uint *token_ids [[buffer(2)]],
|
|
constant uint &dim [[buffer(3)]],
|
|
uint2 gid [[thread_position_in_grid]])
|
|
{
|
|
uint i = gid.x;
|
|
uint t = gid.y;
|
|
if (i >= dim) return;
|
|
out[uint64_t(t) * dim + i] = embed[uint64_t(token_ids[t]) * dim + i];
|
|
}
|
|
|
|
// ── Batched RoPE (N tokens) ─────────────────────────────────────────
|
|
// Applies RoPE to Q[t] and K[t] for each token t at position base_pos+t.
|
|
// Grid: total_pairs per token * n_tokens
|
|
kernel void rope_apply_batched(
|
|
device float *q [[buffer(0)]],
|
|
device float *k [[buffer(1)]],
|
|
device const float *cos_tbl [[buffer(2)]],
|
|
device const float *sin_tbl [[buffer(3)]],
|
|
constant uint &n_q_heads [[buffer(4)]],
|
|
constant uint &n_kv_heads [[buffer(5)]],
|
|
constant uint &head_dim [[buffer(6)]],
|
|
constant uint &base_pos [[buffer(7)]],
|
|
constant uint &q_stride [[buffer(8)]],
|
|
constant uint &k_stride [[buffer(9)]],
|
|
uint2 gid [[thread_position_in_grid]])
|
|
{
|
|
uint pair_idx = gid.x;
|
|
uint t = gid.y;
|
|
|
|
uint half_dim = head_dim / 2;
|
|
uint total_pairs = (n_q_heads + n_kv_heads) * half_dim;
|
|
if (pair_idx >= total_pairs) return;
|
|
|
|
uint head_pair = pair_idx / half_dim;
|
|
uint i = pair_idx % half_dim;
|
|
uint pos = base_pos + t;
|
|
|
|
device float *vec;
|
|
if (head_pair < n_q_heads)
|
|
vec = q + uint64_t(t) * q_stride + head_pair * head_dim;
|
|
else
|
|
vec = k + uint64_t(t) * k_stride + (head_pair - n_q_heads) * head_dim;
|
|
|
|
uint cos_off = pos * half_dim;
|
|
float f = vec[i];
|
|
float s = vec[i + half_dim];
|
|
float c = cos_tbl[cos_off + i];
|
|
float sv = sin_tbl[cos_off + i];
|
|
vec[i] = f * c - s * sv;
|
|
vec[i + half_dim] = s * c + f * sv;
|
|
}
|
|
|
|
// ── Batched vec_add: out[i] = a[i] + b[i] for N*dim elements ────────
|
|
kernel void vec_add_batched(
|
|
device const float *a [[buffer(0)]],
|
|
device const float *b [[buffer(1)]],
|
|
device float *out [[buffer(2)]],
|
|
constant uint &total_n [[buffer(3)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= total_n) return;
|
|
out[gid] = a[gid] + b[gid];
|
|
}
|
|
|
|
// ── F16 matrix-vector multiply ───────────────────────────────────────
|
|
kernel void sgemv_f16(
|
|
device const half *W [[buffer(0)]],
|
|
device const float *x [[buffer(1)]],
|
|
device float *y [[buffer(2)]],
|
|
constant uint &in_dim [[buffer(3)]],
|
|
constant uint &out_dim [[buffer(4)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= out_dim) return;
|
|
|
|
device const half *row = W + uint64_t(gid) * in_dim;
|
|
float sum = 0.0f;
|
|
|
|
uint i = 0;
|
|
for (; i + 7 < in_dim; i += 8) {
|
|
sum += float(row[i]) * x[i];
|
|
sum += float(row[i + 1]) * x[i + 1];
|
|
sum += float(row[i + 2]) * x[i + 2];
|
|
sum += float(row[i + 3]) * x[i + 3];
|
|
sum += float(row[i + 4]) * x[i + 4];
|
|
sum += float(row[i + 5]) * x[i + 5];
|
|
sum += float(row[i + 6]) * x[i + 6];
|
|
sum += float(row[i + 7]) * x[i + 7];
|
|
}
|
|
for (; i < in_dim; i++)
|
|
sum += float(row[i]) * x[i];
|
|
|
|
y[gid] = sum;
|
|
}
|
|
|
|
// ── F32 matrix-vector multiply ───────────────────────────────────────
|
|
kernel void sgemv_f32(
|
|
device const float *W [[buffer(0)]],
|
|
device const float *x [[buffer(1)]],
|
|
device float *y [[buffer(2)]],
|
|
constant uint &in_dim [[buffer(3)]],
|
|
constant uint &out_dim [[buffer(4)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= out_dim) return;
|
|
|
|
device const float *row = W + uint64_t(gid) * in_dim;
|
|
float sum = 0.0f;
|
|
|
|
uint i = 0;
|
|
for (; i + 7 < in_dim; i += 8) {
|
|
sum += row[i] * x[i];
|
|
sum += row[i + 1] * x[i + 1];
|
|
sum += row[i + 2] * x[i + 2];
|
|
sum += row[i + 3] * x[i + 3];
|
|
sum += row[i + 4] * x[i + 4];
|
|
sum += row[i + 5] * x[i + 5];
|
|
sum += row[i + 6] * x[i + 6];
|
|
sum += row[i + 7] * x[i + 7];
|
|
}
|
|
for (; i < in_dim; i++)
|
|
sum += row[i] * x[i];
|
|
|
|
y[gid] = sum;
|
|
}
|
|
|
|
// ── RMS Normalization ────────────────────────────────────────────────
|
|
// out[i] = x[i] * w[i] / sqrt(mean(x^2) + eps)
|
|
// Two-pass: first compute sum of squares (reduction), then normalize.
|
|
// Single threadgroup processes the entire vector.
|
|
kernel void rms_norm(
|
|
device const float *x [[buffer(0)]],
|
|
device const float *w [[buffer(1)]],
|
|
device float *out [[buffer(2)]],
|
|
constant uint &dim [[buffer(3)]],
|
|
constant float &eps [[buffer(4)]],
|
|
uint tid [[thread_index_in_threadgroup]],
|
|
uint tpg [[threads_per_threadgroup]])
|
|
{
|
|
threadgroup float partial[1024];
|
|
|
|
float local_sum = 0.0f;
|
|
for (uint i = tid; i < dim; i += tpg)
|
|
local_sum += x[i] * x[i];
|
|
partial[tid] = local_sum;
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
// Tree reduction
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s) partial[tid] += partial[tid + s];
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
float rms_inv = rsqrt(partial[0] / float(dim) + eps);
|
|
|
|
for (uint i = tid; i < dim; i += tpg)
|
|
out[i] = x[i] * rms_inv * w[i];
|
|
}
|
|
|
|
// ── RoPE (Rotary Position Embedding) ─────────────────────────────────
|
|
// Applies RoPE to Q and K vectors in-place.
|
|
// cos_sin is precomputed: [half_dim] cos values followed by [half_dim] sin values.
|
|
kernel void rope_apply(
|
|
device float *q [[buffer(0)]],
|
|
device float *k [[buffer(1)]],
|
|
device const float *cos_v [[buffer(2)]],
|
|
device const float *sin_v [[buffer(3)]],
|
|
constant uint &n_q_heads [[buffer(4)]],
|
|
constant uint &n_kv_heads [[buffer(5)]],
|
|
constant uint &head_dim [[buffer(6)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
uint half_dim = head_dim / 2;
|
|
uint total_pairs = (n_q_heads + n_kv_heads) * half_dim;
|
|
if (gid >= total_pairs) return;
|
|
|
|
uint head_pair = gid / half_dim;
|
|
uint i = gid % half_dim;
|
|
|
|
device float *vec;
|
|
if (head_pair < n_q_heads) {
|
|
vec = q + head_pair * head_dim;
|
|
} else {
|
|
vec = k + (head_pair - n_q_heads) * head_dim;
|
|
}
|
|
|
|
float f = vec[i];
|
|
float s = vec[i + half_dim];
|
|
float c = cos_v[i];
|
|
float sv = sin_v[i];
|
|
vec[i] = f * c - s * sv;
|
|
vec[i + half_dim] = s * c + f * sv;
|
|
}
|
|
|
|
// ── SiLU activation + element-wise multiply ──────────────────────────
|
|
// gate[i] = silu(gate[i]) * up[i]
|
|
// silu(x) = x / (1 + exp(-x))
|
|
kernel void silu_mul(
|
|
device float *gate [[buffer(0)]],
|
|
device const float *up [[buffer(1)]],
|
|
constant uint &n [[buffer(2)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= n) return;
|
|
float x = gate[gid];
|
|
float s = x / (1.0f + exp(-x));
|
|
gate[gid] = s * up[gid];
|
|
}
|
|
|
|
// ── Vector add (residual connection) ─────────────────────────────────
|
|
// out[i] = a[i] + b[i]
|
|
kernel void vec_add(
|
|
device const float *a [[buffer(0)]],
|
|
device const float *b [[buffer(1)]],
|
|
device float *out [[buffer(2)]],
|
|
constant uint &n [[buffer(3)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= n) return;
|
|
out[gid] = a[gid] + b[gid];
|
|
}
|
|
|
|
// ── Bias add ─────────────────────────────────────────────────────────
|
|
// x[i] += bias[i]
|
|
kernel void bias_add(
|
|
device float *x [[buffer(0)]],
|
|
device const float *bias [[buffer(1)]],
|
|
constant uint &n [[buffer(2)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= n) return;
|
|
x[gid] += bias[gid];
|
|
}
|
|
|
|
// ── Embedding lookup ─────────────────────────────────────────────────
|
|
// out[i] = embed[token_id * dim + i]
|
|
kernel void embed_lookup(
|
|
device const float *embed [[buffer(0)]],
|
|
device float *out [[buffer(1)]],
|
|
constant uint &token_id [[buffer(2)]],
|
|
constant uint &dim [[buffer(3)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= dim) return;
|
|
out[gid] = embed[uint64_t(token_id) * dim + gid];
|
|
}
|
|
|
|
// ── Attention score: Q @ K^T for one head (legacy) ──────────────────
|
|
kernel void attn_score(
|
|
device const float *qh [[buffer(0)]],
|
|
device const float *kv_cache_k [[buffer(1)]],
|
|
device float *att [[buffer(2)]],
|
|
constant uint &head_dim [[buffer(3)]],
|
|
constant uint &kv_dim [[buffer(4)]],
|
|
constant uint &kv_head_offset [[buffer(5)]],
|
|
constant float &scale [[buffer(6)]],
|
|
constant uint &seq_len [[buffer(7)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= seq_len) return;
|
|
|
|
device const float *kt = kv_cache_k + uint64_t(gid) * kv_dim + kv_head_offset;
|
|
float dot = 0.0f;
|
|
for (uint i = 0; i < head_dim; i++)
|
|
dot += qh[i] * kt[i];
|
|
att[gid] = dot * scale;
|
|
}
|
|
|
|
// ── Batched attention score: all Q heads in one dispatch ─────────────
|
|
// Grid: (seq_len, n_q_heads, 1). Each thread computes one score for one head.
|
|
// GQA: maps Q head h to KV head h/gqa_factor.
|
|
kernel void attn_score_batched(
|
|
device const float *q [[buffer(0)]],
|
|
device const float *kv_cache_k [[buffer(1)]],
|
|
device float *att [[buffer(2)]],
|
|
constant uint &head_dim [[buffer(3)]],
|
|
constant uint &kv_dim [[buffer(4)]],
|
|
constant uint &n_q_heads [[buffer(5)]],
|
|
constant uint &gqa_factor [[buffer(6)]],
|
|
constant float &scale [[buffer(7)]],
|
|
constant uint &seq_len [[buffer(8)]],
|
|
constant uint &max_seq [[buffer(9)]],
|
|
uint2 gid [[thread_position_in_grid]])
|
|
{
|
|
uint t = gid.x;
|
|
uint h = gid.y;
|
|
if (t >= seq_len || h >= n_q_heads) return;
|
|
|
|
uint kv_h = h / gqa_factor;
|
|
device const float *qh = q + h * head_dim;
|
|
device const float *kt = kv_cache_k + uint64_t(t) * kv_dim + kv_h * head_dim;
|
|
|
|
float dot = 0.0f;
|
|
for (uint i = 0; i < head_dim; i++)
|
|
dot += qh[i] * kt[i];
|
|
|
|
att[h * max_seq + t] = dot * scale;
|
|
}
|
|
|
|
// ── Batched softmax: all heads in one dispatch ───────────────────────
|
|
// One threadgroup per head. tid reduces over seq_len dimension.
|
|
kernel void softmax_batched(
|
|
device float *att [[buffer(0)]],
|
|
constant uint &seq_len [[buffer(1)]],
|
|
constant uint &max_seq [[buffer(2)]],
|
|
constant uint &n_q_heads [[buffer(3)]],
|
|
uint tgid [[threadgroup_position_in_grid]],
|
|
uint tid [[thread_index_in_threadgroup]],
|
|
uint tpg [[threads_per_threadgroup]])
|
|
{
|
|
uint h = tgid;
|
|
if (h >= n_q_heads) return;
|
|
|
|
device float *head_att = att + h * max_seq;
|
|
threadgroup float shared[1024];
|
|
|
|
float local_max = -1e30f;
|
|
for (uint i = tid; i < seq_len; i += tpg)
|
|
local_max = max(local_max, head_att[i]);
|
|
shared[tid] = local_max;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
float max_val = shared[0];
|
|
|
|
float local_sum = 0.0f;
|
|
for (uint i = tid; i < seq_len; i += tpg) {
|
|
float e = exp(head_att[i] - max_val);
|
|
head_att[i] = e;
|
|
local_sum += e;
|
|
}
|
|
shared[tid] = local_sum;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s) shared[tid] += shared[tid + s];
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
float inv_sum = 1.0f / shared[0];
|
|
|
|
for (uint i = tid; i < seq_len; i += tpg)
|
|
head_att[i] *= inv_sum;
|
|
}
|
|
|
|
// ── Batched attention weighted sum: all heads in one dispatch ────────
|
|
// Grid: (head_dim, n_q_heads, 1). Each thread computes one output dim for one head.
|
|
kernel void attn_wsum_batched(
|
|
device const float *att [[buffer(0)]],
|
|
device const float *kv_cache_v [[buffer(1)]],
|
|
device float *out [[buffer(2)]],
|
|
constant uint &head_dim [[buffer(3)]],
|
|
constant uint &kv_dim [[buffer(4)]],
|
|
constant uint &n_q_heads [[buffer(5)]],
|
|
constant uint &gqa_factor [[buffer(6)]],
|
|
constant uint &seq_len [[buffer(7)]],
|
|
constant uint &max_seq [[buffer(8)]],
|
|
uint2 gid [[thread_position_in_grid]])
|
|
{
|
|
uint d = gid.x;
|
|
uint h = gid.y;
|
|
if (d >= head_dim || h >= n_q_heads) return;
|
|
|
|
uint kv_h = h / gqa_factor;
|
|
device const float *head_att = att + h * max_seq;
|
|
|
|
float sum = 0.0f;
|
|
for (uint t = 0; t < seq_len; t++) {
|
|
float a = head_att[t];
|
|
float v = kv_cache_v[uint64_t(t) * kv_dim + kv_h * head_dim + d];
|
|
sum += a * v;
|
|
}
|
|
out[h * head_dim + d] = sum;
|
|
}
|
|
|
|
// ── Softmax (legacy, single head) ───────────────────────────────────
|
|
kernel void softmax_inplace(
|
|
device float *att [[buffer(0)]],
|
|
constant uint &seq_len [[buffer(1)]],
|
|
uint tid [[thread_index_in_threadgroup]],
|
|
uint tpg [[threads_per_threadgroup]])
|
|
{
|
|
threadgroup float shared[1024];
|
|
|
|
float local_max = -1e30f;
|
|
for (uint i = tid; i < seq_len; i += tpg)
|
|
local_max = max(local_max, att[i]);
|
|
shared[tid] = local_max;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s) shared[tid] = max(shared[tid], shared[tid + s]);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
float max_val = shared[0];
|
|
|
|
float local_sum = 0.0f;
|
|
for (uint i = tid; i < seq_len; i += tpg) {
|
|
float e = exp(att[i] - max_val);
|
|
att[i] = e;
|
|
local_sum += e;
|
|
}
|
|
shared[tid] = local_sum;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s) shared[tid] += shared[tid + s];
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
float inv_sum = 1.0f / shared[0];
|
|
|
|
for (uint i = tid; i < seq_len; i += tpg)
|
|
att[i] *= inv_sum;
|
|
}
|
|
|
|
// ── Attention weighted sum (legacy, single head) ─────────────────────
|
|
kernel void attn_weighted_sum(
|
|
device const float *att [[buffer(0)]],
|
|
device const float *kv_cache_v [[buffer(1)]],
|
|
device float *out [[buffer(2)]],
|
|
constant uint &head_dim [[buffer(3)]],
|
|
constant uint &kv_dim [[buffer(4)]],
|
|
constant uint &kv_head_offset [[buffer(5)]],
|
|
constant uint &seq_len [[buffer(6)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= head_dim) return;
|
|
|
|
float sum = 0.0f;
|
|
for (uint t = 0; t < seq_len; t++) {
|
|
float a = att[t];
|
|
float v = kv_cache_v[uint64_t(t) * kv_dim + kv_head_offset + gid];
|
|
sum += a * v;
|
|
}
|
|
out[gid] = sum;
|
|
}
|
|
|
|
// ── Argmax ───────────────────────────────────────────────────────────
|
|
// Finds argmax of logits[0..n-1], writes to result[0].
|
|
// Single threadgroup.
|
|
kernel void argmax_kernel(
|
|
device const float *logits [[buffer(0)]],
|
|
device int *result [[buffer(1)]],
|
|
constant uint &n [[buffer(2)]],
|
|
uint tid [[thread_index_in_threadgroup]],
|
|
uint tpg [[threads_per_threadgroup]])
|
|
{
|
|
threadgroup float shared_val[1024];
|
|
threadgroup int shared_idx[1024];
|
|
|
|
float local_max = -1e30f;
|
|
int local_idx = 0;
|
|
for (uint i = tid; i < n; i += tpg) {
|
|
if (logits[i] > local_max) {
|
|
local_max = logits[i];
|
|
local_idx = int(i);
|
|
}
|
|
}
|
|
shared_val[tid] = local_max;
|
|
shared_idx[tid] = local_idx;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
for (uint s = tpg / 2; s > 0; s >>= 1) {
|
|
if (tid < s && shared_val[tid + s] > shared_val[tid]) {
|
|
shared_val[tid] = shared_val[tid + s];
|
|
shared_idx[tid] = shared_idx[tid + s];
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
}
|
|
|
|
if (tid == 0) result[0] = shared_idx[0];
|
|
}
|
|
|
|
// ── Copy kernel ──────────────────────────────────────────────────────
|
|
// dst[i] = src[i] for i in [0, n)
|
|
kernel void vec_copy(
|
|
device const float *src [[buffer(0)]],
|
|
device float *dst [[buffer(1)]],
|
|
constant uint &n [[buffer(2)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= n) return;
|
|
dst[gid] = src[gid];
|
|
}
|
|
|
|
// ── Zero-fill ────────────────────────────────────────────────────────
|
|
kernel void vec_zero(
|
|
device float *dst [[buffer(0)]],
|
|
constant uint &n [[buffer(1)]],
|
|
uint gid [[thread_position_in_grid]])
|
|
{
|
|
if (gid >= n) return;
|
|
dst[gid] = 0.0f;
|
|
}
|