[Common] Optimize fused router forward/backward kernels#3012
[Common] Optimize fused router forward/backward kernels#3012harryzhou2000 wants to merge 13 commits into
Conversation
2009f16 to
14a302c
Compare
Replace multi-loop preprocess (separate clear/load/score/save/bias loops) with single fused loops per score function in all 4 kernel paths (topk forward, topk backward, aux_loss forward, aux_loss backward). Replace multi-pass backward (array-based helpers + comp_buf shmem) with a two-pass approach using scalar helpers: Pass 1: reduction — warp-level sums via warp_allreduce_sum() Pass 2: element-wise — scalar gradient computation → write to global Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar, sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar, softmax_bwd_scalar. Remove dead array helpers from utils.h: apply_sigmoid_on_float, apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float, apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float, masked_warp_reduce_on_shmem. Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf eliminated). Net -226 lines across 3 files. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Add async_loader.h with:
- RawAsyncLoader<T>: cp.async on sm_80+, int4 fallback on sm_70,
stores data in original type (no conversion during copy)
- compute_persistent_grid(): occupancy-based grid sizing
- choose_num_buffers(): shmem-aware 1-vs-2 buffer decision
- vec_fill_global(), vec_store_global(): vectorized output helpers
Forward kernels (topk + aux_loss):
- Logits loaded via RawAsyncLoader with double-buffered prefetch
- Persistent grid replaces 1-shot grid launch
- DataType→CompType conversion during compute, not during load
- vec_fill_global for clearing probs/routing_map
Backward kernels (topk + aux_loss):
- All inputs loaded via RawAsyncLoader (topk: 3 loaders for
grad/act/mask; aux_loss: 2 loaders for grad/act)
- Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2)
- Persistent grid with occupancy-based sizing
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32 registers using 8-bit fields (4 counters per register). Eliminates massive register spill to local memory on large kernels (81% of L1 traffic on E=2304, K=36). Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks in both forward launchers to guard against 8-bit overflow. All current MoE configurations (max E=2304) are well within this limit. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
…dispatch Replace runtime score_function parameter in all 4 kernel __global__ functions with template int ScoreFunc (0=sigmoid, 1=softmax, 2=sqrtsoftplus). All score_function branches now use if constexpr, eliminating dead-code register pressure and branch overhead. Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations per DataType. Backward launchers dispatch on ScoreFunc = 3 instantiations per DataType. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Fix broken topk < 0 threshold (radix was always selected, naive unreachable). Replace with configurable NVTE_RADIX_TOPK_THRESHOLD env var (default 0, i.e. always use radix). Set to 16 to restore the old naive-for-small-K behavior. Uses the standard TE pattern: static local + getenv (read once, cached for process lifetime). Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When choose_num_buffers() returns 1 (shmem too tight for double buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1] alias the same memory. The prefetch via start_load(next_buf()) then overwrites the current buffer while compute is still reading it. Fix: guard the prefetch on num_buffers > 1. When single-buffered, load the current round's data at the top of each iteration instead. The first round's load_current is still issued before the loop. Backward kernels are unaffected (always kBwdNumBuffers=2). Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Code review fixes: - C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block max). These coincide on Hopper/Blackwell but differ on Ampere. - H3: Remove dead fallback branch in choose_num_buffers() — since total_double >= total_single always, blocks_single >= blocks_double, so the old ternary always returned 1 anyway. - H4/M8: Add host-side NVTE_CHECK in all 4 launchers: - num_experts > 0 - topk in [1, num_experts] - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets) - M9: Assert topk % group_topk == 0 when group_topk > 0. - H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in radix_topk_and_mask() — zero cost in release (NDEBUG), catches 8-bit histogram overflow in debug builds. - L1: Fix stale comments claiming default threshold is 16 (it is 0). - L4: Fix typo 'hanlded' -> 'handled'. - L8: Remove unused topk parameter from aux loss backward kernel. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Move the duplicated static function from both .cu files into utils.h as an inline function. Each TU gets its own static local (read-once per TU), which is safe since environment variables are immutable during process lifetime. Documented this in a NOTE comment. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace runtime function-pointer dispatch with compile-time if constexpr. Eliminates indirect call overhead in the reduction loop and warp shuffle butterfly, allowing the compiler to emit straight-line arithmetic. Removes the now-unused max<T>() and sum<T>() helper functions. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight forward kernel that avoids the async loader and persistent grid overhead. The simple kernel loads logits directly from global memory to shmem and uses Naive iterative-argmax topk — matching the baseline structure that was faster for small K due to lower launch/scheduling overhead. The optimized path (async loader + persistent grid + radix topk) remains the default for topk >= 8 where the compute savings dominate. Both topk and aux_loss forward kernels get the simple variant. Backward kernels are unchanged (always use the optimized path). Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
14a302c to
a805f38
Compare
Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float) and __nv_bfloat16(double) constructors on older CUDA toolkits. Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Greptile SummaryThis PR significantly optimizes the fused router CUDA kernels by introducing persistent grids with async double-buffered prefetch (
Confidence Score: 3/5Both forward launchers unconditionally check the larger async-path shared_memory_size before branching on topk threshold, incorrectly rejecting valid large-E/small-topk configurations that the simple kernel would handle within device limits. The premature shmem check is a real defect on the changed launch path — not a theoretical risk — that silently blocks usable configurations for any user with E > ~6000 and topk below the radix threshold. The backward kernels also bypass the adaptive choose_num_buffers logic that was specifically added for the forward paths. fused_topk_with_score_function.cu and fused_score_for_moe_aux_loss.cu — the launcher functions where shmem check order needs correcting and backward buffer counts need adaptive selection. Important Files Changed
|
| T *buf_[2]; | ||
| int phase_; | ||
| bool double_buf_; | ||
|
|
||
| // Raw copy: global → shmem, no type conversion. | ||
| // Uses 16-byte vectorised copies (cp.async on sm_80+, int4 on older archs) | ||
| // when both pointers are 16-byte aligned, with a scalar tail / fallback. | ||
| __device__ void raw_load(const T *__restrict__ src, T *__restrict__ dst, int count, int lane_id) { | ||
| constexpr int kBytesPerCopy = 16; | ||
| constexpr int kEltsPerCopy = kBytesPerCopy / sizeof(T); | ||
|
|
||
| bool src_aligned = (reinterpret_cast<uint64_t>(src) % kBytesPerCopy == 0); | ||
| bool dst_aligned = (reinterpret_cast<uint64_t>(dst) % kBytesPerCopy == 0); | ||
| int aligned_count = (count / kEltsPerCopy) * kEltsPerCopy; | ||
|
|
||
| if (src_aligned && dst_aligned && aligned_count > 0) { | ||
| int vec_count = aligned_count / kEltsPerCopy; | ||
| for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) { | ||
| cp_async_16B(dst + vi * kEltsPerCopy, src + vi * kEltsPerCopy); | ||
| } | ||
| for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) { | ||
| dst[i] = src[i]; | ||
| } | ||
| cp_async_commit(); | ||
| } else { | ||
| for (int i = lane_id; i < count; i += kThreadsPerWarp) { | ||
| dst[i] = src[i]; | ||
| } | ||
| cp_async_commit(); // No-op on sm_70; matches wait() expectation on sm_80+. | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace fused_router | ||
| } // namespace transformer_engine | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_ |
There was a problem hiding this comment.
cp_async_commit() in scalar/unaligned fallback creates empty pipeline groups
The else branch copies elements synchronously then still calls cp_async_commit(). On sm_80+, this commits an empty group that wait() → __pipeline_wait_prior(0) must drain unnecessarily. On sm_70 it is a no-op; on sm_80+ it adds avoidable pipeline overhead.
Summary
Optimizes the fused router CUDA kernels introduced in #2821 (
fused_topk_with_score_functionandfused_score_for_moe_aux_loss). Achieves significant bandwidth improvements for large expert counts and topk values while preserving identical performance for smaller configurations (e.g., E=256, topk=4).Key results (B300, float32, 8192 tokens):
Changes
Forward kernels
RawAsyncLoader<T>usescp.async(sm_80+) for non-blocking global→shmem loads. Occupancy-aware grid sizing (compute_persistent_grid) keeps all SMs saturated across multiple rounds.ScoreFunctemplate parameter withif constexprremoves runtime branches from the hot loop.topk < NVTE_RADIX_TOPK_THRESHOLD(default 8), dispatches to a lightweight kernel matching the original structure — no async loader, no persistent grid — avoiding scheduling overhead that dominates at small K.Backward kernels
warp_allreduce_sum. Pass 2 computes per-element gradients using scalar helpers. Eliminates thecomp_bufshared memory buffer (savesE × warps × 4bytes per block).RawAsyncLoaderwith always-on double buffering.Infrastructure
async_loader.h:RawAsyncLoader<T>,compute_persistent_grid(),choose_num_buffers(), vectorized global store/fill helpers.NVTE_RADIX_TOPK_THRESHOLDenv var (default 8): configurable naive↔radix crossover.warp_reduce_on_shmem<T, ReduceFuncType>eliminates function-pointer overhead.Hardening
num_tokens * num_experts <= INT_MAX,topk ∈ [1, E],topk % group_topk == 0assert(data_size <= kMaxExpertsRadixTopk)in radix pathcudaDevAttrMaxSharedMemoryPerMultiprocessorfor buffer-count decisionCompatibility
test_fused_router.pytests pass, 117 skipped (fp8/multi-node).NVTE_RADIX_TOPK_THRESHOLD=0to force radix everywhere, or=16to use naive for topk<16.Performance (B300 SXM6, sm_103, float32, 8192 tokens)
Effective bandwidth (GB/s) is computed as the minimum bytes that must be transferred to/from global memory for one kernel invocation, divided by the measured wall time. For example, the topk forward kernel reads logits (
T×E×dtype) and writes probs (T×E×dtype), routing_map (T×E×1), and intermediate_output (T×E×4). This metric captures how well the kernel utilizes memory bandwidth — higher is better, with the device peak around 8 TB/s on B300. Config format isnum_experts/topk.Full benchmark table (softmax)
Full benchmark table (sigmoid)