Skip to content

Commit a3f28c3

Browse files
author
da.huo
committed
fix cmakelists and llamalinear
1 parent c2f3e04 commit a3f28c3

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

src/turbomind/kernels/gemm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ set(GEMM2_KERNELS_SM80
1616
kernel/sm80_16816_16.cu
1717
)
1818
set(GEMM2_KERNELS_SM90
19+
tma.cu
1920
kernel/sm90_16816_4.cu
2021
kernel/sm90_16816_8.cu
2122
kernel/sm90_16816_16.cu
@@ -49,7 +50,6 @@ add_library(gemm2
4950
cast.cu
5051
unpack.cu
5152
context.cu
52-
tma.cu
5353
tuner/cache_utils.cu
5454
tuner/measurer.cu
5555
tuner/sampler.cu

src/turbomind/models/llama/LlamaLinear.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,13 @@ struct LlamaLinear::Impl {
8989
Tensor A_e = {{m, k}, A.dtype(), kDEVICE};
9090
invokeMoeDispatch(A_e, A, indices.data(), e, st);
9191
sync_check_cuda_error();
92-
Tensor U_e;
93-
invokeMoeDispatchScales(U_e, U, indices.data(), e, st);
94-
sync_check_cuda_error();
92+
if (U) {
93+
Tensor U_e;
94+
invokeMoeDispatchScales(U_e, U, indices.data(), e, st);
95+
sync_check_cuda_error();
96+
U = U_e;
97+
}
9598
A = A_e;
96-
U = U_e;
9799
indices = {}; // indices already applied
98100
}
99101

0 commit comments

Comments
 (0)