Skip to content

[Common] Optimize fused router forward/backward kernels#3012

Open
harryzhou2000 wants to merge 13 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R
Open

[Common] Optimize fused router forward/backward kernels#3012
harryzhou2000 wants to merge 13 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

@harryzhou2000 harryzhou2000 commented May 19, 2026

Summary

Optimizes the fused router CUDA kernels introduced in #2821 (fused_topk_with_score_function and fused_score_for_moe_aux_loss). Achieves significant bandwidth improvements for large expert counts and topk values while preserving identical performance for smaller configurations (e.g., E=256, topk=4).

Key results (B300, float32, 8192 tokens):

  • Forward (E=2304, K=36, softmax): 673 → 964 GB/s (+43%)
  • Backward (E=2304, K=36, softmax): 543 → 2766 GB/s (+410%)
  • Forward (E=512, K=4): no regression (±0.3%)

Changes

Forward kernels

  • Persistent grid with async double-buffered prefetch: RawAsyncLoader<T> uses cp.async (sm_80+) for non-blocking global→shmem loads. Occupancy-aware grid sizing (compute_persistent_grid) keeps all SMs saturated across multiple rounds.
  • Packed 8-bit radix histogram: Reduces radix topk register usage from 32 to 4 registers by packing 16 bucket counts into 4×u32 with 8-bit fields. Eliminates local memory spill at large E.
  • Compile-time score function dispatch: ScoreFunc template parameter with if constexpr removes runtime branches from the hot loop.
  • Simple kernel path for small topk: When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), dispatches to a lightweight kernel matching the original structure — no async loader, no persistent grid — avoiding scheduling overhead that dominates at small K.

Backward kernels

  • Two-pass fused design: Pass 1 accumulates warp-level sums via register reduction + warp_allreduce_sum. Pass 2 computes per-element gradients using scalar helpers. Eliminates the comp_buf shared memory buffer (saves E × warps × 4 bytes per block).
  • Double-buffered async loading: All backward inputs (grad, activation, mask) loaded through RawAsyncLoader with always-on double buffering.

Infrastructure

  • async_loader.h: RawAsyncLoader<T>, compute_persistent_grid(), choose_num_buffers(), vectorized global store/fill helpers.
  • NVTE_RADIX_TOPK_THRESHOLD env var (default 8): configurable naive↔radix crossover.
  • Templated warp_reduce_on_shmem<T, ReduceFuncType> eliminates function-pointer overhead.

Hardening

  • Host-side: num_tokens * num_experts <= INT_MAX, topk ∈ [1, E], topk % group_topk == 0
  • Device-side: assert(data_size <= kMaxExpertsRadixTopk) in radix path
  • Correct cudaDevAttrMaxSharedMemoryPerMultiprocessor for buffer-count decision
  • Fix: single-buffer prefetch clobber when shmem is too tight for double buffering

Compatibility

  • No regression for small configs: The simple forward kernel path is an exact replica of the original kernel structure, ensuring E=256/topk=4 (common in standard MoE) performs identically.
  • All existing tests pass: 891/891 test_fused_router.py tests pass, 117 skipped (fp8/multi-node).
  • No API changes: Same Python/C++ interface, same output semantics.
  • Tunable: Set NVTE_RADIX_TOPK_THRESHOLD=0 to force radix everywhere, or =16 to use naive for topk<16.

Performance (B300 SXM6, sm_103, float32, 8192 tokens)

Effective bandwidth (GB/s) is computed as the minimum bytes that must be transferred to/from global memory for one kernel invocation, divided by the measured wall time. For example, the topk forward kernel reads logits (T×E×dtype) and writes probs (T×E×dtype), routing_map (T×E×1), and intermediate_output (T×E×4). This metric captures how well the kernel utilizes memory bandwidth — higher is better, with the device peak around 8 TB/s on B300. Config format is num_experts/topk.

Full benchmark table (softmax)
kernel pass config before after
topk fprop 512/4 1779 1784 (+0.3%)
topk fprop 512/8 798 904 (+13%)
topk fprop 512/22 514 924 (+80%)
topk fprop 512/36 499 908 (+82%)
topk fprop 2304/4 1803 1802 (0%)
topk fprop 2304/8 660 993 (+51%)
topk fprop 2304/22 602 972 (+61%)
topk fprop 2304/36 673 964 (+43%)
topk bprop 512/22 3391 5362 (+58%)
topk bprop 2304/36 543 2766 (+410%)
aux_loss fprop 512/22 519 896 (+73%)
aux_loss fprop 2304/36 645 891 (+38%)
aux_loss bprop 512/22 5289 6155 (+16%)
aux_loss bprop 2304/36 2272 4201 (+85%)
Full benchmark table (sigmoid)
kernel pass config before after
topk fprop 512/4 1728 1736 (+0.5%)
topk fprop 512/22 470 891 (+90%)
topk fprop 2304/36 639 798 (+25%)
topk bprop 512/22 3169 4398 (+39%)
topk bprop 2304/36 533 2274 (+327%)
aux_loss fprop 512/22 475 912 (+92%)
aux_loss fprop 2304/36 598 867 (+45%)
aux_loss bprop 2304/36 1965 2757 (+40%)

@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch from 2009f16 to 14a302c Compare May 19, 2026 10:12
Replace multi-loop preprocess (separate clear/load/score/save/bias loops)
with single fused loops per score function in all 4 kernel paths (topk
forward, topk backward, aux_loss forward, aux_loss backward).

Replace multi-pass backward (array-based helpers + comp_buf shmem) with
a two-pass approach using scalar helpers:
  Pass 1: reduction — warp-level sums via warp_allreduce_sum()
  Pass 2: element-wise — scalar gradient computation → write to global

Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar,
sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar,
softmax_bwd_scalar.

Remove dead array helpers from utils.h: apply_sigmoid_on_float,
apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float,
apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float,
masked_warp_reduce_on_shmem.

Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf
eliminated).  Net -226 lines across 3 files.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Add async_loader.h with:
  - RawAsyncLoader<T>: cp.async on sm_80+, int4 fallback on sm_70,
    stores data in original type (no conversion during copy)
  - compute_persistent_grid(): occupancy-based grid sizing
  - choose_num_buffers(): shmem-aware 1-vs-2 buffer decision
  - vec_fill_global(), vec_store_global(): vectorized output helpers

Forward kernels (topk + aux_loss):
  - Logits loaded via RawAsyncLoader with double-buffered prefetch
  - Persistent grid replaces 1-shot grid launch
  - DataType→CompType conversion during compute, not during load
  - vec_fill_global for clearing probs/routing_map

Backward kernels (topk + aux_loss):
  - All inputs loaded via RawAsyncLoader (topk: 3 loaders for
    grad/act/mask; aux_loss: 2 loaders for grad/act)
  - Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2)
  - Persistent grid with occupancy-based sizing

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32
registers using 8-bit fields (4 counters per register).  Eliminates
massive register spill to local memory on large kernels (81% of L1
traffic on E=2304, K=36).

Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks
in both forward launchers to guard against 8-bit overflow.  All current
MoE configurations (max E=2304) are well within this limit.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
…dispatch

Replace runtime score_function parameter in all 4 kernel __global__
functions with template int ScoreFunc (0=sigmoid, 1=softmax,
2=sqrtsoftplus).  All score_function branches now use if constexpr,
eliminating dead-code register pressure and branch overhead.

Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations
per DataType.  Backward launchers dispatch on ScoreFunc = 3
instantiations per DataType.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Fix broken topk < 0 threshold (radix was always selected, naive
unreachable).  Replace with configurable NVTE_RADIX_TOPK_THRESHOLD
env var (default 0, i.e. always use radix).  Set to 16 to restore
the old naive-for-small-K behavior.

Uses the standard TE pattern: static local + getenv (read once,
cached for process lifetime).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When choose_num_buffers() returns 1 (shmem too tight for double
buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1]
alias the same memory.  The prefetch via start_load(next_buf()) then
overwrites the current buffer while compute is still reading it.

Fix: guard the prefetch on num_buffers > 1.  When single-buffered,
load the current round's data at the top of each iteration instead.
The first round's load_current is still issued before the loop.

Backward kernels are unaffected (always kBwdNumBuffers=2).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Code review fixes:

- C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor
  (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block
  max).  These coincide on Hopper/Blackwell but differ on Ampere.

- H3: Remove dead fallback branch in choose_num_buffers() — since
  total_double >= total_single always, blocks_single >= blocks_double,
  so the old ternary always returned 1 anyway.

- H4/M8: Add host-side NVTE_CHECK in all 4 launchers:
  - num_experts > 0
  - topk in [1, num_experts]
  - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets)

- M9: Assert topk % group_topk == 0 when group_topk > 0.

- H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in
  radix_topk_and_mask() — zero cost in release (NDEBUG), catches
  8-bit histogram overflow in debug builds.

- L1: Fix stale comments claiming default threshold is 16 (it is 0).
- L4: Fix typo 'hanlded' -> 'handled'.
- L8: Remove unused topk parameter from aux loss backward kernel.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Move the duplicated static function from both .cu files into utils.h
as an inline function.  Each TU gets its own static local (read-once
per TU), which is safe since environment variables are immutable
during process lifetime.  Documented this in a NOTE comment.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace runtime function-pointer dispatch with compile-time if constexpr.
Eliminates indirect call overhead in the reduction loop and warp shuffle
butterfly, allowing the compiler to emit straight-line arithmetic.

Removes the now-unused max<T>() and sum<T>() helper functions.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight
forward kernel that avoids the async loader and persistent grid overhead.
The simple kernel loads logits directly from global memory to shmem and
uses Naive iterative-argmax topk — matching the baseline structure that
was faster for small K due to lower launch/scheduling overhead.

The optimized path (async loader + persistent grid + radix topk) remains
the default for topk >= 8 where the compute savings dominate.

Both topk and aux_loss forward kernels get the simple variant.
Backward kernels are unchanged (always use the optimized path).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch from 14a302c to a805f38 Compare May 19, 2026 10:22
Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float)
and __nv_bfloat16(double) constructors on older CUDA toolkits.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 marked this pull request as ready for review May 20, 2026 08:29
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR significantly optimizes the fused router CUDA kernels by introducing persistent grids with async double-buffered prefetch (cp.async), a packed 8-bit radix histogram to eliminate register spill at large expert counts, compile-time ScoreFunc dispatch, and a two-pass backward design that eliminates the comp_buf shared memory buffer.

  • Forward: splits into a simple kernel (no async overhead) for topk < NVTE_RADIX_TOPK_THRESHOLD and an optimized persistent+async kernel for larger topk, applied to both the topk and aux_loss kernels.
  • Backward: replaces the old three-buffer shmem design with two-pass register reduction and double-buffered async loads of grad, activation, and mask, yielding up to 410% bandwidth improvement at large configs.
  • Infrastructure: new async_loader.h adds RawAsyncLoader<T>, occupancy helpers, and vectorized store/fill; utils.h gains scalar activation helpers and the packed radix histogram optimisation.

Confidence Score: 3/5

Both forward launchers unconditionally check the larger async-path shared_memory_size before branching on topk threshold, incorrectly rejecting valid large-E/small-topk configurations that the simple kernel would handle within device limits.

The premature shmem check is a real defect on the changed launch path — not a theoretical risk — that silently blocks usable configurations for any user with E > ~6000 and topk below the radix threshold. The backward kernels also bypass the adaptive choose_num_buffers logic that was specifically added for the forward paths.

fused_topk_with_score_function.cu and fused_score_for_moe_aux_loss.cu — the launcher functions where shmem check order needs correcting and backward buffer counts need adaptive selection.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/async_loader.h New header with RawAsyncLoader, compute_persistent_grid, choose_num_buffers, and vec helpers. Scalar fallback commits empty cp.async groups unnecessarily; buffer logic and alignment checks are otherwise sound.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Major refactor with simple+optimized kernel split and two-pass backward. Premature unconditional shmem check and hardcoded backward buffer count are the key defects.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Same refactor pattern as topk; shares both the premature shmem check and hardcoded kAuxBwdNumBuffers=2 issues.
transformer_engine/common/fused_router/utils.h Compile-time warp_reduce_on_shmem, scalar activation helpers, packed 8-bit radix histogram. Logic appears correct; per-TU statics for get_radix_topk_threshold are safe given env-var immutability.
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Minimal change: updates warp_reduce_on_shmem call site to new templated signature. Correct.

Comments Outside Diff (3)

  1. transformer_engine/common/fused_router/fused_topk_with_score_function.cu, line 1429-1432 (link)

    P1 Premature shared-memory check rejects valid small-topk configs

    check_shared_memory_capacity_num_experts(shared_memory_size, num_experts) runs unconditionally using shared_memory_size (which includes the async loader overhead sized for the optimized path), but when topk < get_radix_topk_threshold() the simple kernel only allocates other_shmem. On an A100 (max 228 KB per SM), E=8000/topk=4/float32 gives other_shmem ≈ 128 KB (within limit) but shared_memory_size ≈ 256 KB (over limit) — the unconditional check throws, blocking valid configurations the simple-path kernel would handle fine. The same pattern appears in fused_score_for_moe_aux_loss_forward_kernel_launcher. Move the check inside each branch.

  2. transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu, line 599-603 (link)

    P1 Same premature shmem check as in the topk launcher

    check_shared_memory_capacity_num_experts(shared_memory_size, num_experts) runs before the topk < get_radix_topk_threshold() branch and uses the inflated shared_memory_size (includes async-loader buffers). The simple kernel path only needs other_shmem; the mismatch causes a spurious capacity error for large E with small topk.

  3. transformer_engine/common/fused_router/fused_topk_with_score_function.cu, line 1519-1574 (link)

    P2 Backward kernels always use 2 async buffers without adaptive buffer selection

    kBwdNumBuffers = 2 is hardcoded; the choose_num_buffers() helper added in this PR is never consulted for any backward kernel. The forward launchers already call choose_num_buffers() correctly. On devices where double-buffering would drop below kMinBlocksPerSM = 4 resident blocks per SM, single-buffer would give better occupancy. Both topk and aux_loss backward launchers are affected.

Reviews (1): Last reviewed commit: "[Common] Fix bf16 ambiguous constructor ..." | Re-trigger Greptile

Comment on lines +220 to +256
T *buf_[2];
int phase_;
bool double_buf_;

// Raw copy: global → shmem, no type conversion.
// Uses 16-byte vectorised copies (cp.async on sm_80+, int4 on older archs)
// when both pointers are 16-byte aligned, with a scalar tail / fallback.
__device__ void raw_load(const T *__restrict__ src, T *__restrict__ dst, int count, int lane_id) {
constexpr int kBytesPerCopy = 16;
constexpr int kEltsPerCopy = kBytesPerCopy / sizeof(T);

bool src_aligned = (reinterpret_cast<uint64_t>(src) % kBytesPerCopy == 0);
bool dst_aligned = (reinterpret_cast<uint64_t>(dst) % kBytesPerCopy == 0);
int aligned_count = (count / kEltsPerCopy) * kEltsPerCopy;

if (src_aligned && dst_aligned && aligned_count > 0) {
int vec_count = aligned_count / kEltsPerCopy;
for (int vi = lane_id; vi < vec_count; vi += kThreadsPerWarp) {
cp_async_16B(dst + vi * kEltsPerCopy, src + vi * kEltsPerCopy);
}
for (int i = aligned_count + lane_id; i < count; i += kThreadsPerWarp) {
dst[i] = src[i];
}
cp_async_commit();
} else {
for (int i = lane_id; i < count; i += kThreadsPerWarp) {
dst[i] = src[i];
}
cp_async_commit(); // No-op on sm_70; matches wait() expectation on sm_80+.
}
}
};

} // namespace fused_router
} // namespace transformer_engine

#endif // TRANSFORMER_ENGINE_FUSED_ROUTER_ASYNC_LOADER_H_
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 cp_async_commit() in scalar/unaligned fallback creates empty pipeline groups

The else branch copies elements synchronously then still calls cp_async_commit(). On sm_80+, this commits an empty group that wait()__pipeline_wait_prior(0) must drain unnecessarily. On sm_70 it is a no-op; on sm_80+ it adds avoidable pipeline overhead.

@tdophung tdophung self-assigned this May 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants