Skip to content

Commit 34aba3b

Browse files
committed
fix(cuda): remove float4 alignment requirement from gemv_q8_kernel
The gemv_q8_kernel cast the activation pointer x (float*) to float4* for 16-byte vectorized loads into shared memory. When x is not 16-byte aligned (common on ARM64/Grace Hopper with pool allocations), this causes cudaMemcpy misaligned address errors. Replace float4 global loads with per-element __ldg loads. Shared memory float4 accesses are unaffected (shared memory is always 16-byte aligned). Performance impact: minimal -- the global-to-shared load is a one-time cost per block, not in the inner loop. Fixes: Gemma3 inference "misaligned address" on DGX Spark GB10. Root cause confirmed via compute-sanitizer --tool memcheck.
1 parent 0e9fb78 commit 34aba3b

1 file changed

Lines changed: 7 additions & 9 deletions

File tree

internal/cuda/kernels/gemm_q8.cu

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,14 @@ __global__ void gemv_q8_kernel(
3030
{
3131
extern __shared__ float sx[];
3232

33-
/* Cooperatively load x[0..K-1] into shared memory using float4 loads. */
33+
/* Cooperatively load x[0..K-1] into shared memory.
34+
* Use per-element loads instead of float4 to avoid misaligned access
35+
* when the activation pointer x is not 16-byte aligned (common on
36+
* ARM64/Grace Hopper when x comes from pool allocations with
37+
* non-aligned offsets). Shared memory loads later in the kernel are
38+
* always aligned since shared memory base is 16-byte aligned. */
3439
int threads_per_block = blockDim.x;
35-
int k4 = K / 4;
36-
const float4* x4 = (const float4*)x;
37-
float4* sx4 = (float4*)sx;
38-
for (int i = threadIdx.x; i < k4; i += threads_per_block) {
39-
sx4[i] = __ldg(&x4[i]);
40-
}
41-
/* Handle remainder if K is not a multiple of 4. */
42-
for (int i = k4 * 4 + threadIdx.x; i < K; i += threads_per_block) {
40+
for (int i = threadIdx.x; i < K; i += threads_per_block) {
4341
sx[i] = __ldg(&x[i]);
4442
}
4543
__syncthreads();

0 commit comments

Comments
 (0)