@@ -2331,24 +2331,70 @@ typedef struct {
23312331
23322332void * q8_int_dot_worker (void * arg ) {
23332333 q8_int_task_t * t = (q8_int_task_t * )arg ;
2334- for (int d = t -> start_row ; d < t -> end_row ; d ++ ) {
2335- const block_q8_0 * wblk = (const block_q8_0 * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
2336- if (d + 1 < t -> end_row ) {
2337- const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 1 ) * t -> row_bytes ;
2338- /* Q8_0 row is n_blocks * 34 bytes — for hidden=3072, ~3.4 KB.
2339- * Prefetch first 4 cache lines; HW prefetcher takes the rest. */
2334+ int d = t -> start_row ;
2335+ /* 2-row inner loop: pair rows share x_qs / x_ds. ILP hides weight load
2336+ * latency. Same trick as the q4k_int worker. */
2337+ for (; d + 1 < t -> end_row ; d += 2 ) {
2338+ const block_q8_0 * wblk0 = (const block_q8_0 * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
2339+ const block_q8_0 * wblk1 = (const block_q8_0 * )((const uint8_t * )t -> weight + (size_t )(d + 1 ) * t -> row_bytes );
2340+ if (d + 2 < t -> end_row ) {
2341+ const uint8_t * next = (const uint8_t * )t -> weight + (size_t )(d + 2 ) * t -> row_bytes ;
23402342 __builtin_prefetch (next + 0 , 0 , 0 );
23412343 __builtin_prefetch (next + 64 , 0 , 0 );
23422344 __builtin_prefetch (next + 128 , 0 , 0 );
23432345 __builtin_prefetch (next + 192 , 0 , 0 );
23442346 }
2347+ float row_sum0 = 0.0f , row_sum1 = 0.0f ;
2348+ for (int b = 0 ; b < t -> n_blocks ; b ++ ) {
2349+ const float wd0 = fp16_to_fp32 (wblk0 [b ].d );
2350+ const float wd1 = fp16_to_fp32 (wblk1 [b ].d );
2351+ const int8_t * wqs0 = wblk0 [b ].qs ;
2352+ const int8_t * wqs1 = wblk1 [b ].qs ;
2353+ const int8_t * xqs = t -> x_qs + b * 32 ;
2354+ int8x16_t xq0 = vld1q_s8 (xqs + 0 );
2355+ int8x16_t xq1 = vld1q_s8 (xqs + 16 );
2356+ #ifdef __ARM_FEATURE_DOTPROD
2357+ int32x4_t vd0 = vdupq_n_s32 (0 );
2358+ int32x4_t vd1 = vdupq_n_s32 (0 );
2359+ vd0 = vdotq_s32 (vd0 , vld1q_s8 (wqs0 + 0 ), xq0 );
2360+ vd0 = vdotq_s32 (vd0 , vld1q_s8 (wqs0 + 16 ), xq1 );
2361+ vd1 = vdotq_s32 (vd1 , vld1q_s8 (wqs1 + 0 ), xq0 );
2362+ vd1 = vdotq_s32 (vd1 , vld1q_s8 (wqs1 + 16 ), xq1 );
2363+ float xd = t -> x_ds [b ];
2364+ row_sum0 += wd0 * xd * (float )vaddvq_s32 (vd0 );
2365+ row_sum1 += wd1 * xd * (float )vaddvq_s32 (vd1 );
2366+ #else
2367+ int32x4_t vd0a = vdupq_n_s32 (0 ), vd0b = vdupq_n_s32 (0 );
2368+ int32x4_t vd1a = vdupq_n_s32 (0 ), vd1b = vdupq_n_s32 (0 );
2369+ int8x16_t vw0a = vld1q_s8 (wqs0 );
2370+ int8x16_t vw0b = vld1q_s8 (wqs0 + 16 );
2371+ int8x16_t vw1a = vld1q_s8 (wqs1 );
2372+ int8x16_t vw1b = vld1q_s8 (wqs1 + 16 );
2373+ vd0a = vpadalq_s16 (vd0a , vmull_s8 (vget_low_s8 (vw0a ), vget_low_s8 (xq0 )));
2374+ vd0a = vpadalq_s16 (vd0a , vmull_s8 (vget_high_s8 (vw0a ), vget_high_s8 (xq0 )));
2375+ vd0b = vpadalq_s16 (vd0b , vmull_s8 (vget_low_s8 (vw0b ), vget_low_s8 (xq1 )));
2376+ vd0b = vpadalq_s16 (vd0b , vmull_s8 (vget_high_s8 (vw0b ), vget_high_s8 (xq1 )));
2377+ vd1a = vpadalq_s16 (vd1a , vmull_s8 (vget_low_s8 (vw1a ), vget_low_s8 (xq0 )));
2378+ vd1a = vpadalq_s16 (vd1a , vmull_s8 (vget_high_s8 (vw1a ), vget_high_s8 (xq0 )));
2379+ vd1b = vpadalq_s16 (vd1b , vmull_s8 (vget_low_s8 (vw1b ), vget_low_s8 (xq1 )));
2380+ vd1b = vpadalq_s16 (vd1b , vmull_s8 (vget_high_s8 (vw1b ), vget_high_s8 (xq1 )));
2381+ float xd = t -> x_ds [b ];
2382+ row_sum0 += wd0 * xd * (float )vaddvq_s32 (vaddq_s32 (vd0a , vd0b ));
2383+ row_sum1 += wd1 * xd * (float )vaddvq_s32 (vaddq_s32 (vd1a , vd1b ));
2384+ #endif
2385+ }
2386+ t -> out [d ] = row_sum0 ;
2387+ t -> out [d + 1 ] = row_sum1 ;
2388+ }
2389+ /* Tail: single row */
2390+ for (; d < t -> end_row ; d ++ ) {
2391+ const block_q8_0 * wblk = (const block_q8_0 * )((const uint8_t * )t -> weight + (size_t )d * t -> row_bytes );
23452392 float row_sum = 0.0f ;
23462393 for (int b = 0 ; b < t -> n_blocks ; b ++ ) {
23472394 const float wd = fp16_to_fp32 (wblk [b ].d );
23482395 const int8_t * wqs = wblk [b ].qs ;
23492396 const int8_t * xqs = t -> x_qs + b * 32 ;
23502397#ifdef __ARM_FEATURE_DOTPROD
2351- /* ARMv8.2 dotprod: 16 int8 MACs per instruction (M1+ have this). */
23522398 int32x4_t vd = vdupq_n_s32 (0 );
23532399 vd = vdotq_s32 (vd , vld1q_s8 (wqs + 0 ), vld1q_s8 (xqs + 0 ));
23542400 vd = vdotq_s32 (vd , vld1q_s8 (wqs + 16 ), vld1q_s8 (xqs + 16 ));
0 commit comments