Skip to content

[Common] Optimize fused router forward/backward kernels#3012

Open
harryzhou2000 wants to merge 16 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R
Open

[Common] Optimize fused router forward/backward kernels#3012
harryzhou2000 wants to merge 16 commits into
NVIDIA:mainfrom
harryzhou2000:hhanyu/router_fix_p3R

Conversation

@harryzhou2000
Copy link
Copy Markdown
Member

@harryzhou2000 harryzhou2000 commented May 19, 2026

Summary

Optimizes the fused router CUDA kernels introduced in #2821 (fused_topk_with_score_function and fused_score_for_moe_aux_loss). Achieves significant bandwidth improvements for large expert counts and topk values while preserving identical performance for smaller configurations (e.g., E=256, topk=4).

Key results (B300, float32, 8192 tokens):

  • Forward (E=2304, K=36, softmax): 673 → 964 GB/s (+43%)
  • Backward (E=2304, K=36, softmax): 543 → 2766 GB/s (+410%)
  • Forward (E=512, K=4): no regression (±0.3%)

Changes

Forward kernels

  • Persistent grid with async double-buffered prefetch: RawAsyncLoader<T> uses cp.async (sm_80+) for non-blocking global→shmem loads. Occupancy-aware grid sizing (compute_persistent_grid) keeps all SMs saturated across multiple rounds.
  • Packed 8-bit radix histogram: Reduces radix topk register usage from 32 to 4 registers by packing 16 bucket counts into 4×u32 with 8-bit fields. Eliminates local memory spill at large E.
  • Compile-time score function dispatch: ScoreFunc template parameter with if constexpr removes runtime branches from the hot loop.
  • Simple kernel path for small topk: When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), dispatches to a lightweight kernel matching the original structure — no async loader, no persistent grid — avoiding scheduling overhead that dominates at small K.

Backward kernels

  • Two-pass fused design: Pass 1 accumulates warp-level sums via register reduction + warp_allreduce_sum. Pass 2 computes per-element gradients using scalar helpers. Eliminates the comp_buf shared memory buffer (saves E × warps × 4 bytes per block).
  • Double-buffered async loading: All backward inputs (grad, activation, mask) loaded through RawAsyncLoader with always-on double buffering.

Infrastructure

  • async_loader.h: RawAsyncLoader<T>, compute_persistent_grid(), choose_num_buffers(), vectorized global store/fill helpers.
  • NVTE_RADIX_TOPK_THRESHOLD env var (default 8): configurable naive↔radix crossover.
  • Templated warp_reduce_on_shmem<T, ReduceFuncType> eliminates function-pointer overhead.

Hardening

  • Host-side: num_tokens * num_experts <= INT_MAX, topk ∈ [1, E], topk % group_topk == 0
  • Device-side: assert(data_size <= kMaxExpertsRadixTopk) in radix path
  • Correct cudaDevAttrMaxSharedMemoryPerMultiprocessor for buffer-count decision
  • Fix: single-buffer prefetch clobber when shmem is too tight for double buffering

Compatibility

  • No regression for small configs: The simple forward kernel path is an exact replica of the original kernel structure, ensuring E=256/topk=4 (common in standard MoE) performs identically.
  • All existing tests pass: 891/891 test_fused_router.py tests pass, 117 skipped (fp8/multi-node).
  • No API changes: Same Python/C++ interface, same output semantics.
  • Tunable: Set NVTE_RADIX_TOPK_THRESHOLD=0 to force radix everywhere, or =16 to use naive for topk<16.

Performance (B300 SXM6, sm_103, float32, 8192 tokens)

Effective bandwidth (GB/s) is computed as the minimum bytes that must be transferred to/from global memory for one kernel invocation, divided by the measured wall time. For example, the topk forward kernel reads logits (T×E×dtype) and writes probs (T×E×dtype), routing_map (T×E×1), and intermediate_output (T×E×4). This metric captures how well the kernel utilizes memory bandwidth — higher is better, with the device peak around 8 TB/s on B300. Config format is num_experts/topk.

Full benchmark table (softmax)
kernel pass config before after
topk fprop 512/4 1779 1784 (+0.3%)
topk fprop 512/8 798 904 (+13%)
topk fprop 512/22 514 924 (+80%)
topk fprop 512/36 499 908 (+82%)
topk fprop 2304/4 1803 1802 (0%)
topk fprop 2304/8 660 993 (+51%)
topk fprop 2304/22 602 972 (+61%)
topk fprop 2304/36 673 964 (+43%)
topk bprop 512/22 3391 5362 (+58%)
topk bprop 2304/36 543 2766 (+410%)
aux_loss fprop 512/22 519 896 (+73%)
aux_loss fprop 2304/36 645 891 (+38%)
aux_loss bprop 512/22 5289 6155 (+16%)
aux_loss bprop 2304/36 2272 4201 (+85%)
Full benchmark table (sigmoid)
kernel pass config before after
topk fprop 512/4 1728 1736 (+0.5%)
topk fprop 512/22 470 891 (+90%)
topk fprop 2304/36 639 798 (+25%)
topk bprop 512/22 3169 4398 (+39%)
topk bprop 2304/36 533 2274 (+327%)
aux_loss fprop 512/22 475 912 (+92%)
aux_loss fprop 2304/36 598 867 (+45%)
aux_loss bprop 2304/36 1965 2757 (+40%)

@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch 2 times, most recently from 14a302c to a805f38 Compare May 19, 2026 10:22
@harryzhou2000 harryzhou2000 marked this pull request as ready for review May 20, 2026 08:29
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 20, 2026

Greptile Summary

This PR rewrites the fused router CUDA kernels (fused_topk_with_score_function and fused_score_for_moe_aux_loss) with persistent grids, double-buffered cp.async loads, a packed 8-bit radix histogram, and a two-pass fused backward that eliminates the comp_buf shared-memory buffer, yielding up to 5× bandwidth improvement for large expert configurations.

  • Forward: dispatches to a lightweight "simple" kernel for topk < NVTE_RADIX_TOPK_THRESHOLD (preserving original performance for small-K configs) and to an async-loader + radix-topk kernel above the threshold; shared-memory capacity checks are correctly split per path, addressing the previously flagged P1s.
  • Backward (topk + aux_loss): two-pass design — pass 1 accumulates warp-level sums via register reduction + warp_allreduce_sum, pass 2 computes per-element gradients using scalar helpers; choose_num_buffers is correctly called (fixing the hardcoded kBwdNumBuffers = 2 noted in prior review).
  • Infrastructure (async_loader.h, utils.h): new RawAsyncLoader<T>, compute_persistent_grid, choose_num_buffers, and compile-time-dispatched warp_reduce_on_shmem<T, type>; host-side bounds checks guard against INT_MAX overflow and out-of-range topk.

Confidence Score: 5/5

Safe to merge. The core gradient math is correct for all three score functions and both forward and backward kernels. Previously flagged shared-memory check issues are resolved, and backward kernels now correctly call choose_num_buffers.

The two-pass backward derivation is mathematically sound across all ScoreFunc × use_pre_softmax combinations. The packed 8-bit radix histogram correctly bounds overflow via kMaxExpertsRadixTopk. The simple kernel path ensures no regression for small-topk configurations. Remaining comments are style and performance suggestions that do not affect correctness.

utils.h (pragma unroll + warp_allreduce_sum + break interaction) and async_loader.h (uncached CUDA API calls) are the only points worth a follow-up look.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/async_loader.h New header introducing RawAsyncLoader (cp.async double-buffering), compute_persistent_grid, choose_num_buffers, and vec_store/fill helpers. Logic is sound; minor concern around multiple uncached CUDA API calls per launch.
transformer_engine/common/fused_router/utils.h Refactors warp_reduce_on_shmem to a compile-time template, removes function-pointer dispatch, adds scalar score-function helpers, and replaces the 32-register radix histogram with a packed 4-register version; radix scan unroll with break deserves attention.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds simple kernel path for small topk and replaces the old all-in-one kernel with an async-loader + persistent-grid + radix variant; two-pass backward eliminates comp_buf shmem. Shmem checks are correctly split per path. Previously flagged P1s appear addressed.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Parallel restructuring of aux_loss forward/backward kernels using the same simple/optimized split. Two-pass backward now uses register reduction and correctly separates pass1 (warp sums) from pass2 (element-wise grad write).
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Minimal change: updates warp_reduce_on_shmem call site to use the new compile-time template signature. No logic change.

Reviews (3): Last reviewed commit: "[Common] Fall back to naive topk beyond ..." | Re-trigger Greptile

Comment thread transformer_engine/common/fused_router/async_loader.h
@tdophung tdophung self-assigned this May 20, 2026
Replace multi-loop preprocess (separate clear/load/score/save/bias loops)
with single fused loops per score function in all 4 kernel paths (topk
forward, topk backward, aux_loss forward, aux_loss backward).

Replace multi-pass backward (array-based helpers + comp_buf shmem) with
a two-pass approach using scalar helpers:
  Pass 1: reduction — warp-level sums via warp_allreduce_sum()
  Pass 2: element-wise — scalar gradient computation → write to global

Add scalar helpers to utils.h: sigmoid_scalar, sqrtsoftplus_scalar,
sigmoid_bwd_scalar, sqrtsoftplus_bwd_scalar, normalize_bwd_scalar,
softmax_bwd_scalar.

Remove dead array helpers from utils.h: apply_sigmoid_on_float,
apply_sigmoid_bwd_on_float, apply_sqrtsoftplus_on_float,
apply_sqrtsoftplus_bwd_on_float, apply_softmax_bwd_on_float,
masked_warp_reduce_on_shmem.

Backward shmem reduced by E×W×sizeof(float) per kernel (comp_buf
eliminated).  Net -226 lines across 3 files.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Add async_loader.h with:
  - RawAsyncLoader<T>: cp.async on sm_80+, int4 fallback on sm_70,
    stores data in original type (no conversion during copy)
  - compute_persistent_grid(): occupancy-based grid sizing
  - choose_num_buffers(): shmem-aware 1-vs-2 buffer decision
  - vec_fill_global(), vec_store_global(): vectorized output helpers

Forward kernels (topk + aux_loss):
  - Logits loaded via RawAsyncLoader with double-buffered prefetch
  - Persistent grid replaces 1-shot grid launch
  - DataType→CompType conversion during compute, not during load
  - vec_fill_global for clearing probs/routing_map

Backward kernels (topk + aux_loss):
  - All inputs loaded via RawAsyncLoader (topk: 3 loaders for
    grad/act/mask; aux_loss: 2 loaders for grad/act)
  - Always double-buffered (kBwdNumBuffers=2, kAuxBwdNumBuffers=2)
  - Persistent grid with occupancy-based sizing

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace counts[16] + total_counts[16] (32 registers) with 4 packed u32
registers using 8-bit fields (4 counters per register).  Eliminates
massive register spill to local memory on large kernels (81% of L1
traffic on E=2304, K=36).

Add kMaxExpertsRadixTopk constant (8160 = 255 * 32) and runtime checks
in both forward launchers to guard against 8-bit overflow.  All current
MoE configurations (max E=2304) are well within this limit.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
…dispatch

Replace runtime score_function parameter in all 4 kernel __global__
functions with template int ScoreFunc (0=sigmoid, 1=softmax,
2=sqrtsoftplus).  All score_function branches now use if constexpr,
eliminating dead-code register pressure and branch overhead.

Forward launchers dispatch on TopkFunc × ScoreFunc = 6 instantiations
per DataType.  Backward launchers dispatch on ScoreFunc = 3
instantiations per DataType.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Fix broken topk < 0 threshold (radix was always selected, naive
unreachable).  Replace with configurable NVTE_RADIX_TOPK_THRESHOLD
env var (default 0, i.e. always use radix).  Set to 16 to restore
the old naive-for-small-K behavior.

Uses the standard TE pattern: static local + getenv (read once,
cached for process lifetime).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When choose_num_buffers() returns 1 (shmem too tight for double
buffering, e.g. E=1024 with group_topk scratch), buf_[0] and buf_[1]
alias the same memory.  The prefetch via start_load(next_buf()) then
overwrites the current buffer while compute is still reading it.

Fix: guard the prefetch on num_buffers > 1.  When single-buffered,
load the current round's data at the top of each iteration instead.
The first round's load_current is still issued before the loop.

Backward kernels are unaffected (always kBwdNumBuffers=2).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Code review fixes:

- C1: choose_num_buffers() now queries cudaDevAttrMaxSharedMemoryPerMultiprocessor
  (per-SM budget) instead of cudaDevAttrMaxSharedMemoryPerBlockOptin (per-block
  max).  These coincide on Hopper/Blackwell but differ on Ampere.

- H3: Remove dead fallback branch in choose_num_buffers() — since
  total_double >= total_single always, blocks_single >= blocks_double,
  so the old ternary always returned 1 anyway.

- H4/M8: Add host-side NVTE_CHECK in all 4 launchers:
  - num_experts > 0
  - topk in [1, num_experts]
  - (int64_t)num_tokens * num_experts <= INT_MAX (kernel uses int offsets)

- M9: Assert topk % group_topk == 0 when group_topk > 0.

- H6: Add device-side assert(data_size <= kMaxExpertsRadixTopk) in
  radix_topk_and_mask() — zero cost in release (NDEBUG), catches
  8-bit histogram overflow in debug builds.

- L1: Fix stale comments claiming default threshold is 16 (it is 0).
- L4: Fix typo 'hanlded' -> 'handled'.
- L8: Remove unused topk parameter from aux loss backward kernel.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Move the duplicated static function from both .cu files into utils.h
as an inline function.  Each TU gets its own static local (read-once
per TU), which is safe since environment variables are immutable
during process lifetime.  Documented this in a NOTE comment.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Replace runtime function-pointer dispatch with compile-time if constexpr.
Eliminates indirect call overhead in the reduction loop and warp shuffle
butterfly, allowing the compiler to emit straight-line arithmetic.

Removes the now-unused max<T>() and sum<T>() helper functions.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
When topk < NVTE_RADIX_TOPK_THRESHOLD (default 8), use a lightweight
forward kernel that avoids the async loader and persistent grid overhead.
The simple kernel loads logits directly from global memory to shmem and
uses Naive iterative-argmax topk — matching the baseline structure that
was faster for small K due to lower launch/scheduling overhead.

The optimized path (async loader + persistent grid + radix topk) remains
the default for topk >= 8 where the compute savings dominate.

Both topk and aux_loss forward kernels get the simple variant.
Backward kernels are unchanged (always use the optimized path).

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Use 0.0f instead of 0 to avoid ambiguity between __nv_bfloat16(float)
and __nv_bfloat16(double) constructors on older CUDA toolkits.

Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
@harryzhou2000 harryzhou2000 force-pushed the hhanyu/router_fix_p3R branch from 9a7cb7e to 3bab7cb Compare May 21, 2026 03:03
Comment thread transformer_engine/common/fused_router/fused_topk_with_score_function.cu Outdated
Comment thread transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Outdated
Signed-off-by: Harry Zhou <hhanyu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants