Skip to content

Commit f1ad427

Browse files
author
da.huo
committed
feat(turbomind): integrate cublasGemmGroupedBatchedEx for Qwen3.5 MoE on Blackwell
Use grouped batched GEMM on SM100, SM90 CUTLASS kernels split into a separate STATIC library for arch-specific builds, copy path workaround for Blackwell, and Llama MoE weight layout adjustments. Move tma.cu into libgemm2_sm90.a (its only callers are SM90 kernels), fixing undefined symbol make_2d_tma_desc from single-pass static link order between two archives.
1 parent 65155f2 commit f1ad427

11 files changed

Lines changed: 442 additions & 23 deletions

File tree

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,11 @@ if(ARCH STREQUAL "x86_64")
263263
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8")
264264
list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) # 5090
265265
endif ()
266+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "12.8")
267+
list(APPEND CMAKE_CUDA_ARCHITECTURES 100a-real) # B200
268+
endif()
266269
if (MSVC)
267-
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real)
270+
list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80-real 90a-real 100a-real)
268271
endif ()
269272
endif ()
270273
elseif(ARCH STREQUAL "aarch64")

src/turbomind/core/copy.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ const auto& GetCopyAPI()
5757
void* fpn{};
5858
TM_CHECK_EQ(cudaGetDriverEntryPoint(symbol, &fpn, cudaEnableDefault, &status), 0);
5959
if (fpn && status == cudaDriverEntryPointSuccess) {
60+
// cuMemcpyBatchAsync crashes on sm_100 (Blackwell); force monostate -> serialized path.
61+
int device = 0;
62+
(void)cudaGetDevice(&device);
63+
int major = 0;
64+
(void)cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
65+
if (major >= 10) {
66+
return {};
67+
}
6068
return (PFN_cuMemcpyBatchAsync_v12080)fpn;
6169
}
6270
else {

src/turbomind/kernels/gemm/CMakeLists.txt

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,44 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3+
set(GEMM2_KERNELS_SM70
4+
kernel/sm70_884_4.cu
5+
kernel/sm70_884_8.cu
6+
kernel/sm70_884_16.cu
7+
)
8+
set(GEMM2_KERNELS_SM75
9+
kernel/sm75_16816_4.cu
10+
kernel/sm75_16816_8.cu
11+
kernel/sm75_16816_16.cu
12+
)
13+
set(GEMM2_KERNELS_SM80
14+
kernel/sm80_16816_4.cu
15+
kernel/sm80_16816_8.cu
16+
kernel/sm80_16816_16.cu
17+
)
18+
set(GEMM2_KERNELS_SM90
19+
kernel/sm90_16816_4.cu
20+
kernel/sm90_16816_8.cu
21+
kernel/sm90_16816_16.cu
22+
kernel/sm90_64n32_8.cu
23+
)
24+
25+
set(GEMM2_ARCH_90_ENABLED FALSE)
26+
set(_sm90_archs "${CMAKE_CUDA_ARCHITECTURES}")
27+
list(FILTER _sm90_archs INCLUDE REGEX "^90")
28+
if(_sm90_archs)
29+
set(GEMM2_ARCH_90_ENABLED TRUE)
30+
else()
31+
# When building for SM100+ without explicit SM90, still compile SM90 CUTLASS
32+
# kernels so the fat binary can run MoE models on H100 (CUTLASS fused path).
33+
set(_sm100_archs "${CMAKE_CUDA_ARCHITECTURES}")
34+
list(FILTER _sm100_archs INCLUDE REGEX "^100")
35+
if(_sm100_archs)
36+
set(GEMM2_ARCH_90_ENABLED TRUE)
37+
set(_sm90_archs "90")
38+
message(STATUS "GEMM: auto-enabling SM90 CUTLASS kernels for H100 backward compatibility")
39+
endif()
40+
endif()
41+
342
add_library(gemm2
443
gemm.cu
544
kernel.cu
@@ -10,34 +49,30 @@ add_library(gemm2
1049
cast.cu
1150
unpack.cu
1251
context.cu
13-
tma.cu
1452
tuner/cache_utils.cu
1553
tuner/measurer.cu
1654
tuner/sampler.cu
1755
tuner/stopping_criterion.cc
1856
tuner/params.cc
19-
kernel/sm90_16816_4.cu
20-
kernel/sm90_16816_8.cu
21-
kernel/sm90_16816_16.cu
22-
kernel/sm80_16816_4.cu
23-
kernel/sm80_16816_8.cu
24-
kernel/sm80_16816_16.cu
25-
kernel/sm75_16816_4.cu
26-
kernel/sm75_16816_8.cu
27-
kernel/sm75_16816_16.cu
28-
kernel/sm70_884_4.cu
29-
kernel/sm70_884_8.cu
30-
kernel/sm70_884_16.cu
31-
kernel/sm90_64n32_8.cu
57+
${GEMM2_KERNELS_SM70}
58+
${GEMM2_KERNELS_SM75}
59+
${GEMM2_KERNELS_SM80}
3260
cublas.cu
3361
moe_utils_v2.cu
3462
test/test_utils.cu
3563
)
3664

3765
target_link_libraries(gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)
3866

39-
40-
target_compile_definitions(gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
67+
# cublasGemmGroupedBatchedEx (CUDA 12.5+): grouped batched GEMM for MoE on SM100
68+
set(_has_sm100 FALSE)
69+
set(_archs_100 "${CMAKE_CUDA_ARCHITECTURES}")
70+
list(FILTER _archs_100 INCLUDE REGEX "^100")
71+
if(_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5")
72+
set(_has_sm100 TRUE)
73+
target_compile_definitions(gemm2 PRIVATE ENABLE_CUBLAS_GROUPED=1)
74+
message(STATUS "GEMM: ENABLE_CUBLAS_GROUPED=1 (cublasGemmGroupedBatchedEx for MoE on SM100)")
75+
endif()
4176

4277
target_compile_options(gemm2 PRIVATE
4378
$<$<COMPILE_LANGUAGE:CUDA>:
@@ -48,7 +83,29 @@ target_compile_options(gemm2 PRIVATE
4883
set_property(TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON)
4984
set_property(TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
5085

86+
if(GEMM2_ARCH_90_ENABLED)
87+
# SM90 kernels + tma.cu are built as a separate STATIC library for sm_90 only.
88+
# tma.cu (defines make_2d_tma_desc) is placed here because its only callers are
89+
# in kernel_impl_sm90.h; keeping definition and references in the same archive
90+
# avoids the single-pass static-link ordering problem between two archives.
91+
add_library(gemm2_sm90 STATIC ${GEMM2_KERNELS_SM90} tma.cu)
92+
set_target_properties(gemm2_sm90 PROPERTIES
93+
CUDA_ARCHITECTURES "${_sm90_archs}"
94+
POSITION_INDEPENDENT_CODE ON
95+
CUDA_RESOLVE_DEVICE_SYMBOLS ON
96+
)
97+
target_compile_definitions(gemm2_sm90 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
98+
target_compile_options(gemm2_sm90 PRIVATE
99+
$<$<COMPILE_LANGUAGE:CUDA>:
100+
-Xptxas=-v
101+
--generate-line-info
102+
--threads 16>
103+
)
104+
target_link_libraries(gemm2_sm90 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver)
105+
target_link_libraries(gemm2 PRIVATE gemm2_sm90)
51106

107+
target_compile_definitions(gemm2 PRIVATE GEMM2_ARCH_90_ENABLED)
108+
endif()
52109

53110
if (BUILD_TEST)
54111
add_executable(test_gemm_v2

src/turbomind/kernels/gemm/arch.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,20 @@ struct Sm80: Arch<800, 900> {
2626
static constexpr int value = 800;
2727
};
2828

29-
struct Sm90: Arch<900> {
29+
struct Sm90: Arch<900, 1000> {
3030
static constexpr int value = 900;
3131
};
3232

33+
// B200 (Blackwell) SM 100
34+
struct Sm100: Arch<1000, 1200> {
35+
static constexpr int value = 1000;
36+
};
37+
38+
// SM12.x (e.g. sm_120): use same CUTLASS SM90 kernel family as pre-PR Sm90+ range
39+
struct Sm120: Arch<1200, 1300> {
40+
static constexpr int value = 1200;
41+
};
42+
3343
inline bool is_arch_compatible(int karch, int darch)
3444
{
3545
switch (karch) {
@@ -42,7 +52,11 @@ inline bool is_arch_compatible(int karch, int darch)
4252
case 800:
4353
return Sm80::is_compatible(darch);
4454
case 900:
45-
return Sm90::is_compatible(darch);
55+
return Sm90::is_compatible(darch) || Sm120::is_compatible(darch);
56+
case 1000:
57+
return Sm100::is_compatible(darch);
58+
case 1200:
59+
return Sm120::is_compatible(darch);
4660
default:
4761
return false;
4862
}

src/turbomind/kernels/gemm/convert_v3.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ std::array<const LayoutConverter*, 2> GetConverters(DataType data_type,
105105
if (weight_type == kHalf || weight_type == kBfloat16) {
106106
constexpr Cvt<uint16_t, uint16_t> W;
107107
if (grouped) {
108+
// SM10.x only: CublasGroupedKernel (cublasGemmGroupedBatchedEx) expects standard (K,N)
109+
if (sm >= 100 && sm < 120)
110+
return {};
108111
// clang-format off
109112
if (sm >= 80) return {W(sm8_, kRow, s16816h | B | _1), {}};
110113
if (sm == 75) return {W(sm75, kRow, s16816h | B | _1), {}};

0 commit comments

Comments
 (0)