Skip to content

Commit f38b174

Browse files
Fix noaux_tc cuda Error 700 in CUDAGraph and Add wfp8apf8 moe quant method (#4115)
* improve per_token_quant_fp8 performance * support moe wfp8apf8 * check glm test * fix noaux_tc op in cudagraph, support noaux_tc return the correct * check * check inf and overwrite score in noaux_tc --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
1 parent 6b47773 commit f38b174

17 files changed

Lines changed: 921 additions & 122 deletions

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ std::vector<paddle::Tensor> NoauxTc(
564564
int n_group,
565565
int topk_group,
566566
int topk,
567+
bool renormalize,
567568
float routed_scaling_factor);
568569

569570
#ifdef ENABLE_FP8

custom_ops/gpu_ops/helper.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,34 @@ inline int GetGPUComputeCapability(int id) {
151151

152152
#endif
153153

154+
#ifndef FP8_E4M3_MAX
155+
#define FP8_E4M3_MAX 448.0
156+
#endif
157+
158+
#ifndef DISPATCH_FLOAT_FP6_DTYPE
159+
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
160+
switch (pd_dtype) { \
161+
case phi::DataType::FLOAT32: { \
162+
using c_type = float; \
163+
__VA_ARGS__ \
164+
break; \
165+
} \
166+
case phi::DataType::BFLOAT16: { \
167+
using c_type = phi::dtype::bfloat16; \
168+
__VA_ARGS__ \
169+
break; \
170+
} \
171+
case phi::DataType::FLOAT16: { \
172+
using c_type = phi::dtype::float16; \
173+
__VA_ARGS__ \
174+
break; \
175+
} \
176+
default: { \
177+
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
178+
} \
179+
}
180+
#endif
181+
154182
inline constexpr uint32_t next_pow_2(uint32_t const num) {
155183
if (num <= 1)
156184
return num;
@@ -563,3 +591,28 @@ inline int GetSMVersion() {
563591
return sm_version;
564592

565593
}
594+
595+
__device__ __forceinline__ float warpReduceMax(float value) {
596+
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
597+
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
598+
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 4));
599+
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 2));
600+
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 1));
601+
return value;
602+
}
603+
604+
__device__ __forceinline__ float blockReduceMax(float value) {
605+
static __shared__ float warpLevelMaxs[WARP_SIZE];
606+
const int laneId = threadIdx.x % WARP_SIZE;
607+
const int warpId = threadIdx.x / WARP_SIZE;
608+
609+
value = warpReduceMax(value);
610+
611+
if (laneId == 0) warpLevelMaxs[warpId] = value;
612+
__syncthreads();
613+
614+
value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;
615+
if (warpId == 0) value = warpReduceMax(value);
616+
617+
return value;
618+
}

custom_ops/gpu_ops/noaux_tc.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
2626
int n_group,
2727
int topk_group,
2828
int topk,
29+
bool renormalize,
2930
float routed_scaling_factor) {
3031
auto input_shape = scores_with_bias.shape();
3132
PD_CHECK(input_shape.size() == 2);
@@ -48,6 +49,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
4849
n_group,
4950
topk_group,
5051
topk,
52+
renormalize,
5153
routed_scaling_factor,
5254
stream);
5355

@@ -76,6 +78,7 @@ PD_BUILD_STATIC_OP(noaux_tc)
7678
.Attrs({"n_group: int",
7779
"topk_group: int",
7880
"topk:int",
81+
"renormalize: bool",
7982
"routed_scaling_factor: float"})
8083
.SetKernelFn(PD_KERNEL(NoauxTc))
8184
.SetInferShapeFn(PD_INFER_SHAPE(NoauxTcInferShape))

0 commit comments

Comments
 (0)