11# Copyright (c) OpenMMLab. All rights reserved.
22
3+ set (GEMM2_KERNELS_SM70
4+ kernel/sm70_884_4.cu
5+ kernel/sm70_884_8.cu
6+ kernel/sm70_884_16.cu
7+ )
8+ set (GEMM2_KERNELS_SM75
9+ kernel/sm75_16816_4.cu
10+ kernel/sm75_16816_8.cu
11+ kernel/sm75_16816_16.cu
12+ )
13+ set (GEMM2_KERNELS_SM80
14+ kernel/sm80_16816_4.cu
15+ kernel/sm80_16816_8.cu
16+ kernel/sm80_16816_16.cu
17+ )
18+ set (GEMM2_KERNELS_SM90
19+ kernel/sm90_16816_4.cu
20+ kernel/sm90_16816_8.cu
21+ kernel/sm90_16816_16.cu
22+ kernel/sm90_64n32_8.cu
23+ )
24+
25+ set (GEMM2_ARCH_90_ENABLED FALSE )
26+ set (_sm90_archs "${CMAKE_CUDA_ARCHITECTURES} " )
27+ list (FILTER _sm90_archs INCLUDE REGEX "^90" )
28+ if (_sm90_archs)
29+ set (GEMM2_ARCH_90_ENABLED TRUE )
30+ else ()
31+ # When building for SM100+ without explicit SM90, still compile SM90 CUTLASS
32+ # kernels so the fat binary can run MoE models on H100 (CUTLASS fused path).
33+ set (_sm100_archs "${CMAKE_CUDA_ARCHITECTURES} " )
34+ list (FILTER _sm100_archs INCLUDE REGEX "^100" )
35+ if (_sm100_archs)
36+ set (GEMM2_ARCH_90_ENABLED TRUE )
37+ set (_sm90_archs "90" )
38+ message (STATUS "GEMM: auto-enabling SM90 CUTLASS kernels for H100 backward compatibility" )
39+ endif ()
40+ endif ()
41+
342add_library (gemm2
443 gemm.cu
544 kernel.cu
@@ -10,34 +49,30 @@ add_library(gemm2
1049 cast.cu
1150 unpack.cu
1251 context.cu
13- tma.cu
1452 tuner/cache_utils.cu
1553 tuner/measurer.cu
1654 tuner/sampler.cu
1755 tuner/stopping_criterion.cc
1856 tuner/params.cc
19- kernel/sm90_16816_4.cu
20- kernel/sm90_16816_8.cu
21- kernel/sm90_16816_16.cu
22- kernel/sm80_16816_4.cu
23- kernel/sm80_16816_8.cu
24- kernel/sm80_16816_16.cu
25- kernel/sm75_16816_4.cu
26- kernel/sm75_16816_8.cu
27- kernel/sm75_16816_16.cu
28- kernel/sm70_884_4.cu
29- kernel/sm70_884_8.cu
30- kernel/sm70_884_16.cu
31- kernel/sm90_64n32_8.cu
57+ ${GEMM2_KERNELS_SM70}
58+ ${GEMM2_KERNELS_SM75}
59+ ${GEMM2_KERNELS_SM80}
3260 cublas.cu
3361 moe_utils_v2.cu
3462 test /test_utils.cu
3563)
3664
3765target_link_libraries (gemm2 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver )
3866
39-
40- target_compile_definitions (gemm2 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED )
67+ # cublasGemmGroupedBatchedEx (CUDA 12.5+): grouped batched GEMM for MoE on SM100
68+ set (_has_sm100 FALSE )
69+ set (_archs_100 "${CMAKE_CUDA_ARCHITECTURES} " )
70+ list (FILTER _archs_100 INCLUDE REGEX "^100" )
71+ if (_archs_100 AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.5" )
72+ set (_has_sm100 TRUE )
73+ target_compile_definitions (gemm2 PRIVATE ENABLE_CUBLAS_GROUPED=1 )
74+ message (STATUS "GEMM: ENABLE_CUBLAS_GROUPED=1 (cublasGemmGroupedBatchedEx for MoE on SM100)" )
75+ endif ()
4176
4277target_compile_options (gemm2 PRIVATE
4378 $<$<COMPILE_LANGUAGE :CUDA >:
@@ -48,7 +83,29 @@ target_compile_options(gemm2 PRIVATE
4883set_property (TARGET gemm2 PROPERTY POSITION_INDEPENDENT_CODE ON )
4984set_property (TARGET gemm2 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON )
5085
86+ if (GEMM2_ARCH_90_ENABLED)
87+ # SM90 kernels + tma.cu are built as a separate STATIC library for sm_90 only.
88+ # tma.cu (defines make_2d_tma_desc) is placed here because its only callers are
89+ # in kernel_impl_sm90.h; keeping definition and references in the same archive
90+ # avoids the single-pass static-link ordering problem between two archives.
91+ add_library (gemm2_sm90 STATIC ${GEMM2_KERNELS_SM90} tma.cu )
92+ set_target_properties (gemm2_sm90 PROPERTIES
93+ CUDA_ARCHITECTURES "${_sm90_archs} "
94+ POSITION_INDEPENDENT_CODE ON
95+ CUDA_RESOLVE_DEVICE_SYMBOLS ON
96+ )
97+ target_compile_definitions (gemm2_sm90 PRIVATE -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED )
98+ target_compile_options (gemm2_sm90 PRIVATE
99+ $<$<COMPILE_LANGUAGE :CUDA >:
100+ -Xptxas =-v
101+ --generate -line -info
102+ --threads 16>
103+ )
104+ target_link_libraries (gemm2_sm90 PRIVATE parser nvidia::cutlass::cutlass CUDA::cuda_driver )
105+ target_link_libraries (gemm2 PRIVATE gemm2_sm90 )
51106
107+ target_compile_definitions (gemm2 PRIVATE GEMM2_ARCH_90_ENABLED )
108+ endif ()
52109
53110if (BUILD_TEST)
54111 add_executable (test_gemm_v2
0 commit comments