@@ -63,48 +63,44 @@ __global__ void gemv_q8_kernel(
6363 int k_base = bi * Q8_BLOCK_SIZE;
6464 const int8_t * qvals = (const int8_t *)(blk + 4 );
6565
66- /* Vectorized load: read 32 int8 values as two int4 (16 bytes each).
67- * int4 is a CUDA vector type: {int x, y, z, w} = 16 bytes. */
68- const int4 * qv4 = (const int4 *)qvals;
69- int4 q_lo = __ldg (&qv4[0 ]); /* qvals[0..15] */
70- int4 q_hi = __ldg (&qv4[1 ]); /* qvals[16..31] */
71-
72- /* Unpack int4 into individual int8 values and dot with shared mem.
73- * Each int4 component (int x,y,z,w) holds 4 int8 values. */
74- const int8_t * q_lo_bytes = (const int8_t *)&q_lo;
75- const int8_t * q_hi_bytes = (const int8_t *)&q_hi;
76-
77- /* Process first 16 values using float4 loads from shared memory. */
78- float4 sx0 = ((float4 *)&sx[k_base])[0 ]; /* sx[k_base+0..3] */
79- float4 sx1 = ((float4 *)&sx[k_base])[1 ]; /* sx[k_base+4..7] */
80- float4 sx2 = ((float4 *)&sx[k_base])[2 ]; /* sx[k_base+8..11] */
81- float4 sx3 = ((float4 *)&sx[k_base])[3 ]; /* sx[k_base+12..15] */
66+ /* Read 32 int8 quantized values using per-byte loads.
67+ * Avoid int4 (16-byte) vectorized loads because the Q8 block
68+ * layout (4-byte scale + 32-byte data = 36 bytes) means qvals
69+ * is only 4-byte aligned, not 16-byte aligned. On ARM64/Grace
70+ * Hopper, misaligned int4 loads cause fatal errors. */
71+
72+ /* Process first 16 values using float4 loads from shared memory
73+ * (shared memory is always 16-byte aligned). */
74+ float4 sx0 = ((float4 *)&sx[k_base])[0 ];
75+ float4 sx1 = ((float4 *)&sx[k_base])[1 ];
76+ float4 sx2 = ((float4 *)&sx[k_base])[2 ];
77+ float4 sx3 = ((float4 *)&sx[k_base])[3 ];
8278
8379 acc += scale * (
84- (float )q_lo_bytes [0 ] * sx0.x + (float )q_lo_bytes [1 ] * sx0.y +
85- (float )q_lo_bytes [2 ] * sx0.z + (float )q_lo_bytes [3 ] * sx0.w +
86- (float )q_lo_bytes [4 ] * sx1.x + (float )q_lo_bytes [5 ] * sx1.y +
87- (float )q_lo_bytes [6 ] * sx1.z + (float )q_lo_bytes [7 ] * sx1.w +
88- (float )q_lo_bytes [8 ] * sx2.x + (float )q_lo_bytes [9 ] * sx2.y +
89- (float )q_lo_bytes [10 ] * sx2.z + (float )q_lo_bytes [11 ] * sx2.w +
90- (float )q_lo_bytes [12 ] * sx3.x + (float )q_lo_bytes [13 ] * sx3.y +
91- (float )q_lo_bytes [14 ] * sx3.z + (float )q_lo_bytes [15 ] * sx3.w );
80+ (float )qvals [0 ] * sx0.x + (float )qvals [1 ] * sx0.y +
81+ (float )qvals [2 ] * sx0.z + (float )qvals [3 ] * sx0.w +
82+ (float )qvals [4 ] * sx1.x + (float )qvals [5 ] * sx1.y +
83+ (float )qvals [6 ] * sx1.z + (float )qvals [7 ] * sx1.w +
84+ (float )qvals [8 ] * sx2.x + (float )qvals [9 ] * sx2.y +
85+ (float )qvals [10 ] * sx2.z + (float )qvals [11 ] * sx2.w +
86+ (float )qvals [12 ] * sx3.x + (float )qvals [13 ] * sx3.y +
87+ (float )qvals [14 ] * sx3.z + (float )qvals [15 ] * sx3.w );
9288
9389 /* Process second 16 values. */
94- float4 sx4 = ((float4 *)&sx[k_base + 16 ])[0 ]; /* sx[k_base+16..19] */
95- float4 sx5 = ((float4 *)&sx[k_base + 16 ])[1 ]; /* sx[k_base+20..23] */
96- float4 sx6 = ((float4 *)&sx[k_base + 16 ])[2 ]; /* sx[k_base+24..27] */
97- float4 sx7 = ((float4 *)&sx[k_base + 16 ])[3 ]; /* sx[k_base+28..31] */
90+ float4 sx4 = ((float4 *)&sx[k_base + 16 ])[0 ];
91+ float4 sx5 = ((float4 *)&sx[k_base + 16 ])[1 ];
92+ float4 sx6 = ((float4 *)&sx[k_base + 16 ])[2 ];
93+ float4 sx7 = ((float4 *)&sx[k_base + 16 ])[3 ];
9894
9995 acc += scale * (
100- (float )q_hi_bytes[ 0 ] * sx4.x + (float )q_hi_bytes[ 1 ] * sx4.y +
101- (float )q_hi_bytes[ 2 ] * sx4.z + (float )q_hi_bytes[ 3 ] * sx4.w +
102- (float )q_hi_bytes[ 4 ] * sx5.x + (float )q_hi_bytes[ 5 ] * sx5.y +
103- (float )q_hi_bytes[ 6 ] * sx5.z + (float )q_hi_bytes[ 7 ] * sx5.w +
104- (float )q_hi_bytes[ 8 ] * sx6.x + (float )q_hi_bytes[ 9 ] * sx6.y +
105- (float )q_hi_bytes[ 10 ] * sx6.z + (float )q_hi_bytes[ 11 ] * sx6.w +
106- (float )q_hi_bytes[ 12 ] * sx7.x + (float )q_hi_bytes[ 13 ] * sx7.y +
107- (float )q_hi_bytes[ 14 ] * sx7.z + (float )q_hi_bytes[ 15 ] * sx7.w );
96+ (float )qvals[ 16 ] * sx4.x + (float )qvals[ 17 ] * sx4.y +
97+ (float )qvals[ 18 ] * sx4.z + (float )qvals[ 19 ] * sx4.w +
98+ (float )qvals[ 20 ] * sx5.x + (float )qvals[ 21 ] * sx5.y +
99+ (float )qvals[ 22 ] * sx5.z + (float )qvals[ 23 ] * sx5.w +
100+ (float )qvals[ 24 ] * sx6.x + (float )qvals[ 25 ] * sx6.y +
101+ (float )qvals[ 26 ] * sx6.z + (float )qvals[ 27 ] * sx6.w +
102+ (float )qvals[ 28 ] * sx7.x + (float )qvals[ 29 ] * sx7.y +
103+ (float )qvals[ 30 ] * sx7.z + (float )qvals[ 31 ] * sx7.w );
108104 }
109105
110106 /* Warp shuffle reduction. */
0 commit comments