Skip to content

Commit 1313605

Browse files
committed
fix(cuda): remove float4 alignment requirement from gemv_q8_kernel
The gemv_q8_kernel cast the activation pointer x (float*) to float4* for 16-byte vectorized loads into shared memory. When x is not 16-byte aligned (common on ARM64/Grace Hopper with pool allocations), this causes cudaMemcpy misaligned address errors. Replace float4 global loads with per-element __ldg loads. Shared memory float4 accesses are unaffected (shared memory is always 16-byte aligned). Performance impact: minimal -- the global-to-shared load is a one-time cost per block, not in the inner loop. Fixes: Gemma3 inference "misaligned address" on DGX Spark GB10. Root cause confirmed via compute-sanitizer --tool memcheck.
1 parent 34aba3b commit 1313605

1 file changed

Lines changed: 32 additions & 36 deletions

File tree

internal/cuda/kernels/gemm_q8.cu

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -63,48 +63,44 @@ __global__ void gemv_q8_kernel(
6363
int k_base = bi * Q8_BLOCK_SIZE;
6464
const int8_t* qvals = (const int8_t*)(blk + 4);
6565

66-
/* Vectorized load: read 32 int8 values as two int4 (16 bytes each).
67-
* int4 is a CUDA vector type: {int x, y, z, w} = 16 bytes. */
68-
const int4* qv4 = (const int4*)qvals;
69-
int4 q_lo = __ldg(&qv4[0]); /* qvals[0..15] */
70-
int4 q_hi = __ldg(&qv4[1]); /* qvals[16..31] */
71-
72-
/* Unpack int4 into individual int8 values and dot with shared mem.
73-
* Each int4 component (int x,y,z,w) holds 4 int8 values. */
74-
const int8_t* q_lo_bytes = (const int8_t*)&q_lo;
75-
const int8_t* q_hi_bytes = (const int8_t*)&q_hi;
76-
77-
/* Process first 16 values using float4 loads from shared memory. */
78-
float4 sx0 = ((float4*)&sx[k_base])[0]; /* sx[k_base+0..3] */
79-
float4 sx1 = ((float4*)&sx[k_base])[1]; /* sx[k_base+4..7] */
80-
float4 sx2 = ((float4*)&sx[k_base])[2]; /* sx[k_base+8..11] */
81-
float4 sx3 = ((float4*)&sx[k_base])[3]; /* sx[k_base+12..15] */
66+
/* Read 32 int8 quantized values using per-byte loads.
67+
* Avoid int4 (16-byte) vectorized loads because the Q8 block
68+
* layout (4-byte scale + 32-byte data = 36 bytes) means qvals
69+
* is only 4-byte aligned, not 16-byte aligned. On ARM64/Grace
70+
* Hopper, misaligned int4 loads cause fatal errors. */
71+
72+
/* Process first 16 values using float4 loads from shared memory
73+
* (shared memory is always 16-byte aligned). */
74+
float4 sx0 = ((float4*)&sx[k_base])[0];
75+
float4 sx1 = ((float4*)&sx[k_base])[1];
76+
float4 sx2 = ((float4*)&sx[k_base])[2];
77+
float4 sx3 = ((float4*)&sx[k_base])[3];
8278

8379
acc += scale * (
84-
(float)q_lo_bytes[0] * sx0.x + (float)q_lo_bytes[1] * sx0.y +
85-
(float)q_lo_bytes[2] * sx0.z + (float)q_lo_bytes[3] * sx0.w +
86-
(float)q_lo_bytes[4] * sx1.x + (float)q_lo_bytes[5] * sx1.y +
87-
(float)q_lo_bytes[6] * sx1.z + (float)q_lo_bytes[7] * sx1.w +
88-
(float)q_lo_bytes[8] * sx2.x + (float)q_lo_bytes[9] * sx2.y +
89-
(float)q_lo_bytes[10] * sx2.z + (float)q_lo_bytes[11] * sx2.w +
90-
(float)q_lo_bytes[12] * sx3.x + (float)q_lo_bytes[13] * sx3.y +
91-
(float)q_lo_bytes[14] * sx3.z + (float)q_lo_bytes[15] * sx3.w);
80+
(float)qvals[0] * sx0.x + (float)qvals[1] * sx0.y +
81+
(float)qvals[2] * sx0.z + (float)qvals[3] * sx0.w +
82+
(float)qvals[4] * sx1.x + (float)qvals[5] * sx1.y +
83+
(float)qvals[6] * sx1.z + (float)qvals[7] * sx1.w +
84+
(float)qvals[8] * sx2.x + (float)qvals[9] * sx2.y +
85+
(float)qvals[10] * sx2.z + (float)qvals[11] * sx2.w +
86+
(float)qvals[12] * sx3.x + (float)qvals[13] * sx3.y +
87+
(float)qvals[14] * sx3.z + (float)qvals[15] * sx3.w);
9288

9389
/* Process second 16 values. */
94-
float4 sx4 = ((float4*)&sx[k_base + 16])[0]; /* sx[k_base+16..19] */
95-
float4 sx5 = ((float4*)&sx[k_base + 16])[1]; /* sx[k_base+20..23] */
96-
float4 sx6 = ((float4*)&sx[k_base + 16])[2]; /* sx[k_base+24..27] */
97-
float4 sx7 = ((float4*)&sx[k_base + 16])[3]; /* sx[k_base+28..31] */
90+
float4 sx4 = ((float4*)&sx[k_base + 16])[0];
91+
float4 sx5 = ((float4*)&sx[k_base + 16])[1];
92+
float4 sx6 = ((float4*)&sx[k_base + 16])[2];
93+
float4 sx7 = ((float4*)&sx[k_base + 16])[3];
9894

9995
acc += scale * (
100-
(float)q_hi_bytes[0] * sx4.x + (float)q_hi_bytes[1] * sx4.y +
101-
(float)q_hi_bytes[2] * sx4.z + (float)q_hi_bytes[3] * sx4.w +
102-
(float)q_hi_bytes[4] * sx5.x + (float)q_hi_bytes[5] * sx5.y +
103-
(float)q_hi_bytes[6] * sx5.z + (float)q_hi_bytes[7] * sx5.w +
104-
(float)q_hi_bytes[8] * sx6.x + (float)q_hi_bytes[9] * sx6.y +
105-
(float)q_hi_bytes[10] * sx6.z + (float)q_hi_bytes[11] * sx6.w +
106-
(float)q_hi_bytes[12] * sx7.x + (float)q_hi_bytes[13] * sx7.y +
107-
(float)q_hi_bytes[14] * sx7.z + (float)q_hi_bytes[15] * sx7.w);
96+
(float)qvals[16] * sx4.x + (float)qvals[17] * sx4.y +
97+
(float)qvals[18] * sx4.z + (float)qvals[19] * sx4.w +
98+
(float)qvals[20] * sx5.x + (float)qvals[21] * sx5.y +
99+
(float)qvals[22] * sx5.z + (float)qvals[23] * sx5.w +
100+
(float)qvals[24] * sx6.x + (float)qvals[25] * sx6.y +
101+
(float)qvals[26] * sx6.z + (float)qvals[27] * sx6.w +
102+
(float)qvals[28] * sx7.x + (float)qvals[29] * sx7.y +
103+
(float)qvals[30] * sx7.z + (float)qvals[31] * sx7.w);
108104
}
109105

110106
/* Warp shuffle reduction. */

0 commit comments

Comments
 (0)