Skip to content

Commit 14676a3

Browse files
authored
[Cherry-Pick][OP] cherry-pick #7073 support deepgemm for sm103 (#7081)
1 parent bd48640 commit 14676a3

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

fastdeploy/model_executor/layers/quantization/block_wise_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, weight_block_size: list = [-1, -1], is_checkpoint_bf16: bool
6767
self.quant_round_type = 1
6868
self.use_deep_gemm = bool(envs.FD_USE_DEEP_GEMM)
6969
self.is_checkpoint_bf16 = is_checkpoint_bf16
70-
self.deepgemm_scale_ue8m0 = True if get_sm_version() == 100 else False
70+
self.deepgemm_scale_ue8m0 = True if get_sm_version() >= 100 else False
7171

7272
def name(self) -> str:
7373
return "block_wise_fp8"

fastdeploy/model_executor/layers/quantization/fp8_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def load_deep_gemm():
6060
"""
6161

6262
if current_platform.is_cuda():
63-
if get_sm_version() == 100:
63+
if get_sm_version() >= 100:
6464
# SM100 should use PFCC DeepGemm
6565
paddle.compat.enable_torch_proxy(scope={"deep_gemm"})
6666
try:
@@ -167,7 +167,7 @@ def fused_stack_transpose_quant(expert_weight_list, use_ue8m0=False):
167167
# Blackwell (SM100) GPUs require pow2_scale quantization.
168168
# Guard with is_cuda() so non-CUDA environments do not call into
169169
# paddle.device.cuda.* and cause a crash.
170-
use_pow2_scale = current_platform.is_cuda() and get_sm_version() == 100
170+
use_pow2_scale = current_platform.is_cuda() and get_sm_version() >= 100
171171

172172
w, scale = paddlefleet_ops.fuse_stack_transpose_fp8_quant(
173173
expert_weight_list,

0 commit comments

Comments
 (0)