11/* Q5_0 fused dequant-GEMV kernel for single-token decode (batch=1).
22 *
3- * Reads Q5_0 blocks directly, dequantizes in registers (no global
4- * memory intermediary), multiplies by the activation vector, and accumulates
5- * in FP32. This halves memory traffic compared to separate dequant + GEMV.
3+ * GPU-optimized SEPARATED layout (from Q5_0Storage.RawBytesGPU):
4+ * Region 1: [nBlocks * 2 bytes] fp16 scales, padded to 16-byte boundary
5+ * Region 2: [nBlocks * 4 bytes] uint32 qh values, padded to 16-byte boundary
6+ * Region 3: [nBlocks * 16 bytes] packed nibbles (qs)
67 *
7- * Q5_0 block (22 bytes, 32 values):
8- * [0:2] fp16 d -- block scale
9- * [2:6] uint32 qh -- 32 high bits (one per element)
10- * [6:22] 16 bytes qs -- packed nibbles (two 4-bit values per byte)
8+ * This layout ensures natural alignment: fp16 at 2-byte, uint32 at 4-byte.
9+ * Eliminates the byte-wise loads required for the interleaved 22-byte layout
10+ * on ARM64 Grace Hopper.
1111 *
1212 * Dequantization (matching llama.cpp dequantize_row_q5_0):
1313 * For j in 0..15:
2525#include < stdint.h>
2626
2727#define Q5_0_BLOCK_SIZE 32
28- #define Q5_0_BLOCK_BYTES 22
2928#define Q5_0_WARPS_PER_BLOCK 4
3029#define Q5_0_WARP_SIZE 32
3130
32- /* ---------- Fused GEMV kernel ----------
31+ /* ---------- Fused GEMV kernel (separated GPU layout) ----------
3332 *
3433 * y[row] = sum_k dequant(W_q5_0[row, k]) * x[k]
3534 *
36- * Strategy:
37- * - Load input vector x into shared memory (all threads cooperate).
38- * - One warp per row for simplicity and good occupancy.
39- * - Each lane processes a strided subset of blocks.
40- * - Within each block, 16 packed bytes yield 32 dequantized values.
41- * - Warp shuffle reduction produces the final dot product.
35+ * W_q5_0 points to the separated layout base. qhOffset and qsOffset
36+ * are byte offsets to the qh and qs regions respectively.
4237 */
4338__global__ void gemv_q5_0_kernel (
4439 const uint8_t * __restrict__ W_q5_0,
4540 const float * __restrict__ x,
4641 float * __restrict__ y,
47- int M, int K)
42+ int M, int K,
43+ int qhOffset, int qsOffset)
4844{
4945 extern __shared__ float sx[];
5046
@@ -62,27 +58,20 @@ __global__ void gemv_q5_0_kernel(
6258 if (row >= M) return ;
6359
6460 int blocks_per_row = K / Q5_0_BLOCK_SIZE;
65- const uint8_t * row_data = W_q5_0 + (size_t )row * blocks_per_row * Q5_0_BLOCK_BYTES;
61+
62+ /* Pointers to the three separated regions for this row. */
63+ const __half* row_scales = (const __half*)(W_q5_0 + row * blocks_per_row * 2 );
64+ const uint32_t * row_qh = (const uint32_t *)(W_q5_0 + qhOffset + row * blocks_per_row * 4 );
65+ const uint8_t * row_qs = W_q5_0 + qsOffset + (size_t )row * blocks_per_row * 16 ;
6666
6767 float acc = 0 .0f ;
6868
6969 /* Each lane handles a strided subset of blocks. */
7070 for (int bi = lane_id; bi < blocks_per_row; bi += Q5_0_WARP_SIZE) {
71- const uint8_t * blk = row_data + bi * Q5_0_BLOCK_BYTES;
72-
73- /* Read fp16 d using byte-wise load (ARM64 alignment safety).
74- * Q5_0 blocks are 22 bytes — not a multiple of 4, so blk may
75- * be misaligned for uint16/uint32 casts after the first block. */
76- uint16_t d_bits = (uint16_t )__ldg (&blk[0 ]) | ((uint16_t )__ldg (&blk[1 ]) << 8 );
77- float d = __half2float (*reinterpret_cast <const __half*>(&d_bits));
78-
79- /* Read qh (32 high bits) using byte-wise load. */
80- uint32_t qh = (uint32_t )__ldg (&blk[2 ])
81- | ((uint32_t )__ldg (&blk[3 ]) << 8 )
82- | ((uint32_t )__ldg (&blk[4 ]) << 16 )
83- | ((uint32_t )__ldg (&blk[5 ]) << 24 );
84-
85- const uint8_t * qs = blk + 6 ;
71+ /* All loads are naturally aligned in the separated layout. */
72+ float d = __half2float (__ldg (&row_scales[bi]));
73+ uint32_t qh = __ldg (&row_qh[bi]);
74+ const uint8_t * qs = row_qs + bi * 16 ;
8675 int k_base = bi * Q5_0_BLOCK_SIZE;
8776
8877 /* Process 16 packed bytes -> 32 dequantized values. */
@@ -119,6 +108,7 @@ __global__ void gemv_q5_0_kernel(
119108extern " C" cudaError_t gemv_q5_0_f32 (
120109 const void * W_q5_0, const float * x, float * y,
121110 int M, int K,
111+ int qhOffset, int qsOffset,
122112 cudaStream_t stream)
123113{
124114 if (K % Q5_0_BLOCK_SIZE != 0 ) {
@@ -130,7 +120,7 @@ extern "C" cudaError_t gemv_q5_0_f32(
130120 int smem = K * sizeof (float );
131121
132122 gemv_q5_0_kernel<<<grid, threads, smem, stream>>> (
133- (const uint8_t *)W_q5_0, x, y, M, K);
123+ (const uint8_t *)W_q5_0, x, y, M, K, qhOffset, qsOffset );
134124
135125 return cudaGetLastError ();
136126}
0 commit comments