From 496e3ab019d66107cd57fe00da437ecb44d51a8a Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 21 Apr 2026 11:57:08 -0700 Subject: [PATCH 01/18] initial impl Signed-off-by: tdophung --- tests/jax/test_moe_block.py | 292 ++++++++ transformer_engine/jax/flax/__init__.py | 2 + transformer_engine/jax/flax/moe.py | 890 +++++++++++++++++++++++ transformer_engine/jax/mt_permutation.py | 356 +++++++++ 4 files changed, 1540 insertions(+) create mode 100644 tests/jax/test_moe_block.py create mode 100644 transformer_engine/jax/flax/moe.py create mode 100644 transformer_engine/jax/mt_permutation.py diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py new file mode 100644 index 0000000000..458d674c7d --- /dev/null +++ b/tests/jax/test_moe_block.py @@ -0,0 +1,292 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Basic tests for ``transformer_engine.jax.flax.MoEBlock``. + +These tests exercise the MoEBlock on a single device (no expert parallelism) +and verify: + +* Forward pass runs end-to-end and produces the expected output shape. +* Backward pass yields finite, non-trivial parameter gradients. +* The two permutation backends (``"pure_jax"`` and ``"triton"``) produce + numerically equivalent outputs and gradients when given the same routing + decisions. +* Auxiliary load-balancing loss is returned when ``aux_loss_coeff > 0``. +* DeepSeek-style grouped top-k (``num_groups`` / ``group_topk``) runs. +* ``align_size > 0`` produces numerically-equivalent outputs to ``align_size = 0`` + for the pure-JAX backend (padding must not change the result). +""" + +import sys +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + + +# The MoEBlock pulls in both the fused-router CUDA kernel and the Triton +# permutation kernels, so it can only run in the environment where those are +# available. We gate the test on the ``triton`` marker (the Triton permutation +# backend is stricter than the CUDA router). See ``conftest.py``. + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MoEBlock = MoEBlock + yield + + +# ----------------------------------------------------------------------------- +# Configurations +# ----------------------------------------------------------------------------- +# +# Keep shapes small so the tests are cheap but still exercise every code path. + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs( + key: jax.Array, batch_size: int = BATCH_SIZE, sequence_length: int = SEQUENCE_LENGTH +) -> jax.Array: + return jax.random.normal( + key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE + ) + + +def _init_and_apply( + block, + inputs: jax.Array, + init_key: jax.Array, +) -> Tuple[dict, jax.Array, jax.Array]: + variables = block.init(init_key, inputs) + output, aux_loss = block.apply(variables, inputs) + return variables, output, aux_loss + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.triton +class TestMoEBlockSingleDevice: + """Single-device smoke tests for :class:`MoEBlock`.""" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_forward_shape_and_finite(self, permutation_backend): + key = jax.random.PRNGKey(0) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape, ( + f"Unexpected output shape {output.shape} for backend {permutation_backend}" + ) + assert output.dtype == inputs.dtype + assert jnp.all(jnp.isfinite(output)), "Output contains NaN/Inf" + assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_backward_grad(self, permutation_backend): + key = jax.random.PRNGKey(1) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + def loss_fn(variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2) + + grads = jax.grad(loss_fn)(variables, inputs) + # All trainable kernels should receive a non-trivial gradient. + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g = grads["params"][name] + assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" + assert jnp.any(g != 0.0), f"{name} gradient is identically zero" + + def test_pure_jax_triton_equivalence(self): + """Both permutation backends must produce the same forward + grads + under identical routing decisions. + + Since the two backends share the same routing path (TE's fused + top-k), fixing the gate kernel gives both the same routing decisions + and the remainder of the network is identical modulo the permutation + implementation, whose semantics are equivalent. + """ + key = jax.random.PRNGKey(2) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + dtype=DTYPE, + ) + pure_block = MoEBlock(permutation_backend="pure_jax", **base_kwargs) + triton_block = MoEBlock(permutation_backend="triton", **base_kwargs) + inputs = _make_inputs(data_key) + + # Share a single parameter tree so routing decisions and expert + # weights are identical for both backends. + variables = pure_block.init(init_key, inputs) + + def loss_fn(block, variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2), output + + (loss_pj, out_pj), grads_pj = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(pure_block, variables, inputs) + (loss_tr, out_tr), grads_tr = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(triton_block, variables, inputs) + + # BF16 tolerances: outputs come out of the grouped-GEMM + weighted + # sum so they accumulate error; we use ~2 ULPs worth of slack. + atol_out, rtol_out = 5e-2, 5e-2 + assert jnp.allclose(out_pj, out_tr, atol=atol_out, rtol=rtol_out), ( + f"Forward outputs differ across backends: max diff" + f" {jnp.max(jnp.abs(out_pj - out_tr))}" + ) + assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_pj = grads_pj["params"][name] + g_tr = grads_tr["params"][name] + assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( + f"Gradient for {name} differs across backends: max diff" + f" {jnp.max(jnp.abs(g_pj - g_tr))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_aux_loss_returned(self, permutation_backend): + key = jax.random.PRNGKey(3) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert aux_loss is not None, "aux_loss should be returned when coeff > 0" + assert aux_loss.shape == (), "aux_loss should be a scalar" + assert jnp.isfinite(aux_loss) + # With uniform-ish routing the loss should be small-positive, not huge. + assert jnp.abs(aux_loss) < 1e2 + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_group_topk_deepseek(self, permutation_backend): + """Exercise DeepSeek-style grouped top-k routing.""" + key = jax.random.PRNGKey(4) + init_key, data_key = jax.random.split(key) + + # num_groups must divide num_experts. + num_groups = 4 + group_topk = 2 + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, _aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert jnp.all(jnp.isfinite(output)) + + def test_align_size_equivalence_pure_jax(self): + """For the pure-JAX backend, ``align_size > 0`` must not change the + numerical output of the forward pass: padding tokens contribute zero + to every expert GEMM output (their input rows are zeros) and are + stripped before the weighted sum. + """ + key = jax.random.PRNGKey(5) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend="pure_jax", + dtype=DTYPE, + ) + block_no_pad = MoEBlock(align_size=0, **base_kwargs) + block_pad = MoEBlock(align_size=16, **base_kwargs) + inputs = _make_inputs(data_key) + variables = block_no_pad.init(init_key, inputs) + + out_no_pad, _ = block_no_pad.apply(variables, inputs) + out_pad, _ = block_pad.apply(variables, inputs) + assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( + "align_size > 0 must not change pure_jax forward output; max diff" + f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_jit_and_determinism(self, permutation_backend): + """The block must be JIT-compilable and produce a deterministic + forward pass across repeat calls with the same params.""" + key = jax.random.PRNGKey(6) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + @jax.jit + def forward(variables, inputs): + return block.apply(variables, inputs)[0] + + out_a = forward(variables, inputs) + out_b = forward(variables, inputs) + assert jnp.array_equal(out_a, out_b), "JITted forward is non-deterministic" diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 92a968f061..0cd7835bcf 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -9,6 +9,7 @@ make_dot_general_cls, make_grouped_dense_cls, ) +from .moe import MoEBlock from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -18,6 +19,7 @@ "LayerNorm", "LayerNormDenseGeneral", "LayerNormMLP", + "MoEBlock", "wrap_function_in_te_state_module", "make_dot_general_cls", "make_grouped_dense_cls", diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py new file mode 100644 index 0000000000..ddbe687771 --- /dev/null +++ b/transformer_engine/jax/flax/moe.py @@ -0,0 +1,890 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Flax Linen MoEBlock for TransformerEngine JAX. + +This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer +that wires together TE's fused router, a selectable token-dispatch backend +(pure-JAX MaxText-style or Triton), TE's ``grouped_dense``, and optional +ring-of-experts Expert Parallelism. + +See ``plans/te_jax_moeblock_926b7994.plan.md`` for the full design rationale +and the mapping to Maxtext's ``RoutedMoE``. +""" + +from typing import Any, Callable, NewType, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from flax import linen as nn +from jax.sharding import PartitionSpec as P + +from ..dense import grouped_dense +from ..mt_permutation import mt_token_combine, mt_token_dispatch +from ..permutation import token_combine, token_dispatch +from ..quantize import noop_quantizer_set +from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function +from ..sharding import with_sharding_constraint_by_logical_axes +from .module import TransformerEngineBase + +PRNGKey = Any +Shape = Tuple[int, ...] +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) +Initializer = Callable[[PRNGKey, Shape, DType], Array] + + +__all__ = ["MoEBlock"] + + +# ============================================================================= +# Helpers +# ============================================================================= + + +_ACTIVATIONS = { + "silu": jax.nn.silu, + "swish": jax.nn.silu, + "gelu": jax.nn.gelu, + "relu": jax.nn.relu, + "identity": lambda x: x, + "linear": lambda x: x, +} + + +def _get_activation_fn(name: str) -> Callable: + key = name.lower() + if key not in _ACTIVATIONS: + raise ValueError( + f"Unsupported activation_type={name!r}; supported: {sorted(_ACTIVATIONS)}" + ) + return _ACTIVATIONS[key] + + +def _extract_topk_from_routing_map( + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + topk: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert TE's ``(sparse_probs, routing_map)`` to ``(selected_experts, weights)``. + + ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` + with exactly ``topk`` ``True`` positions per row. ``sparse_probs`` is the + same-shape float tensor whose non-zero entries are the routing weights. + + The per-token top-k expert IDs are recovered as the last ``topk`` indices + of ``argsort(routing_map)`` (``False < True``), and the corresponding + weights are gathered from ``sparse_probs`` along the expert axis. + + The within-row expert ordering does not have to match the router's + top-k ordering: :func:`mt_token_dispatch` and :func:`mt_token_combine` + only require that ``selected_experts`` and ``weights`` are consistent with + each other. + """ + # Cast to int32 so argsort has a well-defined ordering. (Ascending argsort + # on 0/1 puts the ``True`` positions last; we then slice the last ``topk``.) + selected_experts = jnp.argsort(routing_map.astype(jnp.int32), axis=-1)[:, -topk:] + weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + return selected_experts, weights + + +# ============================================================================= +# MoEBlock +# ============================================================================= + + +class MoEBlock(TransformerEngineBase): + """Mixture-of-Experts Flax Linen block. + + Encapsulates the full MoE forward pass: gate projection, fused top-k + routing, optional auxiliary load-balancing loss, token dispatch, per-expert + two-layer FFN via grouped GEMMs, activation, token combine, and optional + ring-of-experts expert parallelism. + + The permutation step is pluggable: the default ``permutation_backend="pure_jax"`` + uses the MaxText-style argsort-based dispatch/combine in + :mod:`transformer_engine.jax.mt_permutation`, which empirically outperforms + the Triton kernels on several E2E workloads. ``permutation_backend="triton"`` + uses TE's ``token_dispatch`` / ``token_combine`` kernels. + + Parameters + ---------- + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k value (number of experts each token is routed to). + intermediate_size : int + Per-expert FFN hidden dim. + + activation_type : str + FFN activation applied to the gate projection. Paired with the up + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Supported: + ``"silu"``/``"swish"`` (default), ``"gelu"``, ``"relu"``, + ``"identity"``/``"linear"``. + + score_function : str or ScoreFunction + ``"softmax"`` (default) or ``"sigmoid"`` for :func:`fused_topk_with_score_function`. + use_pre_softmax : bool + Apply softmax before top-k when ``score_function="softmax"``. + num_groups : int + Number of routing groups for grouped top-k (DeepSeek). ``<=0`` disables. + group_topk : int + Top-k at the group level. ``<=0`` disables. + scaling_factor : float + Scaling factor applied to output probs. + use_expert_bias : bool + If ``True``, registers a learnable ``expert_bias`` parameter of shape + ``[num_experts]`` and passes it to the fused router. Only valid with + ``score_function="sigmoid"`` (DeepSeek V3 loss-free load balancing). + aux_loss_coeff : float + If ``> 0``, compute and return the MoE auxiliary load-balancing loss + scalar via :func:`fused_moe_aux_loss`. ``0`` disables. + + gate_kernel_axes : tuple[str, ...] + Logical partitioning axes for the gate kernel of shape + ``[hidden, num_experts]``. + wi_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of + shape ``[num_experts, hidden, intermediate]``. Default: + ``("exp", "embed", "mlp")``. + wo_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wo`` kernel of shape + ``[num_experts, intermediate, hidden]``. Default: + ``("exp", "mlp", "embed")``. + input_axes : tuple[str, ...] + Logical axes used to constrain the input activation sharding at the + block boundary. ``()`` (default) means no constraint. + + expert_parallelism_axis : Optional[str] + Mesh axis along which experts are split. When set, the forward pass + is wrapped in :func:`jax.experimental.shard_map.shard_map` that + implements the ring-of-experts EP strategy: ``all_gather`` on inputs + and gate logits, local routing + dispatch + FFN + combine, then + ``psum_scatter`` on the output. When ``None`` (default), no + ``shard_map`` wrapper is used; each primitive's ``custom_partitioning`` + rule handles DP/FSDP/TP automatically. + tensor_parallelism_axis : Optional[str] + Mesh axis for tensor parallelism on the FFN intermediate dim. When + set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed + along this axis (inside the ``shard_map`` when EP is enabled, else at + the end of the forward pass). + + permutation_backend : str + ``"pure_jax"`` (default; faster on many E2E workloads) or ``"triton"``. + align_size : int + Alignment for per-expert group sizes after padding. ``0`` disables + padding (faster for the unquantized path). ``>0`` is required for + quantized TE grouped GEMM whose recipe-specific alignment must divide + ``align_size``. Passed through to both permutation backends. + use_custom_sort_vjp : bool + Only used when ``permutation_backend="pure_jax"``. If ``True``, uses + a custom VJP for the argsort-based gather (faster in most cases). + + dtype : jnp.dtype + Compute and parameter dtype. + kernel_init : Initializer + Initializer for all kernels. Defaults to ``variance_scaling(1.0, + 'fan_in', 'truncated_normal')``. + use_bias : bool + If ``True``, registers per-expert FFN biases ``wi_0_bias``, + ``wi_1_bias``, ``wo_bias``. + """ + + # Architecture + num_experts: int = 8 + num_experts_per_tok: int = 2 + intermediate_size: int = 2048 + activation_type: str = "silu" + + # Routing + score_function: Union[str, ScoreFunction] = "softmax" + use_pre_softmax: bool = False + num_groups: int = -1 + group_topk: int = -1 + scaling_factor: float = 1.0 + use_expert_bias: bool = False + aux_loss_coeff: float = 0.0 + + # Sharding + gate_kernel_axes: Tuple[Optional[str], ...] = () + wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp") + wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed") + input_axes: Tuple[Optional[str], ...] = () + + # Parallelism + expert_parallelism_axis: Optional[str] = None + tensor_parallelism_axis: Optional[str] = None + # ``jax.sharding.Mesh`` to use when ``expert_parallelism_axis`` is set. + # Required for the ``shard_map`` wrapper; ignored otherwise. + mesh: Optional[Any] = None + + # Permutation + permutation_backend: str = "pure_jax" + align_size: int = 0 + use_custom_sort_vjp: bool = True + + # Dtypes / init / misc + dtype: DType = jnp.float32 + kernel_init: Optional[Initializer] = None + bias_init: Initializer = nn.initializers.zeros + expert_bias_init: Initializer = nn.initializers.zeros + use_bias: bool = False + + def __post_init__(self): + if self.kernel_init is None: + object.__setattr__( + self, + "kernel_init", + nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ), + ) + if self.permutation_backend not in ("pure_jax", "triton"): + raise ValueError( + "permutation_backend must be 'pure_jax' or 'triton'," + f" got {self.permutation_backend!r}" + ) + if self.use_expert_bias: + # ``fused_topk_with_score_function`` only accepts ``expert_bias`` + # under the sigmoid score function. Raise early to surface the + # misconfiguration instead of failing deep inside the kernel. + score_func = ( + self.score_function.name.lower() + if isinstance(self.score_function, ScoreFunction) + else str(self.score_function).lower() + ) + if score_func != "sigmoid": + raise ValueError( + "use_expert_bias=True requires score_function='sigmoid';" + f" got {self.score_function!r}." + ) + super().__post_init__() + + # ------------------------------------------------------------------ + # Parameter registration + # ------------------------------------------------------------------ + + def _make_params(self, hidden_size: int): + """Register module parameters and return them as a dict.""" + gate_kernel = self.param( + "gate_kernel", + nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), + (hidden_size, self.num_experts), + self.dtype, + ) + wi_0 = self.param( + "wi_0", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wi_1 = self.param( + "wi_1", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wo = self.param( + "wo", + nn.with_logical_partitioning(self.kernel_init, self.wo_kernel_axes), + (self.num_experts, self.intermediate_size, hidden_size), + self.dtype, + ) + params = { + "gate_kernel": gate_kernel, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + if self.use_bias: + params["wi_0_bias"] = self.param( + "wi_0_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + params["wi_1_bias"] = self.param( + "wi_1_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + params["wo_bias"] = self.param( + "wo_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "embed")), + (self.num_experts, hidden_size), + self.dtype, + ) + if self.use_expert_bias: + params["expert_bias"] = self.param( + "expert_bias", + nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), + (self.num_experts,), + self.dtype, + ) + return params + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ + + @nn.compact + def __call__( + self, + inputs: Array, + deterministic: bool = True, + ) -> Tuple[Array, Optional[Array]]: + """Run the MoE forward pass. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[batch, sequence, hidden]``. + deterministic : bool + Reserved for future dropout-based routing; currently unused. + + Returns + ------- + output : jnp.ndarray + Output tensor of shape ``[batch, sequence, hidden]``. + aux_loss : Optional[jnp.ndarray] + Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, + else ``None``. + """ + del deterministic # unused for now + + assert inputs.ndim == 3, ( + f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" + ) + inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) + + batch_size, sequence_length, hidden_size = inputs.shape + params = self._make_params(hidden_size) + + # Gate projection runs OUTSIDE the EP shard_map (mirroring Maxtext), + # so that each EP shard projects its own local slice of tokens and we + # later all-gather only the logits, not the full inputs. + gate_logits = self._gate(inputs, params["gate_kernel"]) + + if self.expert_parallelism_axis is not None: + return self._forward_ring_ep(inputs, gate_logits, params) + return self._forward_single_shard(inputs, gate_logits, params) + + # ------------------------------------------------------------------ + # Gate + # ------------------------------------------------------------------ + + def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: + """Linear gate projection ``inputs @ gate_kernel``. + + Kept as a plain matmul (not ``DenseGeneral``) so it integrates cleanly + with the EP shard_map below: the gate matmul runs in the outer + (pre-shard_map) scope and its output is all-gathered along the EP axis + inside the shard_map. + """ + # Cast kernel to input dtype outside FP8 scope (gate is typically BF16/FP32). + kernel = gate_kernel.astype(inputs.dtype) + return jnp.einsum("bsh,he->bse", inputs, kernel) + + # ------------------------------------------------------------------ + # Single-shard (no EP) forward + # ------------------------------------------------------------------ + + def _forward_single_shard( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + batch_size, sequence_length, hidden_size = inputs.shape + + inputs_2d = inputs.reshape(-1, hidden_size) + logits_2d = gate_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map, aux_loss = self._route( + logits_2d, params.get("expert_bias") + ) + + expert_outputs, combine_state = self._dispatch_and_expert_ffn( + inputs_2d, + sparse_probs, + routing_map, + params, + num_experts_local=self.num_experts, + roll_to_expert_id=None, + local_tokens_per_expert_count=self.num_experts, + ) + + output = self._combine( + expert_outputs, + combine_state, + batch_size=batch_size, + sequence_length=sequence_length, + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + return output, aux_loss + + # ------------------------------------------------------------------ + # Ring-of-Experts EP forward + # ------------------------------------------------------------------ + + def _forward_ring_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Wrap the dispatch / FFN / combine pipeline in a ring-of-experts + ``shard_map``. + + Inside the shard_map each EP shard: + 1. ``all_gather`` s the inputs and logits along the EP axis so it + sees every token globally. + 2. Routes with ``roll_to_expert_id = num_experts_per_shard * shard_id`` + so its local experts are in slots ``[0, num_experts_per_shard)``. + 3. Dispatches tokens, slicing ``group_sizes`` to the first + ``num_experts_per_shard`` entries (the rest correspond to remote + experts and should be zero after the roll/mask). + 4. Runs the per-expert FFN on its local expert slice of + ``wi_0`` / ``wi_1`` / ``wo``. + 5. Combines at the expanded-batch shape ``[B * num_ep, S, H]`` then + ``psum_scatter`` s along the EP axis to return the local slice. + """ + from jax.experimental.shard_map import shard_map + + ep_axis = self.expert_parallelism_axis + if self.mesh is None: + raise ValueError( + "MoEBlock.expert_parallelism_axis is set; `mesh` must also be" + " provided so the ring-of-experts shard_map can be built." + ) + mesh = self.mesh + num_ep = mesh.shape[ep_axis] + assert self.num_experts % num_ep == 0, ( + f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" + ) + num_experts_per_shard = self.num_experts // num_ep + + # in_specs / out_specs use PartitionSpec over the EP axis for inputs/ + # outputs (leading batch dim is split across EP) and ``P("exp", ...)`` + # for the expert weights, where we require the user's logical axis + # rules to map ``"exp"`` to the EP mesh axis. The expert bias is + # similarly sharded along the expert axis. + inputs_spec = P(ep_axis, None, None) + logits_spec = P(ep_axis, None, None) + wi_spec = P(ep_axis, None, None) + wo_spec = P(ep_axis, None, None) + output_spec = P(ep_axis, None, None) + scalar_spec = P() + bias_1d_spec = P(ep_axis) + bias_2d_spec = P(ep_axis, None) + + expert_bias_value = params.get("expert_bias") + wi_0_bias_value = params.get("wi_0_bias") + wi_1_bias_value = params.get("wi_1_bias") + wo_bias_value = params.get("wo_bias") + + in_specs = [ + inputs_spec, + logits_spec, + wi_spec, + wi_spec, + wo_spec, + ] + captured = [ + inputs, + gate_logits, + params["wi_0"], + params["wi_1"], + params["wo"], + ] + if expert_bias_value is not None: + in_specs.append(bias_1d_spec) + captured.append(expert_bias_value) + if wi_0_bias_value is not None: + in_specs.extend([bias_2d_spec, bias_2d_spec, bias_2d_spec]) + captured.extend([wi_0_bias_value, wi_1_bias_value, wo_bias_value]) + + out_specs = (output_spec, scalar_spec) + + use_expert_bias = expert_bias_value is not None + use_bias = wi_0_bias_value is not None + + def _ring_fn(*args): + idx = 0 + local_inputs = args[idx]; idx += 1 + local_gate_logits = args[idx]; idx += 1 + local_wi_0 = args[idx]; idx += 1 + local_wi_1 = args[idx]; idx += 1 + local_wo = args[idx]; idx += 1 + local_expert_bias = None + if use_expert_bias: + local_expert_bias = args[idx]; idx += 1 + local_wi_0_bias = local_wi_1_bias = local_wo_bias = None + if use_bias: + local_wi_0_bias = args[idx]; idx += 1 + local_wi_1_bias = args[idx]; idx += 1 + local_wo_bias = args[idx]; idx += 1 + + shard_id = jax.lax.axis_index(ep_axis) + + # All-gather inputs and logits along the EP axis so each shard + # sees the global tokens. + gathered_inputs = jax.lax.all_gather( + local_inputs, axis_name=ep_axis, tiled=True + ) + gathered_logits = jax.lax.all_gather( + local_gate_logits, axis_name=ep_axis, tiled=True + ) + + # If the user also sharded by EP on the expert_bias, ``local_expert_bias`` + # is already the local slice; the router operates over the full + # expert axis, so all-gather to reconstruct. + global_expert_bias = None + if local_expert_bias is not None: + global_expert_bias = jax.lax.all_gather( + local_expert_bias, axis_name=ep_axis, tiled=True + ) + + batch_size = gathered_inputs.shape[0] + sequence_length = gathered_inputs.shape[1] + hidden_size = gathered_inputs.shape[2] + + inputs_2d = gathered_inputs.reshape(-1, hidden_size) + logits_2d = gathered_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map, aux_loss = self._route( + logits_2d, global_expert_bias + ) + + # Ring-of-experts roll: after rolling expert columns by + # ``-num_experts_per_shard * shard_id``, this shard's experts + # occupy slots ``[0, num_experts_per_shard)`` in ``routing_map`` + # and ``sparse_probs``. + # + # For the Triton backend we additionally mask the remote-expert + # columns to False/0 so ``token_dispatch`` never writes those + # tokens into the local permuted buffer. For the pure-JAX backend + # we leave the routing_map untouched (mirroring Maxtext): the roll + # passed to ``mt_token_dispatch`` sorts remote-expert tokens past + # the local slots, and we later zero out those garbage rows of + # ``expert_outputs`` before the combine. + roll = num_experts_per_shard * shard_id + routing_map = jnp.roll(routing_map, -roll, axis=-1) + sparse_probs = jnp.roll(sparse_probs, -roll, axis=-1) + if self.permutation_backend == "triton": + local_expert_mask = ( + jnp.arange(self.num_experts) < num_experts_per_shard + ) + routing_map = routing_map * local_expert_mask[None, :] + sparse_probs = sparse_probs * local_expert_mask[None, :].astype( + sparse_probs.dtype + ) + + # Build a reduced-expert view of the weights: the outer ``shard_map`` + # has already sliced the leading expert axis down to + # ``num_experts_per_shard`` per shard. Pass it through as-is to the + # dispatch / expert-FFN path with ``num_experts_local = num_experts_per_shard``. + local_params = { + "gate_kernel": None, # unused past gate + "wi_0": local_wi_0, + "wi_1": local_wi_1, + "wo": local_wo, + } + if use_bias: + local_params["wi_0_bias"] = local_wi_0_bias + local_params["wi_1_bias"] = local_wi_1_bias + local_params["wo_bias"] = local_wo_bias + + expert_outputs, combine_state = self._dispatch_and_expert_ffn( + inputs_2d, + sparse_probs, + routing_map, + local_params, + num_experts_local=num_experts_per_shard, + roll_to_expert_id=0, # roll is already applied on routing_map + local_tokens_per_expert_count=num_experts_per_shard, + ) + + # For the pure-JAX backend in ring-EP mode, zero out expert-output + # rows that correspond to remote experts (which ``grouped_dense`` + # leaves as garbage since ``group_sizes`` was truncated to the + # local slice). Without this, the unsort + weighted-sum in + # combine would mix garbage into every token's output. Matches + # ``moe.py:1731-1733`` in Maxtext. + if self.permutation_backend == "pure_jax": + real_mask = ( + jnp.arange(expert_outputs.shape[0]) + < combine_state["local_real_size"] + ) + expert_outputs = jnp.where( + real_mask[:, None], expert_outputs, 0 + ) + + output = self._combine( + expert_outputs, + combine_state, + batch_size=batch_size, + sequence_length=sequence_length, + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``output`` is [B*num_ep, S, H] (global batch after all-gather); + # psum_scatter along EP returns the local [B, S, H] slice. + output = jax.lax.psum_scatter( + output, + ep_axis, + scatter_dimension=0, + tiled=True, + ) + + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss + + output, aux_loss = shard_map( + _ring_fn, + mesh=mesh, + in_specs=tuple(in_specs), + out_specs=out_specs, + check_rep=False, + )(*captured) + + if self.aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss + + # ------------------------------------------------------------------ + # Route + # ------------------------------------------------------------------ + + def _route( + self, + logits_2d: jnp.ndarray, + expert_bias: Optional[jnp.ndarray], + ) -> Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]: + """Run the fused router and optional aux-loss.""" + sparse_probs, routing_map = fused_topk_with_score_function( + logits_2d, + topk=self.num_experts_per_tok, + use_pre_softmax=self.use_pre_softmax, + num_groups=self.num_groups, + group_topk=self.group_topk, + scaling_factor=self.scaling_factor, + score_function=self.score_function, + expert_bias=expert_bias, + ) + sparse_probs = sparse_probs.astype(self.dtype) + + aux_loss = None + if self.aux_loss_coeff > 0.0: + # The score-for-aux kernel runs independently (no data dependency + # on the main kernel), so XLA can overlap them on the GPU. + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d, + topk=self.num_experts_per_tok, + score_function=self.score_function, + compute_aux_scores=True, + ) + aux_tokens_per_expert = jnp.sum( + aux_routing_map.astype(jnp.int32), axis=0 + ) + aux_loss = fused_moe_aux_loss( + aux_scores, + aux_tokens_per_expert, + topk=self.num_experts_per_tok, + coeff=self.aux_loss_coeff, + ) + + return sparse_probs, routing_map, aux_loss + + # ------------------------------------------------------------------ + # Dispatch + expert FFN + # ------------------------------------------------------------------ + + def _dispatch_and_expert_ffn( + self, + inputs_2d: jnp.ndarray, + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + params: dict, + num_experts_local: int, + roll_to_expert_id: Optional[int], + local_tokens_per_expert_count: int, + ) -> Tuple[jnp.ndarray, dict]: + """Dispatch tokens, run the three grouped GEMMs + activation, return expert outputs. + + Returns a tuple ``(expert_outputs, combine_state)`` where + ``combine_state`` carries the per-backend state needed to rebuild the + original token ordering in :meth:`_combine`. + """ + num_tokens = inputs_2d.shape[0] + topk = self.num_experts_per_tok + + if self.permutation_backend == "pure_jax": + selected_experts, routing_weights = _extract_topk_from_routing_map( + sparse_probs, routing_map, topk + ) + sorted_inputs, perm_state, group_sizes = mt_token_dispatch( + inputs_2d, + selected_experts, + num_experts=self.num_experts, + num_experts_per_tok=topk, + align_size=self.align_size, + roll_to_expert_id=roll_to_expert_id, + use_custom_sort_vjp=self.use_custom_sort_vjp, + ) + # Slice group_sizes to just this shard's experts. When not using + # EP, ``num_experts_local == self.num_experts`` so this is a no-op. + group_sizes = group_sizes[:local_tokens_per_expert_count] + # ``local_real_size = sum(group_sizes)`` is the number of permuted + # rows that actually correspond to tokens routed to this shard's + # experts. Used by the ring-EP caller to zero out garbage rows + # before combine. + combine_state = { + "backend": "pure_jax", + "perm_state": perm_state, + "routing_weights": routing_weights, + "local_real_size": jnp.sum(group_sizes), + } + else: # "triton" + num_out_tokens = num_tokens * topk + align_size_arg = self.align_size if self.align_size > 0 else None + ( + sorted_inputs, + _permuted_probs, + row_id_map, + pad_offsets, + group_sizes, + ) = token_dispatch( + inputs_2d, + routing_map, + num_out_tokens=num_out_tokens, + probs=sparse_probs, + align_size=align_size_arg, + ) + group_sizes = group_sizes[:local_tokens_per_expert_count] + combine_state = { + "backend": "triton", + "row_id_map": row_id_map, + "pad_offsets": pad_offsets, + "merging_probs": sparse_probs, + "group_sizes": group_sizes, + } + + # ------------------------------------------------------------------ + # Expert FFN: grouped GEMMs w0, w1 + activation + w_o. + # ------------------------------------------------------------------ + wi_0 = params["wi_0"] + wi_1 = params["wi_1"] + wo = params["wo"] + + # Each grouped_dense call gets its own quantizer_set with + # ``n_groups=num_experts_local``; this matches the shape of + # ``group_sizes`` passed in and keeps the quantizer FP8 meta correctly + # sized per shard. + q_set_w0 = self.generate_quantizer_set( + postfix="_w0", n_groups=num_experts_local + ) + q_set_w1 = self.generate_quantizer_set( + postfix="_w1", n_groups=num_experts_local + ) + q_set_wo = self.generate_quantizer_set( + postfix="_wo", n_groups=num_experts_local + ) + + # Cast kernels to the sort dtype when no FP8 quantization is active + # (mirrors DenseGeneral). + if q_set_w0 == noop_quantizer_set: + wi_0 = wi_0.astype(sorted_inputs.dtype) + if q_set_w1 == noop_quantizer_set: + wi_1 = wi_1.astype(sorted_inputs.dtype) + if q_set_wo == noop_quantizer_set: + wo = wo.astype(sorted_inputs.dtype) + + # ``grouped_dense`` accepts per-expert bias of shape (G, N); it adds + # ``bias[i]`` to the ``group_sizes[i]`` rows belonging to expert ``i`` + # in the permuted layout. + wi_0_bias = params.get("wi_0_bias") if self.use_bias else None + wi_1_bias = params.get("wi_1_bias") if self.use_bias else None + wo_bias = params.get("wo_bias") if self.use_bias else None + + layer_w0 = grouped_dense( + sorted_inputs, + wi_0, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_0_bias, + quantizer_set=q_set_w0, + ) + layer_w1 = grouped_dense( + sorted_inputs, + wi_1, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_1_bias, + quantizer_set=q_set_w1, + ) + + act_fn = _get_activation_fn(self.activation_type) + intermediate = act_fn(layer_w0) * layer_w1 + + expert_outputs = grouped_dense( + intermediate, + wo, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wo_bias, + quantizer_set=q_set_wo, + ) + + return expert_outputs, combine_state + + # ------------------------------------------------------------------ + # Combine + # ------------------------------------------------------------------ + + def _combine( + self, + expert_outputs: jnp.ndarray, + combine_state: dict, + batch_size: int, + sequence_length: int, + ) -> jnp.ndarray: + if combine_state["backend"] == "pure_jax": + return mt_token_combine( + expert_outputs, + combine_state["perm_state"], + combine_state["routing_weights"], + num_experts_per_tok=self.num_experts_per_tok, + batch_size=batch_size, + sequence_length=sequence_length, + use_custom_sort_vjp=self.use_custom_sort_vjp, + ) + # triton + out_2d = token_combine( + expert_outputs, + combine_state["row_id_map"], + merging_probs=combine_state["merging_probs"], + pad_offsets=combine_state["pad_offsets"], + ) + hidden_size = out_2d.shape[-1] + return out_2d.reshape(batch_size, sequence_length, hidden_size).astype( + self.dtype + ) diff --git a/transformer_engine/jax/mt_permutation.py b/transformer_engine/jax/mt_permutation.py new file mode 100644 index 0000000000..10882501ec --- /dev/null +++ b/transformer_engine/jax/mt_permutation.py @@ -0,0 +1,356 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Pure-JAX MoE Permutation API. + +This module provides a MaxText-style, pure-JAX implementation of MoE token +dispatch / combine as an alternative to the Triton-backed primitives in +``transformer_engine.jax.permutation``. Empirically this path has been faster +than the Triton kernels on several E2E workloads. + +The core design mirrors Maxtext's ``_mt_permute`` / ``_mt_unpermute`` in +``maxtext/src/maxtext/layers/moe.py``, with alignment-padding support ported +from `nvjax-svc-0/maxtext PR #36 `_ +so each expert's group size is a multiple of ``align_size`` (required for +quantized grouped GEMM whose recipe-specific alignment must divide +``align_size``). + +When ``align_size = 0`` padding is disabled (faster for the unquantized path); +when ``align_size > 0`` a static-size padding buffer of shape +``[num_experts * (align_size - 1)]`` is appended before the sort so the overall +shape is JIT-compatible. + +The public API is: + +* :func:`mt_token_dispatch` -- pure-JAX counterpart of ``token_dispatch``. +* :func:`mt_token_combine` -- pure-JAX counterpart of ``token_combine``. +* :class:`MTPermState` -- opaque state returned by ``mt_token_dispatch`` and + consumed by ``mt_token_combine``. +""" + +from typing import NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp + +__all__ = [ + "MTPermState", + "mt_token_dispatch", + "mt_token_combine", +] + + +# ============================================================================= +# Custom-VJP argsort-based gather (``_sort_activations_custom``) +# ============================================================================= +# +# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. +# Using a custom VJP lets the backward pass exploit that inverse instead of +# relying on the compiler to discover it from the scatter-style default +# gradient of a gather, which is typically less efficient. + + +@jax.custom_vjp +def _sort_activations_custom(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: + """Sort ``inputs`` along the leading dim by ``sort_indices``.""" + return inputs[sort_indices, ...] + + +def _sort_activations_custom_fwd( + inputs: jax.Array, sort_indices: jax.Array +) -> Tuple[jax.Array, jax.Array]: + return _sort_activations_custom(inputs, sort_indices), sort_indices + + +def _sort_activations_custom_bwd( + residuals: jax.Array, grads: jax.Array +) -> Tuple[jax.Array, None]: + sort_indices = residuals + # Inverse permutation: gather-by-argsort undoes the forward gather. + return _sort_activations_custom(grads, jnp.argsort(sort_indices)), None + + +_sort_activations_custom.defvjp(_sort_activations_custom_fwd, _sort_activations_custom_bwd) + + +def _sort_activations( + inputs: jax.Array, + sort_indices: jax.Array, + use_custom_vjp: bool, +) -> jax.Array: + """Sort activations by ``sort_indices``, optionally with the custom VJP.""" + assert inputs.shape[0] == sort_indices.shape[0], ( + f"inputs.shape[0]={inputs.shape[0]} must match" + f" sort_indices.shape[0]={sort_indices.shape[0]}" + ) + with jax.named_scope("mt_sort_activations"): + if use_custom_vjp: + return _sort_activations_custom(inputs, sort_indices) + return inputs[sort_indices, ...] + + +# ============================================================================= +# Permutation state carried from dispatch to combine +# ============================================================================= + + +class MTPermState(NamedTuple): + """Opaque state produced by :func:`mt_token_dispatch`. + + Attributes + ---------- + sorted_indices : jnp.ndarray + The argsort indices used in the forward sort. Needed to reverse the + permutation in :func:`mt_token_combine`. Shape + ``[num_real_tokens + padding_size]``. + num_real_tokens : int + Number of real (non-padding) permuted tokens, i.e. + ``batch_size * sequence_length * num_experts_per_tok``. Compile-time + constant. + padding_size : int + Number of alignment-padding tokens appended to the sort buffer. Equals + ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. + Compile-time constant. + """ + + sorted_indices: jax.Array + num_real_tokens: int + padding_size: int + + +# ============================================================================= +# Dispatch (permute) +# ============================================================================= + + +def mt_token_dispatch( + inputs: jnp.ndarray, + selected_experts: jnp.ndarray, + num_experts: int, + num_experts_per_tok: int, + align_size: int = 0, + roll_to_expert_id: Optional[int] = None, + use_custom_sort_vjp: bool = True, +) -> Tuple[jnp.ndarray, MTPermState, jnp.ndarray]: + """Pure-JAX MaxText-style token dispatch. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[num_tokens, hidden_size]`` (or + ``[batch, seq, hidden]``; it will be flattened). + selected_experts : jnp.ndarray + Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or + ``[batch, seq, num_experts_per_tok]``). Integer dtype. + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k. Must equal ``selected_experts.shape[-1]``. + align_size : int, default 0 + Alignment for each expert's group size. ``0`` disables padding; a value + ``> 0`` appends a static-size padding buffer so each resulting group + size is a multiple of ``align_size``. + roll_to_expert_id : Optional[int] + If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo + ``num_experts`` before the sort (ring-of-experts EP). The returned + ``group_sizes`` is rolled to match. + use_custom_sort_vjp : bool, default True + Whether to use the custom-VJP argsort gather for the sort. + + Returns + ------- + sorted_inputs : jnp.ndarray + Permuted tokens grouped by expert, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : MTPermState + State needed by :func:`mt_token_combine`. + group_sizes : jnp.ndarray + Token count per expert, shape ``[num_experts]``. Each entry is a + multiple of ``align_size`` when ``align_size > 0``. + """ + assert num_experts_per_tok == selected_experts.shape[-1], ( + f"num_experts_per_tok={num_experts_per_tok} must match" + f" selected_experts.shape[-1]={selected_experts.shape[-1]}" + ) + assert align_size >= 0, f"align_size must be >= 0, got {align_size}" + + hidden_size = inputs.shape[-1] + # Flatten token dims. + inputs_2d = inputs.reshape(-1, hidden_size) + num_tokens = inputs_2d.shape[0] + num_real_tokens = num_tokens * num_experts_per_tok + + flatten_selected_experts = jnp.ravel(selected_experts) + + if align_size > 0: + # Per-expert token count, and how many extra tokens each expert needs + # to become aligned to ``align_size``. Using + # ``(align - count % align) % align`` gives 0 (not ``align``) when + # already aligned, so we never exceed the per-expert slot capacity of + # ``align_size - 1``. + token_count_per_expert = jnp.bincount( + flatten_selected_experts, length=num_experts + ) + padding_tokens_required_per_expert = ( + (align_size - (token_count_per_expert % align_size)) % align_size + ) + + # Build a static-size padding buffer of shape + # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot + # of ``align_size - 1`` positions (worst-case padding, which occurs + # when ``token_count[i] % align_size == 1``). Within slot ``i``, + # positions ``[0, padding_needed)`` are assigned expert ``i`` and act + # as real padding; the rest are assigned to ``num_experts - 1`` as + # overflow placeholders that keep the buffer statically sized for JIT. + max_padding_per_expert = align_size - 1 + max_total_padding_size = num_experts * max_padding_per_expert + positions = jnp.arange(max_total_padding_size) + expert_for_pos = positions // max_padding_per_expert + offset_in_slot = positions % max_padding_per_expert + padding_needed = padding_tokens_required_per_expert[expert_for_pos] + flatten_padding_selected_experts = jnp.where( + offset_in_slot < padding_needed, + expert_for_pos, + num_experts - 1, + ) + + flatten_selected_experts = jnp.concatenate( + [flatten_selected_experts, flatten_padding_selected_experts], axis=0 + ) + + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + # Pad inputs with zeros so the sort operand shape matches the expanded + # selected-experts vector. + replicated_inputs_2d = jnp.pad( + replicated_inputs_2d, + pad_width=((0, max_total_padding_size), (0, 0)), + mode="constant", + constant_values=0.0, + ) + + sorted_inputs = _sort_activations( + replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp + ) + + # Compute ``group_sizes`` directly from counts rather than via + # ``bincount(flatten_selected_experts)``: the overflow placeholder + # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the + # alignment guarantee. Direct computation gives each expert exactly + # ``ceil(count / align) * align`` tokens. + group_sizes = token_count_per_expert + padding_tokens_required_per_expert + + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = max_total_padding_size + else: + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + sorted_inputs = _sort_activations( + replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp + ) + + group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = 0 + + perm_state = MTPermState( + sorted_indices=sorted_selected_experts, + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return sorted_inputs, perm_state, group_sizes + + +# ============================================================================= +# Combine (unpermute + weighted sum) +# ============================================================================= + + +def mt_token_combine( + expert_outputs: jnp.ndarray, + perm_state: MTPermState, + routing_weights: jnp.ndarray, + num_experts_per_tok: int, + batch_size: int, + sequence_length: int, + use_custom_sort_vjp: bool = True, +) -> jnp.ndarray: + """Pure-JAX MaxText-style token combine. + + Reverses the permutation performed by :func:`mt_token_dispatch`, strips + any alignment-padding rows appended during dispatch, and applies a + per-token weighted sum across the top-k experts. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the expert FFN, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : MTPermState + State returned by :func:`mt_token_dispatch`. + routing_weights : jnp.ndarray + Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` + (or broadcastable to it after a ``reshape``). + num_experts_per_tok : int + Top-k. + batch_size : int + Original batch size. + sequence_length : int + Original sequence length. + use_custom_sort_vjp : bool, default True + Whether to use the custom-VJP argsort gather for the unsort. + + Returns + ------- + output : jnp.ndarray + Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. + """ + # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes + # ``input[sorted_indices]``. + unsort_intermediate = _sort_activations( + expert_outputs, + jnp.argsort(perm_state.sorted_indices), + use_custom_sort_vjp, + ) + + # Strip alignment padding tokens appended during dispatch. After unsorting, + # the first ``num_real_tokens`` rows hold the real per-(token, top-k) + # outputs; any trailing rows are padding placeholders (zeros) and must be + # discarded before the reshape below. + if perm_state.padding_size > 0: + unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] + + hidden_size = unsort_intermediate.shape[-1] + reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) + ) + + # Cast weights to match intermediate dtype (weighted sum happens in + # intermediate dtype; callers can upcast before calling if higher + # precision weight-sum is desired). + reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) + with jax.named_scope("mt_weight_sum"): + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate, + reshaped_weights, + ) + return output.reshape(batch_size, sequence_length, hidden_size) From f453137c82b103ec302540a7259a9a0caa9e0d03 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 21 Apr 2026 17:24:16 -0700 Subject: [PATCH 02/18] clean up any link to Maxtext. Permutation backends. clean up foward body single GPU vs. multi GPU Signed-off-by: tdophung --- transformer_engine/jax/flax/moe.py | 492 +++++++++-------------- transformer_engine/jax/mt_permutation.py | 356 ---------------- transformer_engine/jax/permutation.py | 336 +++++++++++++++- 3 files changed, 514 insertions(+), 670 deletions(-) delete mode 100644 transformer_engine/jax/mt_permutation.py diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index ddbe687771..6673ac1a71 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -6,11 +6,8 @@ This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer that wires together TE's fused router, a selectable token-dispatch backend -(pure-JAX MaxText-style or Triton), TE's ``grouped_dense``, and optional +(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and optional ring-of-experts Expert Parallelism. - -See ``plans/te_jax_moeblock_926b7994.plan.md`` for the full design rationale -and the mapping to Maxtext's ``RoutedMoE``. """ from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -21,12 +18,17 @@ from jax.sharding import PartitionSpec as P from ..dense import grouped_dense -from ..mt_permutation import mt_token_combine, mt_token_dispatch -from ..permutation import token_combine, token_dispatch +from ..permutation import ( + _routing_map_to_selected_experts, + token_combine, + token_dispatch, + unfused_token_combine, + unfused_token_dispatch, +) from ..quantize import noop_quantizer_set from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function from ..sharding import with_sharding_constraint_by_logical_axes -from .module import TransformerEngineBase +from .module import TransformerEngineBase, _convert_to_activation_function PRNGKey = Any Shape = Tuple[int, ...] @@ -38,57 +40,6 @@ __all__ = ["MoEBlock"] -# ============================================================================= -# Helpers -# ============================================================================= - - -_ACTIVATIONS = { - "silu": jax.nn.silu, - "swish": jax.nn.silu, - "gelu": jax.nn.gelu, - "relu": jax.nn.relu, - "identity": lambda x: x, - "linear": lambda x: x, -} - - -def _get_activation_fn(name: str) -> Callable: - key = name.lower() - if key not in _ACTIVATIONS: - raise ValueError( - f"Unsupported activation_type={name!r}; supported: {sorted(_ACTIVATIONS)}" - ) - return _ACTIVATIONS[key] - - -def _extract_topk_from_routing_map( - sparse_probs: jnp.ndarray, - routing_map: jnp.ndarray, - topk: int, -) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Convert TE's ``(sparse_probs, routing_map)`` to ``(selected_experts, weights)``. - - ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` - with exactly ``topk`` ``True`` positions per row. ``sparse_probs`` is the - same-shape float tensor whose non-zero entries are the routing weights. - - The per-token top-k expert IDs are recovered as the last ``topk`` indices - of ``argsort(routing_map)`` (``False < True``), and the corresponding - weights are gathered from ``sparse_probs`` along the expert axis. - - The within-row expert ordering does not have to match the router's - top-k ordering: :func:`mt_token_dispatch` and :func:`mt_token_combine` - only require that ``selected_experts`` and ``weights`` are consistent with - each other. - """ - # Cast to int32 so argsort has a well-defined ordering. (Ascending argsort - # on 0/1 puts the ``True`` positions last; we then slice the last ``topk``.) - selected_experts = jnp.argsort(routing_map.astype(jnp.int32), axis=-1)[:, -topk:] - weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) - return selected_experts, weights - - # ============================================================================= # MoEBlock # ============================================================================= @@ -102,11 +53,11 @@ class MoEBlock(TransformerEngineBase): two-layer FFN via grouped GEMMs, activation, token combine, and optional ring-of-experts expert parallelism. - The permutation step is pluggable: the default ``permutation_backend="pure_jax"`` - uses the MaxText-style argsort-based dispatch/combine in - :mod:`transformer_engine.jax.mt_permutation`, which empirically outperforms - the Triton kernels on several E2E workloads. ``permutation_backend="triton"`` - uses TE's ``token_dispatch`` / ``token_combine`` kernels. + The permutation step is pluggable via ``permutation_backend``: + ``"pure_jax"`` (default) uses the pure-JAX argsort-based + ``unfused_token_dispatch`` / ``unfused_token_combine`` in + :mod:`transformer_engine.jax.permutation`; ``"triton"`` uses TE's fused + ``token_dispatch`` / ``token_combine`` kernels. Parameters ---------- @@ -119,9 +70,9 @@ class MoEBlock(TransformerEngineBase): activation_type : str FFN activation applied to the gate projection. Paired with the up - projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Supported: - ``"silu"``/``"swish"`` (default), ``"gelu"``, ``"relu"``, - ``"identity"``/``"linear"``. + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Resolved + via :func:`flax.linen.` (``"silu"``, ``"gelu"``, ``"relu"``, + ``"swish"``, ...) plus ``"linear"`` for identity. score_function : str or ScoreFunction ``"softmax"`` (default) or ``"sigmoid"`` for :func:`fused_topk_with_score_function`. @@ -135,8 +86,8 @@ class MoEBlock(TransformerEngineBase): Scaling factor applied to output probs. use_expert_bias : bool If ``True``, registers a learnable ``expert_bias`` parameter of shape - ``[num_experts]`` and passes it to the fused router. Only valid with - ``score_function="sigmoid"`` (DeepSeek V3 loss-free load balancing). + ``[num_experts]`` and passes it to the fused router. The router + primitive validates that this is paired with ``score_function="sigmoid"``. aux_loss_coeff : float If ``> 0``, compute and return the MoE auxiliary load-balancing loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. @@ -171,21 +122,18 @@ class MoEBlock(TransformerEngineBase): the end of the forward pass). permutation_backend : str - ``"pure_jax"`` (default; faster on many E2E workloads) or ``"triton"``. + ``"pure_jax"`` (default) or ``"triton"``. align_size : int Alignment for per-expert group sizes after padding. ``0`` disables padding (faster for the unquantized path). ``>0`` is required for quantized TE grouped GEMM whose recipe-specific alignment must divide - ``align_size``. Passed through to both permutation backends. - use_custom_sort_vjp : bool - Only used when ``permutation_backend="pure_jax"``. If ``True``, uses - a custom VJP for the argsort-based gather (faster in most cases). + ``align_size``. dtype : jnp.dtype Compute and parameter dtype. kernel_init : Initializer - Initializer for all kernels. Defaults to ``variance_scaling(1.0, - 'fan_in', 'truncated_normal')``. + Initializer for all kernels (gate + per-expert FFN). Defaults to + ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax convention). use_bias : bool If ``True``, registers per-expert FFN biases ``wi_0_bias``, ``wi_1_bias``, ``wo_bias``. @@ -222,7 +170,6 @@ class MoEBlock(TransformerEngineBase): # Permutation permutation_backend: str = "pure_jax" align_size: int = 0 - use_custom_sort_vjp: bool = True # Dtypes / init / misc dtype: DType = jnp.float32 @@ -245,20 +192,6 @@ def __post_init__(self): "permutation_backend must be 'pure_jax' or 'triton'," f" got {self.permutation_backend!r}" ) - if self.use_expert_bias: - # ``fused_topk_with_score_function`` only accepts ``expert_bias`` - # under the sigmoid score function. Raise early to surface the - # misconfiguration instead of failing deep inside the kernel. - score_func = ( - self.score_function.name.lower() - if isinstance(self.score_function, ScoreFunction) - else str(self.score_function).lower() - ) - if score_func != "sigmoid": - raise ValueError( - "use_expert_bias=True requires score_function='sigmoid';" - f" got {self.score_function!r}." - ) super().__post_init__() # ------------------------------------------------------------------ @@ -330,19 +263,13 @@ def _make_params(self, hidden_size: int): # ------------------------------------------------------------------ @nn.compact - def __call__( - self, - inputs: Array, - deterministic: bool = True, - ) -> Tuple[Array, Optional[Array]]: + def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: """Run the MoE forward pass. Parameters ---------- inputs : jnp.ndarray Input tensor of shape ``[batch, sequence, hidden]``. - deterministic : bool - Reserved for future dropout-based routing; currently unused. Returns ------- @@ -352,24 +279,39 @@ def __call__( Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, else ``None``. """ - del deterministic # unused for now - assert inputs.ndim == 3, ( f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" ) inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) - batch_size, sequence_length, hidden_size = inputs.shape + _, _, hidden_size = inputs.shape params = self._make_params(hidden_size) - # Gate projection runs OUTSIDE the EP shard_map (mirroring Maxtext), - # so that each EP shard projects its own local slice of tokens and we - # later all-gather only the logits, not the full inputs. + # Gate runs OUTSIDE the EP shard_map below, so each EP shard projects + # its own local slice of tokens and we later all-gather only the + # smaller logits tensor instead of the full inputs. gate_logits = self._gate(inputs, params["gate_kernel"]) - if self.expert_parallelism_axis is not None: - return self._forward_ring_ep(inputs, gate_logits, params) - return self._forward_single_shard(inputs, gate_logits, params) + if self.expert_parallelism_axis is None: + # No EP: each primitive's own ``custom_partitioning`` rule handles + # DP / FSDP / TP across the mesh - no shard_map needed. + output, aux_loss = self._forward_body( + inputs, + gate_logits, + params, + num_experts_local=self.num_experts, + roll_to_expert_id=None, + ) + else: + # Ring-EP: ``_forward_body`` is wrapped in a shard_map that + # orchestrates the cross-primitive collectives (all_gather inputs + # / logits before, psum_scatter output after) which per-primitive + # ``custom_partitioning`` cannot express on its own. + output, aux_loss = self._forward_ring_ep(inputs, gate_logits, params) + + if self.aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss # ------------------------------------------------------------------ # Gate @@ -379,26 +321,34 @@ def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: """Linear gate projection ``inputs @ gate_kernel``. Kept as a plain matmul (not ``DenseGeneral``) so it integrates cleanly - with the EP shard_map below: the gate matmul runs in the outer - (pre-shard_map) scope and its output is all-gathered along the EP axis - inside the shard_map. + with the EP shard_map: the gate matmul runs in the outer (pre-shard_map) + scope and its output is all-gathered along the EP axis inside. """ # Cast kernel to input dtype outside FP8 scope (gate is typically BF16/FP32). kernel = gate_kernel.astype(inputs.dtype) return jnp.einsum("bsh,he->bse", inputs, kernel) # ------------------------------------------------------------------ - # Single-shard (no EP) forward + # Forward body (shared between no-EP and ring-EP paths) # ------------------------------------------------------------------ - def _forward_single_shard( + def _forward_body( self, inputs: jnp.ndarray, gate_logits: jnp.ndarray, params: dict, + num_experts_local: int, + roll_to_expert_id: Optional[int], ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - batch_size, sequence_length, hidden_size = inputs.shape + """Routing + dispatch + per-expert FFN + combine. + Used both bare (no EP) and inside the ring-EP shard_map. In the + ring-EP case ``inputs`` and ``gate_logits`` are the post-all_gather + global tensors, ``num_experts_local == num_experts // num_ep``, and + ``roll_to_expert_id`` is the offset that brings this shard's experts + into slots ``[0, num_experts_local)``. + """ + batch_size, sequence_length, hidden_size = inputs.shape inputs_2d = inputs.reshape(-1, hidden_size) logits_2d = gate_logits.reshape(-1, self.num_experts) @@ -406,16 +356,48 @@ def _forward_single_shard( logits_2d, params.get("expert_bias") ) + if roll_to_expert_id is not None: + # Rotate expert columns so this shard's experts come first. + routing_map = jnp.roll(routing_map, -roll_to_expert_id, axis=-1) + sparse_probs = jnp.roll(sparse_probs, -roll_to_expert_id, axis=-1) + if self.permutation_backend == "triton": + # Triton path: zero out remote-expert columns so the fused + # ``token_dispatch`` never writes tokens routed off-shard. + # The pure-JAX path zeroes garbage *output* rows below + # instead, since masking the routing_map directly would + # break the argsort-based permutation. + local_mask = ( + jnp.arange(self.num_experts) < num_experts_local + ) + routing_map = routing_map * local_mask + sparse_probs = sparse_probs * local_mask.astype(sparse_probs.dtype) + expert_outputs, combine_state = self._dispatch_and_expert_ffn( inputs_2d, sparse_probs, routing_map, params, - num_experts_local=self.num_experts, - roll_to_expert_id=None, - local_tokens_per_expert_count=self.num_experts, + num_experts_local=num_experts_local, + # The roll is already baked into ``routing_map``/``sparse_probs`` + # above, so the unfused dispatch must not roll again. + roll_to_expert_id=0 if roll_to_expert_id is not None else None, ) + if ( + roll_to_expert_id is not None + and self.permutation_backend == "pure_jax" + ): + # Zero the rows of ``expert_outputs`` past the real local-expert + # token count: ``grouped_dense`` leaves them as garbage because + # ``group_sizes`` was truncated to the local slice. Without this + # the unsort + weighted-sum in combine would mix garbage into + # every token's output (mirrors Maxtext's moe.py). + real_mask = ( + jnp.arange(expert_outputs.shape[0]) + < combine_state["local_real_size"] + ) + expert_outputs = jnp.where(real_mask[:, None], expert_outputs, 0) + output = self._combine( expert_outputs, combine_state, @@ -434,7 +416,7 @@ def _forward_single_shard( return output, aux_loss # ------------------------------------------------------------------ - # Ring-of-Experts EP forward + # Ring-of-Experts EP wrapper # ------------------------------------------------------------------ def _forward_ring_ep( @@ -442,22 +424,16 @@ def _forward_ring_ep( inputs: jnp.ndarray, gate_logits: jnp.ndarray, params: dict, - ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Wrap the dispatch / FFN / combine pipeline in a ring-of-experts - ``shard_map``. - - Inside the shard_map each EP shard: - 1. ``all_gather`` s the inputs and logits along the EP axis so it - sees every token globally. - 2. Routes with ``roll_to_expert_id = num_experts_per_shard * shard_id`` - so its local experts are in slots ``[0, num_experts_per_shard)``. - 3. Dispatches tokens, slicing ``group_sizes`` to the first - ``num_experts_per_shard`` entries (the rest correspond to remote - experts and should be zero after the roll/mask). - 4. Runs the per-expert FFN on its local expert slice of - ``wi_0`` / ``wi_1`` / ``wo``. - 5. Combines at the expanded-batch shape ``[B * num_ep, S, H]`` then - ``psum_scatter`` s along the EP axis to return the local slice. + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Wrap :meth:`_forward_body` in a ring-of-experts ``shard_map``. + + For each EP shard the wrapper: + 1. ``all_gather`` s the local inputs / logits / expert_bias along + the EP axis so the routing sees every token globally. + 2. Calls ``_forward_body`` with ``roll_to_expert_id = + num_experts_per_shard * shard_id`` and the EP-local weight slice. + 3. ``psum_scatter`` s the resulting ``[B*num_ep, S, H]`` output back + to the EP-sharded ``[B, S, H]`` layout. """ from jax.experimental.shard_map import shard_map @@ -474,201 +450,94 @@ def _forward_ring_ep( ) num_experts_per_shard = self.num_experts // num_ep - # in_specs / out_specs use PartitionSpec over the EP axis for inputs/ - # outputs (leading batch dim is split across EP) and ``P("exp", ...)`` - # for the expert weights, where we require the user's logical axis - # rules to map ``"exp"`` to the EP mesh axis. The expert bias is - # similarly sharded along the expert axis. - inputs_spec = P(ep_axis, None, None) - logits_spec = P(ep_axis, None, None) - wi_spec = P(ep_axis, None, None) - wo_spec = P(ep_axis, None, None) - output_spec = P(ep_axis, None, None) - scalar_spec = P() - bias_1d_spec = P(ep_axis) - bias_2d_spec = P(ep_axis, None) - - expert_bias_value = params.get("expert_bias") - wi_0_bias_value = params.get("wi_0_bias") - wi_1_bias_value = params.get("wi_1_bias") - wo_bias_value = params.get("wo_bias") - - in_specs = [ - inputs_spec, - logits_spec, - wi_spec, - wi_spec, - wo_spec, - ] - captured = [ - inputs, - gate_logits, - params["wi_0"], - params["wi_1"], - params["wo"], - ] - if expert_bias_value is not None: - in_specs.append(bias_1d_spec) - captured.append(expert_bias_value) - if wi_0_bias_value is not None: - in_specs.extend([bias_2d_spec, bias_2d_spec, bias_2d_spec]) - captured.extend([wi_0_bias_value, wi_1_bias_value, wo_bias_value]) - - out_specs = (output_spec, scalar_spec) - - use_expert_bias = expert_bias_value is not None - use_bias = wi_0_bias_value is not None - - def _ring_fn(*args): - idx = 0 - local_inputs = args[idx]; idx += 1 - local_gate_logits = args[idx]; idx += 1 - local_wi_0 = args[idx]; idx += 1 - local_wi_1 = args[idx]; idx += 1 - local_wo = args[idx]; idx += 1 - local_expert_bias = None - if use_expert_bias: - local_expert_bias = args[idx]; idx += 1 - local_wi_0_bias = local_wi_1_bias = local_wo_bias = None - if use_bias: - local_wi_0_bias = args[idx]; idx += 1 - local_wi_1_bias = args[idx]; idx += 1 - local_wo_bias = args[idx]; idx += 1 - + # Pack everything that crosses the shard_map boundary into a dict + # pytree. shard_map fully supports pytrees: ``in_specs`` must + # structurally match ``captured``, and we build them in lockstep so + # adding/removing an optional bias is a single ``dict[name] = ...``. + captured: dict = { + "inputs": inputs, + "gate_logits": gate_logits, + "wi_0": params["wi_0"], + "wi_1": params["wi_1"], + "wo": params["wo"], + } + in_specs: dict = { + "inputs": P(ep_axis, None, None), + "gate_logits": P(ep_axis, None, None), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if "expert_bias" in params: + captured["expert_bias"] = params["expert_bias"] + in_specs["expert_bias"] = P(ep_axis) + if "wi_0_bias" in params: + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + captured[name] = params[name] + in_specs[name] = P(ep_axis, None) + + def _ring_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: shard_id = jax.lax.axis_index(ep_axis) - # All-gather inputs and logits along the EP axis so each shard - # sees the global tokens. gathered_inputs = jax.lax.all_gather( - local_inputs, axis_name=ep_axis, tiled=True + local["inputs"], axis_name=ep_axis, tiled=True ) gathered_logits = jax.lax.all_gather( - local_gate_logits, axis_name=ep_axis, tiled=True - ) - - # If the user also sharded by EP on the expert_bias, ``local_expert_bias`` - # is already the local slice; the router operates over the full - # expert axis, so all-gather to reconstruct. - global_expert_bias = None - if local_expert_bias is not None: - global_expert_bias = jax.lax.all_gather( - local_expert_bias, axis_name=ep_axis, tiled=True - ) - - batch_size = gathered_inputs.shape[0] - sequence_length = gathered_inputs.shape[1] - hidden_size = gathered_inputs.shape[2] - - inputs_2d = gathered_inputs.reshape(-1, hidden_size) - logits_2d = gathered_logits.reshape(-1, self.num_experts) - - sparse_probs, routing_map, aux_loss = self._route( - logits_2d, global_expert_bias + local["gate_logits"], axis_name=ep_axis, tiled=True ) - # Ring-of-experts roll: after rolling expert columns by - # ``-num_experts_per_shard * shard_id``, this shard's experts - # occupy slots ``[0, num_experts_per_shard)`` in ``routing_map`` - # and ``sparse_probs``. - # - # For the Triton backend we additionally mask the remote-expert - # columns to False/0 so ``token_dispatch`` never writes those - # tokens into the local permuted buffer. For the pure-JAX backend - # we leave the routing_map untouched (mirroring Maxtext): the roll - # passed to ``mt_token_dispatch`` sorts remote-expert tokens past - # the local slots, and we later zero out those garbage rows of - # ``expert_outputs`` before the combine. - roll = num_experts_per_shard * shard_id - routing_map = jnp.roll(routing_map, -roll, axis=-1) - sparse_probs = jnp.roll(sparse_probs, -roll, axis=-1) - if self.permutation_backend == "triton": - local_expert_mask = ( - jnp.arange(self.num_experts) < num_experts_per_shard - ) - routing_map = routing_map * local_expert_mask[None, :] - sparse_probs = sparse_probs * local_expert_mask[None, :].astype( - sparse_probs.dtype - ) - - # Build a reduced-expert view of the weights: the outer ``shard_map`` - # has already sliced the leading expert axis down to - # ``num_experts_per_shard`` per shard. Pass it through as-is to the - # dispatch / expert-FFN path with ``num_experts_local = num_experts_per_shard``. - local_params = { - "gate_kernel": None, # unused past gate - "wi_0": local_wi_0, - "wi_1": local_wi_1, - "wo": local_wo, + local_params: dict = { + "wi_0": local["wi_0"], + "wi_1": local["wi_1"], + "wo": local["wo"], } - if use_bias: - local_params["wi_0_bias"] = local_wi_0_bias - local_params["wi_1_bias"] = local_wi_1_bias - local_params["wo_bias"] = local_wo_bias - - expert_outputs, combine_state = self._dispatch_and_expert_ffn( - inputs_2d, - sparse_probs, - routing_map, + if "expert_bias" in local: + # The router operates over the full expert axis, so the + # EP-sharded bias must be all-gathered. + local_params["expert_bias"] = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True + ) + if "wi_0_bias" in local: + local_params["wi_0_bias"] = local["wi_0_bias"] + local_params["wi_1_bias"] = local["wi_1_bias"] + local_params["wo_bias"] = local["wo_bias"] + + output, aux_loss = self._forward_body( + gathered_inputs, + gathered_logits, local_params, num_experts_local=num_experts_per_shard, - roll_to_expert_id=0, # roll is already applied on routing_map - local_tokens_per_expert_count=num_experts_per_shard, - ) - - # For the pure-JAX backend in ring-EP mode, zero out expert-output - # rows that correspond to remote experts (which ``grouped_dense`` - # leaves as garbage since ``group_sizes`` was truncated to the - # local slice). Without this, the unsort + weighted-sum in - # combine would mix garbage into every token's output. Matches - # ``moe.py:1731-1733`` in Maxtext. - if self.permutation_backend == "pure_jax": - real_mask = ( - jnp.arange(expert_outputs.shape[0]) - < combine_state["local_real_size"] - ) - expert_outputs = jnp.where( - real_mask[:, None], expert_outputs, 0 - ) - - output = self._combine( - expert_outputs, - combine_state, - batch_size=batch_size, - sequence_length=sequence_length, + roll_to_expert_id=num_experts_per_shard * shard_id, ) - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) - - # ``output`` is [B*num_ep, S, H] (global batch after all-gather); + # ``output`` is [B*num_ep, S, H] (global batch after all_gather); # psum_scatter along EP returns the local [B, S, H] slice. output = jax.lax.psum_scatter( - output, - ep_axis, - scatter_dimension=0, - tiled=True, + output, ep_axis, scatter_dimension=0, tiled=True ) + # ``out_specs`` must match the returned pytree structurally, so + # always emit a real scalar for aux_loss; the outer ``__call__`` + # re-strips it to None when ``aux_loss_coeff <= 0``. if aux_loss is None: aux_loss = jnp.zeros((), dtype=self.dtype) return output, aux_loss - output, aux_loss = shard_map( + # ``check_rep=False`` disables shard_map's invariant that any output + # declared as ``P()`` is replicated across ``ep_axis``. We use + # ``axis_index(ep_axis)`` inside ``_ring_fn`` to compute a per-shard + # roll, which makes the body genuinely non-replicated and would + # otherwise (correctly) fail the check. The ``psum_scatter`` of the + # output already produces the right cross-shard semantics; this is + # the standard JAX escape hatch when collectives + per-shard logic + # coexist. + return shard_map( _ring_fn, mesh=mesh, - in_specs=tuple(in_specs), - out_specs=out_specs, + in_specs=in_specs, + out_specs=(P(ep_axis, None, None), P()), check_rep=False, - )(*captured) - - if self.aux_loss_coeff <= 0.0: - aux_loss = None - return output, aux_loss + )(captured) # ------------------------------------------------------------------ # Route @@ -726,7 +595,6 @@ def _dispatch_and_expert_ffn( params: dict, num_experts_local: int, roll_to_expert_id: Optional[int], - local_tokens_per_expert_count: int, ) -> Tuple[jnp.ndarray, dict]: """Dispatch tokens, run the three grouped GEMMs + activation, return expert outputs. @@ -738,21 +606,20 @@ def _dispatch_and_expert_ffn( topk = self.num_experts_per_tok if self.permutation_backend == "pure_jax": - selected_experts, routing_weights = _extract_topk_from_routing_map( + selected_experts, routing_weights = _routing_map_to_selected_experts( sparse_probs, routing_map, topk ) - sorted_inputs, perm_state, group_sizes = mt_token_dispatch( + sorted_inputs, perm_state, group_sizes = unfused_token_dispatch( inputs_2d, selected_experts, num_experts=self.num_experts, num_experts_per_tok=topk, align_size=self.align_size, roll_to_expert_id=roll_to_expert_id, - use_custom_sort_vjp=self.use_custom_sort_vjp, ) # Slice group_sizes to just this shard's experts. When not using # EP, ``num_experts_local == self.num_experts`` so this is a no-op. - group_sizes = group_sizes[:local_tokens_per_expert_count] + group_sizes = group_sizes[:num_experts_local] # ``local_real_size = sum(group_sizes)`` is the number of permuted # rows that actually correspond to tokens routed to this shard's # experts. Used by the ring-EP caller to zero out garbage rows @@ -779,7 +646,7 @@ def _dispatch_and_expert_ffn( probs=sparse_probs, align_size=align_size_arg, ) - group_sizes = group_sizes[:local_tokens_per_expert_count] + group_sizes = group_sizes[:num_experts_local] combine_state = { "backend": "triton", "row_id_map": row_id_map, @@ -842,7 +709,7 @@ def _dispatch_and_expert_ffn( quantizer_set=q_set_w1, ) - act_fn = _get_activation_fn(self.activation_type) + act_fn = _convert_to_activation_function(self.activation_type) intermediate = act_fn(layer_w0) * layer_w1 expert_outputs = grouped_dense( @@ -868,14 +735,13 @@ def _combine( sequence_length: int, ) -> jnp.ndarray: if combine_state["backend"] == "pure_jax": - return mt_token_combine( + return unfused_token_combine( expert_outputs, combine_state["perm_state"], combine_state["routing_weights"], num_experts_per_tok=self.num_experts_per_tok, batch_size=batch_size, sequence_length=sequence_length, - use_custom_sort_vjp=self.use_custom_sort_vjp, ) # triton out_2d = token_combine( diff --git a/transformer_engine/jax/mt_permutation.py b/transformer_engine/jax/mt_permutation.py deleted file mode 100644 index 10882501ec..0000000000 --- a/transformer_engine/jax/mt_permutation.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Pure-JAX MoE Permutation API. - -This module provides a MaxText-style, pure-JAX implementation of MoE token -dispatch / combine as an alternative to the Triton-backed primitives in -``transformer_engine.jax.permutation``. Empirically this path has been faster -than the Triton kernels on several E2E workloads. - -The core design mirrors Maxtext's ``_mt_permute`` / ``_mt_unpermute`` in -``maxtext/src/maxtext/layers/moe.py``, with alignment-padding support ported -from `nvjax-svc-0/maxtext PR #36 `_ -so each expert's group size is a multiple of ``align_size`` (required for -quantized grouped GEMM whose recipe-specific alignment must divide -``align_size``). - -When ``align_size = 0`` padding is disabled (faster for the unquantized path); -when ``align_size > 0`` a static-size padding buffer of shape -``[num_experts * (align_size - 1)]`` is appended before the sort so the overall -shape is JIT-compatible. - -The public API is: - -* :func:`mt_token_dispatch` -- pure-JAX counterpart of ``token_dispatch``. -* :func:`mt_token_combine` -- pure-JAX counterpart of ``token_combine``. -* :class:`MTPermState` -- opaque state returned by ``mt_token_dispatch`` and - consumed by ``mt_token_combine``. -""" - -from typing import NamedTuple, Optional, Tuple - -import jax -import jax.numpy as jnp - -__all__ = [ - "MTPermState", - "mt_token_dispatch", - "mt_token_combine", -] - - -# ============================================================================= -# Custom-VJP argsort-based gather (``_sort_activations_custom``) -# ============================================================================= -# -# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. -# Using a custom VJP lets the backward pass exploit that inverse instead of -# relying on the compiler to discover it from the scatter-style default -# gradient of a gather, which is typically less efficient. - - -@jax.custom_vjp -def _sort_activations_custom(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: - """Sort ``inputs`` along the leading dim by ``sort_indices``.""" - return inputs[sort_indices, ...] - - -def _sort_activations_custom_fwd( - inputs: jax.Array, sort_indices: jax.Array -) -> Tuple[jax.Array, jax.Array]: - return _sort_activations_custom(inputs, sort_indices), sort_indices - - -def _sort_activations_custom_bwd( - residuals: jax.Array, grads: jax.Array -) -> Tuple[jax.Array, None]: - sort_indices = residuals - # Inverse permutation: gather-by-argsort undoes the forward gather. - return _sort_activations_custom(grads, jnp.argsort(sort_indices)), None - - -_sort_activations_custom.defvjp(_sort_activations_custom_fwd, _sort_activations_custom_bwd) - - -def _sort_activations( - inputs: jax.Array, - sort_indices: jax.Array, - use_custom_vjp: bool, -) -> jax.Array: - """Sort activations by ``sort_indices``, optionally with the custom VJP.""" - assert inputs.shape[0] == sort_indices.shape[0], ( - f"inputs.shape[0]={inputs.shape[0]} must match" - f" sort_indices.shape[0]={sort_indices.shape[0]}" - ) - with jax.named_scope("mt_sort_activations"): - if use_custom_vjp: - return _sort_activations_custom(inputs, sort_indices) - return inputs[sort_indices, ...] - - -# ============================================================================= -# Permutation state carried from dispatch to combine -# ============================================================================= - - -class MTPermState(NamedTuple): - """Opaque state produced by :func:`mt_token_dispatch`. - - Attributes - ---------- - sorted_indices : jnp.ndarray - The argsort indices used in the forward sort. Needed to reverse the - permutation in :func:`mt_token_combine`. Shape - ``[num_real_tokens + padding_size]``. - num_real_tokens : int - Number of real (non-padding) permuted tokens, i.e. - ``batch_size * sequence_length * num_experts_per_tok``. Compile-time - constant. - padding_size : int - Number of alignment-padding tokens appended to the sort buffer. Equals - ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. - Compile-time constant. - """ - - sorted_indices: jax.Array - num_real_tokens: int - padding_size: int - - -# ============================================================================= -# Dispatch (permute) -# ============================================================================= - - -def mt_token_dispatch( - inputs: jnp.ndarray, - selected_experts: jnp.ndarray, - num_experts: int, - num_experts_per_tok: int, - align_size: int = 0, - roll_to_expert_id: Optional[int] = None, - use_custom_sort_vjp: bool = True, -) -> Tuple[jnp.ndarray, MTPermState, jnp.ndarray]: - """Pure-JAX MaxText-style token dispatch. - - Parameters - ---------- - inputs : jnp.ndarray - Input tensor of shape ``[num_tokens, hidden_size]`` (or - ``[batch, seq, hidden]``; it will be flattened). - selected_experts : jnp.ndarray - Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or - ``[batch, seq, num_experts_per_tok]``). Integer dtype. - num_experts : int - Total number of experts. - num_experts_per_tok : int - Top-k. Must equal ``selected_experts.shape[-1]``. - align_size : int, default 0 - Alignment for each expert's group size. ``0`` disables padding; a value - ``> 0`` appends a static-size padding buffer so each resulting group - size is a multiple of ``align_size``. - roll_to_expert_id : Optional[int] - If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo - ``num_experts`` before the sort (ring-of-experts EP). The returned - ``group_sizes`` is rolled to match. - use_custom_sort_vjp : bool, default True - Whether to use the custom-VJP argsort gather for the sort. - - Returns - ------- - sorted_inputs : jnp.ndarray - Permuted tokens grouped by expert, shape - ``[num_real_tokens + padding_size, hidden_size]``. - perm_state : MTPermState - State needed by :func:`mt_token_combine`. - group_sizes : jnp.ndarray - Token count per expert, shape ``[num_experts]``. Each entry is a - multiple of ``align_size`` when ``align_size > 0``. - """ - assert num_experts_per_tok == selected_experts.shape[-1], ( - f"num_experts_per_tok={num_experts_per_tok} must match" - f" selected_experts.shape[-1]={selected_experts.shape[-1]}" - ) - assert align_size >= 0, f"align_size must be >= 0, got {align_size}" - - hidden_size = inputs.shape[-1] - # Flatten token dims. - inputs_2d = inputs.reshape(-1, hidden_size) - num_tokens = inputs_2d.shape[0] - num_real_tokens = num_tokens * num_experts_per_tok - - flatten_selected_experts = jnp.ravel(selected_experts) - - if align_size > 0: - # Per-expert token count, and how many extra tokens each expert needs - # to become aligned to ``align_size``. Using - # ``(align - count % align) % align`` gives 0 (not ``align``) when - # already aligned, so we never exceed the per-expert slot capacity of - # ``align_size - 1``. - token_count_per_expert = jnp.bincount( - flatten_selected_experts, length=num_experts - ) - padding_tokens_required_per_expert = ( - (align_size - (token_count_per_expert % align_size)) % align_size - ) - - # Build a static-size padding buffer of shape - # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot - # of ``align_size - 1`` positions (worst-case padding, which occurs - # when ``token_count[i] % align_size == 1``). Within slot ``i``, - # positions ``[0, padding_needed)`` are assigned expert ``i`` and act - # as real padding; the rest are assigned to ``num_experts - 1`` as - # overflow placeholders that keep the buffer statically sized for JIT. - max_padding_per_expert = align_size - 1 - max_total_padding_size = num_experts * max_padding_per_expert - positions = jnp.arange(max_total_padding_size) - expert_for_pos = positions // max_padding_per_expert - offset_in_slot = positions % max_padding_per_expert - padding_needed = padding_tokens_required_per_expert[expert_for_pos] - flatten_padding_selected_experts = jnp.where( - offset_in_slot < padding_needed, - expert_for_pos, - num_experts - 1, - ) - - flatten_selected_experts = jnp.concatenate( - [flatten_selected_experts, flatten_padding_selected_experts], axis=0 - ) - - if roll_to_expert_id is not None: - flatten_selected_experts = ( - flatten_selected_experts - roll_to_expert_id - ) % num_experts - - sorted_selected_experts = jnp.argsort(flatten_selected_experts) - - replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) - # Pad inputs with zeros so the sort operand shape matches the expanded - # selected-experts vector. - replicated_inputs_2d = jnp.pad( - replicated_inputs_2d, - pad_width=((0, max_total_padding_size), (0, 0)), - mode="constant", - constant_values=0.0, - ) - - sorted_inputs = _sort_activations( - replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp - ) - - # Compute ``group_sizes`` directly from counts rather than via - # ``bincount(flatten_selected_experts)``: the overflow placeholder - # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the - # alignment guarantee. Direct computation gives each expert exactly - # ``ceil(count / align) * align`` tokens. - group_sizes = token_count_per_expert + padding_tokens_required_per_expert - - if roll_to_expert_id is not None: - group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) - - padding_size = max_total_padding_size - else: - if roll_to_expert_id is not None: - flatten_selected_experts = ( - flatten_selected_experts - roll_to_expert_id - ) % num_experts - - sorted_selected_experts = jnp.argsort(flatten_selected_experts) - - replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) - sorted_inputs = _sort_activations( - replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp - ) - - group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) - if roll_to_expert_id is not None: - group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) - - padding_size = 0 - - perm_state = MTPermState( - sorted_indices=sorted_selected_experts, - num_real_tokens=num_real_tokens, - padding_size=padding_size, - ) - return sorted_inputs, perm_state, group_sizes - - -# ============================================================================= -# Combine (unpermute + weighted sum) -# ============================================================================= - - -def mt_token_combine( - expert_outputs: jnp.ndarray, - perm_state: MTPermState, - routing_weights: jnp.ndarray, - num_experts_per_tok: int, - batch_size: int, - sequence_length: int, - use_custom_sort_vjp: bool = True, -) -> jnp.ndarray: - """Pure-JAX MaxText-style token combine. - - Reverses the permutation performed by :func:`mt_token_dispatch`, strips - any alignment-padding rows appended during dispatch, and applies a - per-token weighted sum across the top-k experts. - - Parameters - ---------- - expert_outputs : jnp.ndarray - Output of the expert FFN, shape - ``[num_real_tokens + padding_size, hidden_size]``. - perm_state : MTPermState - State returned by :func:`mt_token_dispatch`. - routing_weights : jnp.ndarray - Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` - (or broadcastable to it after a ``reshape``). - num_experts_per_tok : int - Top-k. - batch_size : int - Original batch size. - sequence_length : int - Original sequence length. - use_custom_sort_vjp : bool, default True - Whether to use the custom-VJP argsort gather for the unsort. - - Returns - ------- - output : jnp.ndarray - Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. - """ - # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes - # ``input[sorted_indices]``. - unsort_intermediate = _sort_activations( - expert_outputs, - jnp.argsort(perm_state.sorted_indices), - use_custom_sort_vjp, - ) - - # Strip alignment padding tokens appended during dispatch. After unsorting, - # the first ``num_real_tokens`` rows hold the real per-(token, top-k) - # outputs; any trailing rows are padding placeholders (zeros) and must be - # discarded before the reshape below. - if perm_state.padding_size > 0: - unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] - - hidden_size = unsort_intermediate.shape[-1] - reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) - reshaped_intermediate = jnp.reshape( - unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) - ) - - # Cast weights to match intermediate dtype (weighted sum happens in - # intermediate dtype; callers can upcast before calling if higher - # precision weight-sum is desired). - reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) - with jax.named_scope("mt_weight_sum"): - output = jnp.einsum( - "BKE,BK -> BE", - reshaped_intermediate, - reshaped_weights, - ) - return output.reshape(batch_size, sequence_length, hidden_size) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 81972aac0f..1a492ba186 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -7,6 +7,17 @@ This module provides high-level token dispatch and combine operations for Mixture of Experts (MoE) models with proper automatic differentiation support. +Two backends are offered: + +* Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the + Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. +* Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` - + uses only ``jnp.argsort`` + gather and is therefore compiled as plain XLA. + +Both backends support optional alignment padding (``align_size > 0``) so each +expert's group size is a multiple of ``align_size``, which is required for +quantized grouped GEMMs. + Token Dispatch (Permute): - Forward: Permute tokens according to routing map (scatter to experts) - Backward: Unpermute gradients (gather from experts) @@ -17,7 +28,7 @@ """ from functools import partial -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple import jax import jax.numpy as jnp @@ -38,6 +49,9 @@ "token_dispatch", "token_combine", "sort_chunks_by_index", + "unfused_token_dispatch", + "unfused_token_combine", + "UnfusedPermState", ] @@ -655,3 +669,323 @@ def _sort_chunks_by_index_bwd_rule( _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) + + +# ============================================================================= +# Unfused (pure-JAX) token dispatch / combine +# ============================================================================= +# +# The following implementations use only ``jnp.argsort`` + gather and compile +# to plain XLA. They are a drop-in alternative to ``token_dispatch`` / +# ``token_combine`` above, differing only in input/output conventions (the +# fused path takes ``routing_map`` and ``sparse_probs`` over all experts; the +# unfused path takes dense ``selected_experts`` and per-token ``weights`` of +# shape ``[..., topk]``). + + +# ----------------------------------------------------------------------------- +# Custom-VJP argsort-based gather. +# +# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. +# Using a custom VJP lets the backward pass exploit that inverse instead of +# relying on the compiler to discover it from the scatter-style default +# gradient of a gather, which is typically less efficient. + + +@jax.custom_vjp +def _sort_activations(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: + """Sort ``inputs`` along the leading dim by ``sort_indices``.""" + assert inputs.shape[0] == sort_indices.shape[0], ( + f"inputs.shape[0]={inputs.shape[0]} must match" + f" sort_indices.shape[0]={sort_indices.shape[0]}" + ) + with jax.named_scope("unfused_sort_activations"): + return inputs[sort_indices, ...] + + +def _sort_activations_fwd( + inputs: jax.Array, sort_indices: jax.Array +) -> Tuple[jax.Array, jax.Array]: + return _sort_activations(inputs, sort_indices), sort_indices + + +def _sort_activations_bwd( + residuals: jax.Array, grads: jax.Array +) -> Tuple[jax.Array, None]: + sort_indices = residuals + # Inverse permutation: gather-by-argsort undoes the forward gather. + return _sort_activations(grads, jnp.argsort(sort_indices)), None + + +_sort_activations.defvjp(_sort_activations_fwd, _sort_activations_bwd) + + +def _routing_map_to_selected_experts( + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + topk: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert ``(sparse_probs, routing_map)`` from TE's fused router to the + ``(selected_experts, weights)`` format consumed by + :func:`unfused_token_dispatch`. + + ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` + with exactly ``topk`` ``True`` positions per row. + """ + # Argsort on a bool tensor places ``True`` rows last (False=0 < True=1), + # so the last ``topk`` indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -topk:] + weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + return selected_experts, weights + + +# ----------------------------------------------------------------------------- +# Permutation state carried from dispatch to combine. + + +class UnfusedPermState(NamedTuple): + """Opaque state produced by :func:`unfused_token_dispatch`. + + Attributes + ---------- + sorted_indices : jnp.ndarray + The argsort indices used in the forward sort. Needed to reverse the + permutation in :func:`unfused_token_combine`. Shape + ``[num_real_tokens + padding_size]``. + num_real_tokens : int + Number of real (non-padding) permuted tokens, i.e. + ``batch_size * sequence_length * num_experts_per_tok``. Compile-time + constant. + padding_size : int + Number of alignment-padding tokens appended to the sort buffer. Equals + ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. + Compile-time constant. + """ + + sorted_indices: jax.Array + num_real_tokens: int + padding_size: int + + +# ----------------------------------------------------------------------------- +# Dispatch (permute) + + +def unfused_token_dispatch( + inputs: jnp.ndarray, + selected_experts: jnp.ndarray, + num_experts: int, + num_experts_per_tok: int, + align_size: int = 0, + roll_to_expert_id: Optional[int] = None, +) -> Tuple[jnp.ndarray, UnfusedPermState, jnp.ndarray]: + """Pure-JAX ``argsort``-based token dispatch. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[num_tokens, hidden_size]`` (or + ``[batch, seq, hidden]``; it will be flattened). + selected_experts : jnp.ndarray + Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or + ``[batch, seq, num_experts_per_tok]``). Integer dtype. + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k. Must equal ``selected_experts.shape[-1]``. + align_size : int, default 0 + Alignment for each expert's group size. ``0`` disables padding; a value + ``> 0`` appends a static-size padding buffer so each resulting group + size is a multiple of ``align_size`` (required for quantized grouped + GEMM). + roll_to_expert_id : Optional[int] + If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo + ``num_experts`` before the sort (ring-of-experts EP). The returned + ``group_sizes`` is rolled to match. + + Returns + ------- + sorted_inputs : jnp.ndarray + Permuted tokens grouped by expert, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : UnfusedPermState + State needed by :func:`unfused_token_combine`. + group_sizes : jnp.ndarray + Token count per expert, shape ``[num_experts]``. Each entry is a + multiple of ``align_size`` when ``align_size > 0``. + """ + assert num_experts_per_tok == selected_experts.shape[-1], ( + f"num_experts_per_tok={num_experts_per_tok} must match" + f" selected_experts.shape[-1]={selected_experts.shape[-1]}" + ) + assert align_size >= 0, f"align_size must be >= 0, got {align_size}" + + hidden_size = inputs.shape[-1] + inputs_2d = inputs.reshape(-1, hidden_size) + num_tokens = inputs_2d.shape[0] + num_real_tokens = num_tokens * num_experts_per_tok + + flatten_selected_experts = jnp.ravel(selected_experts) + + if align_size > 0: + # Per-expert token count, and how many extra tokens each expert needs + # to become aligned to ``align_size``. Using + # ``(align - count % align) % align`` gives 0 (not ``align``) when + # already aligned, so we never exceed the per-expert slot capacity of + # ``align_size - 1``. + token_count_per_expert = jnp.bincount( + flatten_selected_experts, length=num_experts + ) + padding_tokens_required_per_expert = ( + (align_size - (token_count_per_expert % align_size)) % align_size + ) + + # Build a static-size padding buffer of shape + # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot + # of ``align_size - 1`` positions (worst-case padding, which occurs + # when ``token_count[i] % align_size == 1``). Within slot ``i``, + # positions ``[0, padding_needed)`` are assigned expert ``i`` and act + # as real padding; the rest are assigned to ``num_experts - 1`` as + # overflow placeholders that keep the buffer statically sized for JIT. + max_padding_per_expert = align_size - 1 + max_total_padding_size = num_experts * max_padding_per_expert + positions = jnp.arange(max_total_padding_size) + expert_for_pos = positions // max_padding_per_expert + offset_in_slot = positions % max_padding_per_expert + padding_needed = padding_tokens_required_per_expert[expert_for_pos] + flatten_padding_selected_experts = jnp.where( + offset_in_slot < padding_needed, + expert_for_pos, + num_experts - 1, + ) + + flatten_selected_experts = jnp.concatenate( + [flatten_selected_experts, flatten_padding_selected_experts], axis=0 + ) + + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + # Pad inputs with zeros so the sort operand shape matches the expanded + # selected-experts vector. + replicated_inputs_2d = jnp.pad( + replicated_inputs_2d, + pad_width=((0, max_total_padding_size), (0, 0)), + mode="constant", + constant_values=0.0, + ) + + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + # Compute ``group_sizes`` directly from counts rather than via + # ``bincount(flatten_selected_experts)``: the overflow placeholder + # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the + # alignment guarantee. Direct computation gives each expert exactly + # ``ceil(count / align) * align`` tokens. + group_sizes = token_count_per_expert + padding_tokens_required_per_expert + + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = max_total_padding_size + else: + if roll_to_expert_id is not None: + flatten_selected_experts = ( + flatten_selected_experts - roll_to_expert_id + ) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = 0 + + perm_state = UnfusedPermState( + sorted_indices=sorted_selected_experts, + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return sorted_inputs, perm_state, group_sizes + + +# ----------------------------------------------------------------------------- +# Combine (unpermute + weighted sum) + + +def unfused_token_combine( + expert_outputs: jnp.ndarray, + perm_state: UnfusedPermState, + routing_weights: jnp.ndarray, + num_experts_per_tok: int, + batch_size: int, + sequence_length: int, +) -> jnp.ndarray: + """Pure-JAX ``argsort``-based token combine. + + Reverses the permutation performed by :func:`unfused_token_dispatch`, + strips any alignment-padding rows appended during dispatch, and applies a + per-token weighted sum across the top-k experts. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the expert FFN, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : UnfusedPermState + State returned by :func:`unfused_token_dispatch`. + routing_weights : jnp.ndarray + Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` + (or broadcastable to it after a ``reshape``). + num_experts_per_tok : int + Top-k. + batch_size : int + Original batch size. + sequence_length : int + Original sequence length. + + Returns + ------- + output : jnp.ndarray + Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. + """ + # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes + # ``input[sorted_indices]``. + unsort_intermediate = _sort_activations( + expert_outputs, + jnp.argsort(perm_state.sorted_indices), + ) + + # Strip alignment padding tokens appended during dispatch. After unsorting, + # the first ``num_real_tokens`` rows hold the real per-(token, top-k) + # outputs; any trailing rows are padding placeholders (zeros) and must be + # discarded before the reshape below. + if perm_state.padding_size > 0: + unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] + + hidden_size = unsort_intermediate.shape[-1] + reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) + ) + + # Cast weights to match intermediate dtype (weighted sum happens in + # intermediate dtype; callers can upcast before calling if higher + # precision weight-sum is desired). + reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) + with jax.named_scope("unfused_weight_sum"): + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate, + reshaped_weights, + ) + return output.reshape(batch_size, sequence_length, hidden_size) From 0044bf23c74753f178ffbc7ed0fa2f845a04fe1c Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 22 Apr 2026 17:58:31 -0700 Subject: [PATCH 03/18] add distributed test. Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 143 ++++++++++++++++++++++++ tests/jax/test_moe_block.py | 23 +++- transformer_engine/jax/flax/moe.py | 24 ++-- 3 files changed, 180 insertions(+), 10 deletions(-) create mode 100644 tests/jax/test_distributed_moe_block.py diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py new file mode 100644 index 0000000000..9d9e57140f --- /dev/null +++ b/tests/jax/test_distributed_moe_block.py @@ -0,0 +1,143 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed tests for ``transformer_engine.jax.flax.MoEBlock``.""" + +import sys + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.sharding import Mesh, PartitionSpec + +from utils import assert_allclose, is_devices_enough + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax import MeshResource, autocast + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MeshResource = MeshResource + mod.autocast = autocast + mod.MoEBlock = MoEBlock + yield + + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs(key: jax.Array) -> jax.Array: + return jax.random.normal( + key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE + ) + + +def _unwrap_partitioned(x): + return x.value if hasattr(x, "value") else x + + +@pytest.mark.triton +class TestDistributedMoEBlock: + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_ep2_fsdp2_matches_single_device(self, permutation_backend): + if not is_devices_enough(4): + pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.") + + key = jax.random.PRNGKey(11) + init_key, data_key = jax.random.split(key) + inputs = _make_inputs(data_key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + + single_block = MoEBlock(**base_kwargs) + + def loss_fn(block, variables, x): + output, aux_loss = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if aux_loss is not None: + loss = loss + aux_loss.astype(jnp.float32) + return loss, (output, aux_loss) + + with autocast(enabled=False, mesh_resource=MeshResource()): + single_variables = single_block.init(init_key, inputs) + (single_loss, (single_output, single_aux)), single_grads = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(single_block, single_variables, inputs) + + devices = np.asarray(jax.devices()[:4]).reshape(2, 2) + mesh = Mesh(devices, ("ep", "fsdp")) + # FSDP-style sharding: weights are sharded on a *non-contracting* + # weight axis (gathered before the GEMM); activations stay sharded on + # the *batch* axis throughout - the same fsdp mesh axis is reused for + # both. The TE primitives' custom_partitioning rules expect activations + # FSDP-sharded on batch, so we declare ("batch", "fsdp") AND pass + # ``input_axes=("batch", None, None)`` to enforce it on the inputs to + # the block. ("embed", "fsdp") shards the weight's hidden dim, which + # is gathered inside grouped_dense's custom_partitioning before GEMM + # (no reshard of activations needed because their layout is unchanged). + logical_axis_rules = ( + ("exp", "ep"), + ("batch", "fsdp"), + ("embed", "fsdp"), + ) + sharded_block = MoEBlock( + expert_parallelism_axis="ep", + mesh=mesh, + input_axes=("batch", None, None), + **base_kwargs, + ) + + with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): + with nn.logical_axis_rules(logical_axis_rules): + sharded_variables = sharded_block.init(init_key, inputs) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( + jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + sharded_block, sharded_variables, inputs + ) + ) + + wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) + wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) + wo = _unwrap_partitioned(sharded_variables["params"]["wo"]) + assert wi_0.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wi_1.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wo.sharding.spec == PartitionSpec("ep", None, "fsdp") + + assert_allclose(sharded_output, single_output, dtype=DTYPE, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + grad_single = _unwrap_partitioned(single_grads["params"][name]) + grad_sharded = _unwrap_partitioned(sharded_grads["params"][name]) + assert_allclose( + grad_sharded, + grad_single, + dtype=DTYPE, + atol=1e-1, + rtol=1e-1, + err_msg=f"Distributed gradient mismatch for {name}", + ) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 458d674c7d..45cce2a60c 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -79,6 +79,11 @@ def _init_and_apply( return variables, output, aux_loss +def _unwrap_partitioned(x): + """Strip Flax logical-partition wrappers for numeric assertions.""" + return x.value if hasattr(x, "value") else x + + # ----------------------------------------------------------------------------- # Tests # ----------------------------------------------------------------------------- @@ -132,7 +137,7 @@ def loss_fn(variables, inputs): grads = jax.grad(loss_fn)(variables, inputs) # All trainable kernels should receive a non-trivial gradient. for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g = grads["params"][name] + g = _unwrap_partitioned(grads["params"][name]) assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" assert jnp.any(g != 0.0), f"{name} gradient is identically zero" @@ -183,8 +188,8 @@ def loss_fn(block, variables, inputs): assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g_pj = grads_pj["params"][name] - g_tr = grads_tr["params"][name] + g_pj = _unwrap_partitioned(grads_pj["params"][name]) + g_tr = _unwrap_partitioned(grads_tr["params"][name]) assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( f"Gradient for {name} differs across backends: max diff" f" {jnp.max(jnp.abs(g_pj - g_tr))}" @@ -238,6 +243,18 @@ def test_group_topk_deepseek(self, permutation_backend): assert output.shape == inputs.shape assert jnp.all(jnp.isfinite(output)) + @pytest.mark.xfail( + reason=( + "TE grouped_dense FFI currently asserts sum(group_sizes) == M " + "(see csrc/extensions/gemm.cpp). With align_size > 0 the dispatch " + "buffer is padded to a static worst-case size, so M can exceed " + "sum(group_sizes). The MoE block deliberately does not fold the " + "gap into a single expert (that would create per-shard load " + "imbalance under EP). Re-enable once the FFI check is relaxed to " + "M >= sum(group_sizes)." + ), + strict=False, + ) def test_align_size_equivalence_pure_jax(self): """For the pure-JAX backend, ``align_size > 0`` must not change the numerical output of the forward pass: padding tokens contribute zero diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 6673ac1a71..5f257dc577 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -425,7 +425,7 @@ def _forward_ring_ep( gate_logits: jnp.ndarray, params: dict, ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Wrap :meth:`_forward_body` in a ring-of-experts ``shard_map``. + """Wrap ``_forward_body`` in a ring-of-experts ``shard_map``. For each EP shard the wrapper: 1. ``all_gather`` s the local inputs / logits / expert_bias along @@ -566,7 +566,7 @@ def _route( # The score-for-aux kernel runs independently (no data dependency # on the main kernel), so XLA can overlap them on the GPU. aux_scores, aux_routing_map = fused_topk_with_score_function( - logits_2d, + logits_2d.astype(jnp.float32), topk=self.num_experts_per_tok, score_function=self.score_function, compute_aux_scores=True, @@ -575,7 +575,7 @@ def _route( aux_routing_map.astype(jnp.int32), axis=0 ) aux_loss = fused_moe_aux_loss( - aux_scores, + aux_scores.astype(jnp.float32), aux_tokens_per_expert, topk=self.num_experts_per_tok, coeff=self.aux_loss_coeff, @@ -619,11 +619,21 @@ def _dispatch_and_expert_ffn( ) # Slice group_sizes to just this shard's experts. When not using # EP, ``num_experts_local == self.num_experts`` so this is a no-op. + # + # NOTE on padded buffers (``align_size > 0``): + # ``unfused_token_dispatch`` pads ``sorted_inputs`` to a static + # worst-case row count so JIT shape inference is happy. The + # returned ``group_sizes`` deliberately tracks only real + real + # alignment-padding tokens; the remaining rows are zero-input + # placeholders that ``grouped_dense`` does not need to touch. + # + # TE's ``grouped_dense`` FFI today asserts strictly + # ``sum(group_sizes) == sorted_inputs.shape[0]``. When that + # assertion is relaxed to ``>=`` (the GEMM only iterates over the + # first ``sum(group_sizes)`` rows anyway), this code works as-is. + # Folding the gap into a single expert would create a per-shard + # load imbalance and is intentionally avoided here. group_sizes = group_sizes[:num_experts_local] - # ``local_real_size = sum(group_sizes)`` is the number of permuted - # rows that actually correspond to tokens routed to this shard's - # experts. Used by the ring-EP caller to zero out garbage rows - # before combine. combine_state = { "backend": "pure_jax", "perm_state": perm_state, From d78bc01660bdb2ce8b7c15affa6d303816c6e3d8 Mon Sep 17 00:00:00 2001 From: tdophung Date: Wed, 29 Apr 2026 18:02:18 -0700 Subject: [PATCH 04/18] refactor to a2a from roe Signed-off-by: tdophung --- tests/jax/test_moe_block.py | 18 +- transformer_engine/jax/flax/moe.py | 945 +++++++++++++++----------- transformer_engine/jax/permutation.py | 336 +++++++++ 3 files changed, 908 insertions(+), 391 deletions(-) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 45cce2a60c..39a6bfd592 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -245,13 +245,17 @@ def test_group_topk_deepseek(self, permutation_backend): @pytest.mark.xfail( reason=( - "TE grouped_dense FFI currently asserts sum(group_sizes) == M " - "(see csrc/extensions/gemm.cpp). With align_size > 0 the dispatch " - "buffer is padded to a static worst-case size, so M can exceed " - "sum(group_sizes). The MoE block deliberately does not fold the " - "gap into a single expert (that would create per-shard load " - "imbalance under EP). Re-enable once the FFI check is relaxed to " - "M >= sum(group_sizes)." + "TE grouped_dense FFI asserts sum(group_sizes) == M at " + "transformer_engine/jax/csrc/extensions/gemm.cpp:1029. With " + "align_size > 0 both backends produce a buffer where M >= " + "sum(group_sizes) (the slack is structural padding for JIT). " + "The kernel itself iterates over per-expert m_i from " + "group_sizes via nvte_multi_tensor_gemm and never reads past " + "sum(group_sizes), so relaxing that assertion to " + "`m >= sum_group_sizes` is the cleanest fix. The MoE block " + "deliberately does not fold the gap into a single expert " + "(that would create per-shard load imbalance under EP). " + "Re-enable once the FFI check is relaxed." ), strict=False, ) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 5f257dc577..690d804e38 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -6,8 +6,54 @@ This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer that wires together TE's fused router, a selectable token-dispatch backend -(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and optional -ring-of-experts Expert Parallelism. +(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and an +optional ragged-all-to-all (A2A / A2Av) expert-parallelism strategy. + +Architecture +------------ + +The MoEBlock is decomposed into orthogonal stages so the EP wrapper can +inject collectives between them: + +* ``_route``: gate logits -> top-k routing decisions (+ aux loss). +* ``_global_permute``: scatter tokens to experts; produces + ``[num_tokens*topk + maybe_padding, hidden]`` and + per-expert ``group_sizes`` of length ``num_experts``. +* ``_expert_ffn``: three ``grouped_dense`` calls + activation. Operates + on whatever ``(rows, group_sizes, n_groups)`` it is + handed -- agnostic to whether ``n_groups`` is the + global expert count (no-EP) or the local expert + count (A2A-EP). +* ``_global_combine``: inverse of ``_global_permute`` -- gather + weighted + sum across top-k experts. + +Two top-level forward variants compose those stages: + +* ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE + primitive's ``custom_partitioning`` rule handles + DP / FSDP / TP automatically. +* ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and inserts + ``all_gather(group_sizes)`` + forward + ``ragged_all_to_all`` + local permute around the + FFN, plus their inverses afterwards. This is the + only place ``shard_map`` is used; A2A is the + canonical EP strategy because the in-flight NCCL + EP component will require this same data layout. + +Note on ``align_size > 0`` +-------------------------- + +Both permutation backends pad each expert's group to a multiple of +``align_size`` when requested, which is what CUBLASLt's grouped GEMM wants +for FP8 shape selection. The pure-JAX backend additionally appends a +zero-input padding tail to keep the buffer statically sized for JIT, so +``sum(group_sizes) <= sorted_inputs.shape[0]`` strictly. TE's +``grouped_dense`` FFI today asserts ``m == sum(group_sizes)`` at +``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``; relaxing that +check to ``m >= sum(group_sizes)`` (the kernel itself only iterates over +``sum(group_sizes)`` rows via ``nvte_multi_tensor_gemm``) is the cleanest +way to support ``align_size > 0`` end-to-end. Until that lands the +``align_size > 0`` tests stay xfail. """ from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -20,6 +66,10 @@ from ..dense import grouped_dense from ..permutation import ( _routing_map_to_selected_experts, + compute_ragged_all_to_all_params, + compute_reverse_ragged_all_to_all_params, + local_permute_after_a2a, + local_unpermute_before_a2a, token_combine, token_dispatch, unfused_token_combine, @@ -49,15 +99,28 @@ class MoEBlock(TransformerEngineBase): """Mixture-of-Experts Flax Linen block. Encapsulates the full MoE forward pass: gate projection, fused top-k - routing, optional auxiliary load-balancing loss, token dispatch, per-expert - two-layer FFN via grouped GEMMs, activation, token combine, and optional - ring-of-experts expert parallelism. - - The permutation step is pluggable via ``permutation_backend``: - ``"pure_jax"`` (default) uses the pure-JAX argsort-based - ``unfused_token_dispatch`` / ``unfused_token_combine`` in - :mod:`transformer_engine.jax.permutation`; ``"triton"`` uses TE's fused - ``token_dispatch`` / ``token_combine`` kernels. + routing, optional auxiliary load-balancing loss, token dispatch, + per-expert two-layer FFN via grouped GEMMs, activation, token combine, + and optional ragged-all-to-all expert parallelism. + + Two permutation backends are pluggable via ``permutation_backend``: + + * ``"pure_jax"`` (default) -- argsort-based + :func:`~transformer_engine.jax.permutation.unfused_token_dispatch` / + :func:`~transformer_engine.jax.permutation.unfused_token_combine`. + Faster than Triton in profiling for DeepSeek-style configs. + * ``"triton"`` -- TE's fused + :func:`~transformer_engine.jax.permutation.token_dispatch` / + :func:`~transformer_engine.jax.permutation.token_combine` Triton + kernels. + + Expert parallelism (``expert_parallelism_axis is not None``) uses the + **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes its + own tokens globally over all experts, then a forward + ``ragged_all_to_all`` exchanges per-expert chunks so each shard ends up + holding only the tokens for its local experts; after the FFN a reverse + ``ragged_all_to_all`` returns each shard's outputs to it. This matches + the layout the in-flight NCCL EP component expects. Parameters ---------- @@ -70,70 +133,72 @@ class MoEBlock(TransformerEngineBase): activation_type : str FFN activation applied to the gate projection. Paired with the up - projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. Resolved - via :func:`flax.linen.` (``"silu"``, ``"gelu"``, ``"relu"``, - ``"swish"``, ...) plus ``"linear"`` for identity. + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. + Resolved via :func:`flax.linen.` (``"silu"``, ``"gelu"``, + ``"relu"``, ``"swish"``, ...) plus ``"linear"`` for identity. score_function : str or ScoreFunction - ``"softmax"`` (default) or ``"sigmoid"`` for :func:`fused_topk_with_score_function`. + ``"softmax"`` (default) or ``"sigmoid"`` for + :func:`fused_topk_with_score_function`. use_pre_softmax : bool Apply softmax before top-k when ``score_function="softmax"``. num_groups : int - Number of routing groups for grouped top-k (DeepSeek). ``<=0`` disables. + Number of routing groups for grouped top-k (DeepSeek). ``<=0`` + disables. group_topk : int Top-k at the group level. ``<=0`` disables. scaling_factor : float Scaling factor applied to output probs. use_expert_bias : bool - If ``True``, registers a learnable ``expert_bias`` parameter of shape - ``[num_experts]`` and passes it to the fused router. The router - primitive validates that this is paired with ``score_function="sigmoid"``. + If ``True``, registers a learnable ``expert_bias`` parameter of + shape ``[num_experts]`` and passes it to the fused router. The + router primitive validates that this is paired with + ``score_function="sigmoid"``. aux_loss_coeff : float - If ``> 0``, compute and return the MoE auxiliary load-balancing loss - scalar via :func:`fused_moe_aux_loss`. ``0`` disables. + If ``> 0``, compute and return the MoE auxiliary load-balancing + loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. gate_kernel_axes : tuple[str, ...] Logical partitioning axes for the gate kernel of shape ``[hidden, num_experts]``. wi_kernel_axes : tuple[str, ...] Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of - shape ``[num_experts, hidden, intermediate]``. Default: + shape ``[num_experts, hidden, intermediate]``. Default ``("exp", "embed", "mlp")``. wo_kernel_axes : tuple[str, ...] Logical partitioning axes for the ``wo`` kernel of shape - ``[num_experts, intermediate, hidden]``. Default: + ``[num_experts, intermediate, hidden]``. Default ``("exp", "mlp", "embed")``. input_axes : tuple[str, ...] Logical axes used to constrain the input activation sharding at the block boundary. ``()`` (default) means no constraint. expert_parallelism_axis : Optional[str] - Mesh axis along which experts are split. When set, the forward pass - is wrapped in :func:`jax.experimental.shard_map.shard_map` that - implements the ring-of-experts EP strategy: ``all_gather`` on inputs - and gate logits, local routing + dispatch + FFN + combine, then - ``psum_scatter`` on the output. When ``None`` (default), no - ``shard_map`` wrapper is used; each primitive's ``custom_partitioning`` - rule handles DP/FSDP/TP automatically. + Mesh axis along which experts are split. When set, the forward + pass is wrapped in :func:`jax.shard_map` that implements the + ragged-all-to-all EP strategy. When ``None`` (default), no + ``shard_map`` wrapper is used; each TE primitive's + ``custom_partitioning`` rule handles DP / FSDP / TP automatically. tensor_parallelism_axis : Optional[str] Mesh axis for tensor parallelism on the FFN intermediate dim. When set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed - along this axis (inside the ``shard_map`` when EP is enabled, else at - the end of the forward pass). + along this axis. permutation_backend : str ``"pure_jax"`` (default) or ``"triton"``. align_size : int Alignment for per-expert group sizes after padding. ``0`` disables - padding (faster for the unquantized path). ``>0`` is required for - quantized TE grouped GEMM whose recipe-specific alignment must divide - ``align_size``. + padding (the only supported configuration end-to-end today). ``>0`` + is required for quantized TE grouped GEMM whose recipe-specific + alignment must divide ``align_size``; see the module docstring for + the FFI assertion that currently blocks ``>0`` for both backends. dtype : jnp.dtype Compute and parameter dtype. kernel_init : Initializer Initializer for all kernels (gate + per-expert FFN). Defaults to - ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax convention). + ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax + convention). use_bias : bool If ``True``, registers per-expert FFN biases ``wi_0_bias``, ``wi_1_bias``, ``wo_bias``. @@ -198,7 +263,7 @@ def __post_init__(self): # Parameter registration # ------------------------------------------------------------------ - def _make_params(self, hidden_size: int): + def _make_params(self, hidden_size: int) -> dict: """Register module parameters and return them as a dict.""" gate_kernel = self.param( "gate_kernel", @@ -224,7 +289,7 @@ def _make_params(self, hidden_size: int): (self.num_experts, self.intermediate_size, hidden_size), self.dtype, ) - params = { + params: dict = { "gate_kernel": gate_kernel, "wi_0": wi_0, "wi_1": wi_1, @@ -276,8 +341,8 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: output : jnp.ndarray Output tensor of shape ``[batch, sequence, hidden]``. aux_loss : Optional[jnp.ndarray] - Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, - else ``None``. + Scalar auxiliary load-balancing loss when + ``aux_loss_coeff > 0``, else ``None``. """ assert inputs.ndim == 3, ( f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" @@ -287,27 +352,15 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: _, _, hidden_size = inputs.shape params = self._make_params(hidden_size) - # Gate runs OUTSIDE the EP shard_map below, so each EP shard projects - # its own local slice of tokens and we later all-gather only the - # smaller logits tensor instead of the full inputs. + # The gate runs OUTSIDE any EP shard_map: under EP each shard + # projects only its local slice of tokens, producing local gate + # logits with the same per-shard layout as ``inputs``. gate_logits = self._gate(inputs, params["gate_kernel"]) if self.expert_parallelism_axis is None: - # No EP: each primitive's own ``custom_partitioning`` rule handles - # DP / FSDP / TP across the mesh - no shard_map needed. - output, aux_loss = self._forward_body( - inputs, - gate_logits, - params, - num_experts_local=self.num_experts, - roll_to_expert_id=None, - ) + output, aux_loss = self._forward_no_ep(inputs, gate_logits, params) else: - # Ring-EP: ``_forward_body`` is wrapped in a shard_map that - # orchestrates the cross-primitive collectives (all_gather inputs - # / logits before, psum_scatter output after) which per-primitive - # ``custom_partitioning`` cannot express on its own. - output, aux_loss = self._forward_ring_ep(inputs, gate_logits, params) + output, aux_loss = self._forward_a2a_ep(inputs, gate_logits, params) if self.aux_loss_coeff <= 0.0: aux_loss = None @@ -320,235 +373,31 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: """Linear gate projection ``inputs @ gate_kernel``. - Kept as a plain matmul (not ``DenseGeneral``) so it integrates cleanly - with the EP shard_map: the gate matmul runs in the outer (pre-shard_map) - scope and its output is all-gathered along the EP axis inside. + Kept as a plain ``einsum`` (not ``DenseGeneral``) so it composes + cleanly with the EP shard_map: the gate runs in the outer + (pre-shard_map) scope and its output passes through the + ``shard_map`` boundary unchanged. """ - # Cast kernel to input dtype outside FP8 scope (gate is typically BF16/FP32). kernel = gate_kernel.astype(inputs.dtype) return jnp.einsum("bsh,he->bse", inputs, kernel) - # ------------------------------------------------------------------ - # Forward body (shared between no-EP and ring-EP paths) - # ------------------------------------------------------------------ - - def _forward_body( - self, - inputs: jnp.ndarray, - gate_logits: jnp.ndarray, - params: dict, - num_experts_local: int, - roll_to_expert_id: Optional[int], - ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Routing + dispatch + per-expert FFN + combine. - - Used both bare (no EP) and inside the ring-EP shard_map. In the - ring-EP case ``inputs`` and ``gate_logits`` are the post-all_gather - global tensors, ``num_experts_local == num_experts // num_ep``, and - ``roll_to_expert_id`` is the offset that brings this shard's experts - into slots ``[0, num_experts_local)``. - """ - batch_size, sequence_length, hidden_size = inputs.shape - inputs_2d = inputs.reshape(-1, hidden_size) - logits_2d = gate_logits.reshape(-1, self.num_experts) - - sparse_probs, routing_map, aux_loss = self._route( - logits_2d, params.get("expert_bias") - ) - - if roll_to_expert_id is not None: - # Rotate expert columns so this shard's experts come first. - routing_map = jnp.roll(routing_map, -roll_to_expert_id, axis=-1) - sparse_probs = jnp.roll(sparse_probs, -roll_to_expert_id, axis=-1) - if self.permutation_backend == "triton": - # Triton path: zero out remote-expert columns so the fused - # ``token_dispatch`` never writes tokens routed off-shard. - # The pure-JAX path zeroes garbage *output* rows below - # instead, since masking the routing_map directly would - # break the argsort-based permutation. - local_mask = ( - jnp.arange(self.num_experts) < num_experts_local - ) - routing_map = routing_map * local_mask - sparse_probs = sparse_probs * local_mask.astype(sparse_probs.dtype) - - expert_outputs, combine_state = self._dispatch_and_expert_ffn( - inputs_2d, - sparse_probs, - routing_map, - params, - num_experts_local=num_experts_local, - # The roll is already baked into ``routing_map``/``sparse_probs`` - # above, so the unfused dispatch must not roll again. - roll_to_expert_id=0 if roll_to_expert_id is not None else None, - ) - - if ( - roll_to_expert_id is not None - and self.permutation_backend == "pure_jax" - ): - # Zero the rows of ``expert_outputs`` past the real local-expert - # token count: ``grouped_dense`` leaves them as garbage because - # ``group_sizes`` was truncated to the local slice. Without this - # the unsort + weighted-sum in combine would mix garbage into - # every token's output (mirrors Maxtext's moe.py). - real_mask = ( - jnp.arange(expert_outputs.shape[0]) - < combine_state["local_real_size"] - ) - expert_outputs = jnp.where(real_mask[:, None], expert_outputs, 0) - - output = self._combine( - expert_outputs, - combine_state, - batch_size=batch_size, - sequence_length=sequence_length, - ) - - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) - - return output, aux_loss - - # ------------------------------------------------------------------ - # Ring-of-Experts EP wrapper - # ------------------------------------------------------------------ - - def _forward_ring_ep( - self, - inputs: jnp.ndarray, - gate_logits: jnp.ndarray, - params: dict, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Wrap ``_forward_body`` in a ring-of-experts ``shard_map``. - - For each EP shard the wrapper: - 1. ``all_gather`` s the local inputs / logits / expert_bias along - the EP axis so the routing sees every token globally. - 2. Calls ``_forward_body`` with ``roll_to_expert_id = - num_experts_per_shard * shard_id`` and the EP-local weight slice. - 3. ``psum_scatter`` s the resulting ``[B*num_ep, S, H]`` output back - to the EP-sharded ``[B, S, H]`` layout. - """ - from jax.experimental.shard_map import shard_map - - ep_axis = self.expert_parallelism_axis - if self.mesh is None: - raise ValueError( - "MoEBlock.expert_parallelism_axis is set; `mesh` must also be" - " provided so the ring-of-experts shard_map can be built." - ) - mesh = self.mesh - num_ep = mesh.shape[ep_axis] - assert self.num_experts % num_ep == 0, ( - f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" - ) - num_experts_per_shard = self.num_experts // num_ep - - # Pack everything that crosses the shard_map boundary into a dict - # pytree. shard_map fully supports pytrees: ``in_specs`` must - # structurally match ``captured``, and we build them in lockstep so - # adding/removing an optional bias is a single ``dict[name] = ...``. - captured: dict = { - "inputs": inputs, - "gate_logits": gate_logits, - "wi_0": params["wi_0"], - "wi_1": params["wi_1"], - "wo": params["wo"], - } - in_specs: dict = { - "inputs": P(ep_axis, None, None), - "gate_logits": P(ep_axis, None, None), - "wi_0": P(ep_axis, None, None), - "wi_1": P(ep_axis, None, None), - "wo": P(ep_axis, None, None), - } - if "expert_bias" in params: - captured["expert_bias"] = params["expert_bias"] - in_specs["expert_bias"] = P(ep_axis) - if "wi_0_bias" in params: - for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - captured[name] = params[name] - in_specs[name] = P(ep_axis, None) - - def _ring_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: - shard_id = jax.lax.axis_index(ep_axis) - - gathered_inputs = jax.lax.all_gather( - local["inputs"], axis_name=ep_axis, tiled=True - ) - gathered_logits = jax.lax.all_gather( - local["gate_logits"], axis_name=ep_axis, tiled=True - ) - - local_params: dict = { - "wi_0": local["wi_0"], - "wi_1": local["wi_1"], - "wo": local["wo"], - } - if "expert_bias" in local: - # The router operates over the full expert axis, so the - # EP-sharded bias must be all-gathered. - local_params["expert_bias"] = jax.lax.all_gather( - local["expert_bias"], axis_name=ep_axis, tiled=True - ) - if "wi_0_bias" in local: - local_params["wi_0_bias"] = local["wi_0_bias"] - local_params["wi_1_bias"] = local["wi_1_bias"] - local_params["wo_bias"] = local["wo_bias"] - - output, aux_loss = self._forward_body( - gathered_inputs, - gathered_logits, - local_params, - num_experts_local=num_experts_per_shard, - roll_to_expert_id=num_experts_per_shard * shard_id, - ) - - # ``output`` is [B*num_ep, S, H] (global batch after all_gather); - # psum_scatter along EP returns the local [B, S, H] slice. - output = jax.lax.psum_scatter( - output, ep_axis, scatter_dimension=0, tiled=True - ) - - # ``out_specs`` must match the returned pytree structurally, so - # always emit a real scalar for aux_loss; the outer ``__call__`` - # re-strips it to None when ``aux_loss_coeff <= 0``. - if aux_loss is None: - aux_loss = jnp.zeros((), dtype=self.dtype) - return output, aux_loss - - # ``check_rep=False`` disables shard_map's invariant that any output - # declared as ``P()`` is replicated across ``ep_axis``. We use - # ``axis_index(ep_axis)`` inside ``_ring_fn`` to compute a per-shard - # roll, which makes the body genuinely non-replicated and would - # otherwise (correctly) fail the check. The ``psum_scatter`` of the - # output already produces the right cross-shard semantics; this is - # the standard JAX escape hatch when collectives + per-shard logic - # coexist. - return shard_map( - _ring_fn, - mesh=mesh, - in_specs=in_specs, - out_specs=(P(ep_axis, None, None), P()), - check_rep=False, - )(captured) - # ------------------------------------------------------------------ # Route # ------------------------------------------------------------------ - - def _route( + # + # The router is split into two pieces so the EP path can compute + # aux_loss over global (cross-shard) statistics without re-running + # the main top-k path. ``_route_topk`` returns the per-token routing + # decisions (used by ``_global_permute``) and ``_compute_aux_loss`` + # returns the scalar load-balancing loss given the (possibly + # gathered) logits. + + def _route_topk( self, logits_2d: jnp.ndarray, expert_bias: Optional[jnp.ndarray], - ) -> Tuple[jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]: - """Run the fused router and optional aux-loss.""" + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run the fused router top-k selection.""" sparse_probs, routing_map = fused_topk_with_score_function( logits_2d, topk=self.num_experts_per_tok, @@ -560,47 +409,73 @@ def _route( expert_bias=expert_bias, ) sparse_probs = sparse_probs.astype(self.dtype) + return sparse_probs, routing_map - aux_loss = None - if self.aux_loss_coeff > 0.0: - # The score-for-aux kernel runs independently (no data dependency - # on the main kernel), so XLA can overlap them on the GPU. - aux_scores, aux_routing_map = fused_topk_with_score_function( - logits_2d.astype(jnp.float32), - topk=self.num_experts_per_tok, - score_function=self.score_function, - compute_aux_scores=True, - ) - aux_tokens_per_expert = jnp.sum( - aux_routing_map.astype(jnp.int32), axis=0 - ) - aux_loss = fused_moe_aux_loss( - aux_scores.astype(jnp.float32), - aux_tokens_per_expert, - topk=self.num_experts_per_tok, - coeff=self.aux_loss_coeff, - ) - - return sparse_probs, routing_map, aux_loss + def _compute_aux_loss( + self, + logits_2d: jnp.ndarray, + ) -> Optional[jnp.ndarray]: + """Compute the MoE auxiliary load-balancing loss. + + The score-for-aux kernel has no data dependency on the main + routing kernel, so XLA can overlap them on the GPU. + + ``logits_2d`` should be the *full* logits tensor over the global + token batch -- under EP the caller is responsible for + :func:`jax.lax.all_gather` ing the logits before calling this so + the aux_loss formula + ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` + sees the global ``T`` and the global ``tokens_per_expert``. + """ + if self.aux_loss_coeff <= 0.0: + return None + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d.astype(jnp.float32), + topk=self.num_experts_per_tok, + score_function=self.score_function, + compute_aux_scores=True, + ) + aux_tokens_per_expert = jnp.sum( + aux_routing_map.astype(jnp.int32), axis=0 + ) + return fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + aux_tokens_per_expert, + topk=self.num_experts_per_tok, + coeff=self.aux_loss_coeff, + ) # ------------------------------------------------------------------ - # Dispatch + expert FFN + # Global permute (route -> token dispatch) # ------------------------------------------------------------------ - def _dispatch_and_expert_ffn( + def _global_permute( self, inputs_2d: jnp.ndarray, sparse_probs: jnp.ndarray, routing_map: jnp.ndarray, - params: dict, - num_experts_local: int, - roll_to_expert_id: Optional[int], - ) -> Tuple[jnp.ndarray, dict]: - """Dispatch tokens, run the three grouped GEMMs + activation, return expert outputs. - - Returns a tuple ``(expert_outputs, combine_state)`` where - ``combine_state`` carries the per-backend state needed to rebuild the - original token ordering in :meth:`_combine`. + ) -> dict: + """Dispatch tokens to the global expert axis. + + Returns a permutation-result dict suitable both for the no-EP + forward (where the same buffer feeds ``_expert_ffn`` directly) and + for the A2A-EP path (where the buffer is sliced + sent over the EP + axis before the FFN). The dict carries the per-backend opaque + state needed to invert the dispatch in :meth:`_global_combine`. + + The output dict layout is:: + + { + "backend": "pure_jax" | "triton", + "sorted_inputs": [buffer_size, hidden], + "group_sizes": [num_experts], # per-expert, + # length == E always. + "perm_state": UnfusedPermState | None, # pure_jax + "row_id_map": jnp.ndarray | None, # triton + "pad_offsets": jnp.ndarray | None, # triton + "routing_weights": jnp.ndarray | None, # pure_jax + "merging_probs": jnp.ndarray | None, # triton + } """ num_tokens = inputs_2d.shape[0] topk = self.num_experts_per_tok @@ -615,79 +490,90 @@ def _dispatch_and_expert_ffn( num_experts=self.num_experts, num_experts_per_tok=topk, align_size=self.align_size, - roll_to_expert_id=roll_to_expert_id, ) - # Slice group_sizes to just this shard's experts. When not using - # EP, ``num_experts_local == self.num_experts`` so this is a no-op. - # - # NOTE on padded buffers (``align_size > 0``): - # ``unfused_token_dispatch`` pads ``sorted_inputs`` to a static - # worst-case row count so JIT shape inference is happy. The - # returned ``group_sizes`` deliberately tracks only real + real - # alignment-padding tokens; the remaining rows are zero-input - # placeholders that ``grouped_dense`` does not need to touch. - # - # TE's ``grouped_dense`` FFI today asserts strictly - # ``sum(group_sizes) == sorted_inputs.shape[0]``. When that - # assertion is relaxed to ``>=`` (the GEMM only iterates over the - # first ``sum(group_sizes)`` rows anyway), this code works as-is. - # Folding the gap into a single expert would create a per-shard - # load imbalance and is intentionally avoided here. - group_sizes = group_sizes[:num_experts_local] - combine_state = { + return { "backend": "pure_jax", + "sorted_inputs": sorted_inputs, + "group_sizes": group_sizes, "perm_state": perm_state, "routing_weights": routing_weights, - "local_real_size": jnp.sum(group_sizes), - } - else: # "triton" - num_out_tokens = num_tokens * topk - align_size_arg = self.align_size if self.align_size > 0 else None - ( - sorted_inputs, - _permuted_probs, - row_id_map, - pad_offsets, - group_sizes, - ) = token_dispatch( - inputs_2d, - routing_map, - num_out_tokens=num_out_tokens, - probs=sparse_probs, - align_size=align_size_arg, - ) - group_sizes = group_sizes[:num_experts_local] - combine_state = { - "backend": "triton", - "row_id_map": row_id_map, - "pad_offsets": pad_offsets, - "merging_probs": sparse_probs, - "group_sizes": group_sizes, } - # ------------------------------------------------------------------ - # Expert FFN: grouped GEMMs w0, w1 + activation + w_o. - # ------------------------------------------------------------------ + # triton + num_out_tokens = num_tokens * topk + align_size_arg = self.align_size if self.align_size > 0 else None + ( + sorted_inputs, + _permuted_probs, + row_id_map, + pad_offsets, + group_sizes, + ) = token_dispatch( + inputs_2d, + routing_map, + num_out_tokens=num_out_tokens, + probs=sparse_probs, + align_size=align_size_arg, + ) + return { + "backend": "triton", + "sorted_inputs": sorted_inputs, + "group_sizes": group_sizes, + "row_id_map": row_id_map, + "pad_offsets": pad_offsets, + "merging_probs": sparse_probs, + } + + # ------------------------------------------------------------------ + # Expert FFN (three grouped_dense calls + activation) + # ------------------------------------------------------------------ + + def _expert_ffn( + self, + sorted_inputs: jnp.ndarray, + group_sizes: jnp.ndarray, + params: dict, + n_groups: int, + ) -> jnp.ndarray: + """Run the per-expert SwiGLU-style FFN over a permuted buffer. + + Parameters + ---------- + sorted_inputs : jnp.ndarray + Permuted tokens of shape ``[buffer_size, hidden]`` (rows + grouped by expert). + group_sizes : jnp.ndarray + Per-group token counts of shape ``[n_groups]``. + ``sum(group_sizes)`` must equal ``buffer_size`` (TE + ``grouped_dense`` FFI assertion at + ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``). + params : dict + Block parameters from :meth:`_make_params`. Reads ``wi_0``, + ``wi_1``, ``wo``, and the optional bias entries. + n_groups : int + Number of expert groups. Equals ``self.num_experts`` for the + no-EP path and ``num_experts // num_ep`` for the A2A-EP path. + Used to size the per-call quantizer set so the FP8 metadata + tensors match ``group_sizes``. + + Returns + ------- + expert_outputs : jnp.ndarray + ``[buffer_size, hidden]``. + """ wi_0 = params["wi_0"] wi_1 = params["wi_1"] wo = params["wo"] # Each grouped_dense call gets its own quantizer_set with - # ``n_groups=num_experts_local``; this matches the shape of - # ``group_sizes`` passed in and keeps the quantizer FP8 meta correctly - # sized per shard. - q_set_w0 = self.generate_quantizer_set( - postfix="_w0", n_groups=num_experts_local - ) - q_set_w1 = self.generate_quantizer_set( - postfix="_w1", n_groups=num_experts_local - ) - q_set_wo = self.generate_quantizer_set( - postfix="_wo", n_groups=num_experts_local - ) - - # Cast kernels to the sort dtype when no FP8 quantization is active - # (mirrors DenseGeneral). + # n_groups matching ``group_sizes``; this keeps the FP8 meta + # tensors correctly sized in both no-EP and A2A-EP cases. + q_set_w0 = self.generate_quantizer_set(postfix="_w0", n_groups=n_groups) + q_set_w1 = self.generate_quantizer_set(postfix="_w1", n_groups=n_groups) + q_set_wo = self.generate_quantizer_set(postfix="_wo", n_groups=n_groups) + + # Cast kernels to the activation dtype when no FP8 quantization + # is active (mirrors DenseGeneral). if q_set_w0 == noop_quantizer_set: wi_0 = wi_0.astype(sorted_inputs.dtype) if q_set_w1 == noop_quantizer_set: @@ -695,9 +581,9 @@ def _dispatch_and_expert_ffn( if q_set_wo == noop_quantizer_set: wo = wo.astype(sorted_inputs.dtype) - # ``grouped_dense`` accepts per-expert bias of shape (G, N); it adds - # ``bias[i]`` to the ``group_sizes[i]`` rows belonging to expert ``i`` - # in the permuted layout. + # ``grouped_dense`` accepts per-expert bias of shape (G, N); it + # adds ``bias[i]`` to the ``group_sizes[i]`` rows belonging to + # expert ``i`` in the permuted layout. wi_0_bias = params.get("wi_0_bias") if self.use_bias else None wi_1_bias = params.get("wi_1_bias") if self.use_bias else None wo_bias = params.get("wo_bias") if self.use_bias else None @@ -730,25 +616,30 @@ def _dispatch_and_expert_ffn( bias=wo_bias, quantizer_set=q_set_wo, ) - - return expert_outputs, combine_state + return expert_outputs # ------------------------------------------------------------------ - # Combine + # Global combine (token combine -> back to [B, S, H]) # ------------------------------------------------------------------ - def _combine( + def _global_combine( self, expert_outputs: jnp.ndarray, - combine_state: dict, + perm_result: dict, batch_size: int, sequence_length: int, ) -> jnp.ndarray: - if combine_state["backend"] == "pure_jax": + """Inverse of :meth:`_global_permute`. + + Gathers per-expert outputs back into ``[batch, sequence, hidden]`` + and applies the per-token weighted sum across the top-k experts. + """ + backend = perm_result["backend"] + if backend == "pure_jax": return unfused_token_combine( expert_outputs, - combine_state["perm_state"], - combine_state["routing_weights"], + perm_result["perm_state"], + perm_result["routing_weights"], num_experts_per_tok=self.num_experts_per_tok, batch_size=batch_size, sequence_length=sequence_length, @@ -756,11 +647,297 @@ def _combine( # triton out_2d = token_combine( expert_outputs, - combine_state["row_id_map"], - merging_probs=combine_state["merging_probs"], - pad_offsets=combine_state["pad_offsets"], + perm_result["row_id_map"], + merging_probs=perm_result["merging_probs"], + pad_offsets=perm_result["pad_offsets"], ) hidden_size = out_2d.shape[-1] return out_2d.reshape(batch_size, sequence_length, hidden_size).astype( self.dtype ) + + # ------------------------------------------------------------------ + # No-EP forward + # ------------------------------------------------------------------ + + def _forward_no_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Single-shard or DP/FSDP/TP forward (no shard_map wrapper). + + DP / FSDP / TP all flow through each TE primitive's + ``custom_partitioning`` rule -- there is no cross-primitive + collective that the rules cannot express on their own, so a + ``shard_map`` is unnecessary here. + """ + batch_size, sequence_length, hidden_size = inputs.shape + inputs_2d = inputs.reshape(-1, hidden_size) + logits_2d = gate_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map = self._route_topk( + logits_2d, params.get("expert_bias") + ) + aux_loss = self._compute_aux_loss(logits_2d) + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + expert_outputs = self._expert_ffn( + perm["sorted_inputs"], + perm["group_sizes"], + params, + n_groups=self.num_experts, + ) + output = self._global_combine( + expert_outputs, perm, batch_size, sequence_length + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + return output, aux_loss + + # ------------------------------------------------------------------ + # A2A (ragged-all-to-all) EP forward + # ------------------------------------------------------------------ + + def _forward_a2a_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + params: dict, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Wrap the body in a ``shard_map`` that runs a forward + ``ragged_all_to_all`` (A2A / A2Av) around the FFN. + + For each EP shard the wrapper: + + 1. Routes the shard's local tokens **globally** over all + ``num_experts`` experts (no roll, no local-mask -- every shard + sees the full expert axis). + 2. ``all_gather`` s its per-expert ``group_sizes`` so all shards + know the complete ``[num_ep, num_experts]`` token-count matrix. + 3. Forward ``ragged_all_to_all`` over the EP axis: each shard + sends per-expert chunks to the shard that owns those experts, + and receives chunks for its own ``num_experts // num_ep`` + local experts from every other shard. + 4. Reorders the received buffer from ``(source_shard, expert)`` + to ``(expert, source_shard)`` ordering so each local expert's + tokens are contiguous. + 5. Runs the three ``grouped_dense`` calls + activation over the + ``E_local``-group buffer. + 6. Reverses the local reorder. + 7. Reverse ``ragged_all_to_all`` over EP returns each shard's + token outputs to it. + 8. Inverts the global permute and applies the top-k weighted sum. + """ + from jax.experimental.shard_map import shard_map + + ep_axis = self.expert_parallelism_axis + if self.mesh is None: + raise ValueError( + "MoEBlock.expert_parallelism_axis is set; `mesh` must also" + " be provided so the EP shard_map can be built." + ) + mesh = self.mesh + num_ep = mesh.shape[ep_axis] + assert self.num_experts % num_ep == 0, ( + f"num_experts={self.num_experts} must be divisible by EP" + f" size={num_ep}" + ) + num_experts_local = self.num_experts // num_ep + + # Pre-compute the worst-case A2A receive buffer size (compile-time + # constant). Each shard contributes ``b_l*S*topk = B*S*topk/num_ep`` + # token-expert pairs across all experts; the worst case for one + # shard is "every global pair lands on this shard's local + # experts" -- ``num_ep * (B*S*topk/num_ep) = B*S*topk`` rows. JIT + # needs this static, so we use the global ``batch_size`` from the + # outer scope (sharded layouts don't change it). + global_batch_size, sequence_length, _hidden = inputs.shape + topk = self.num_experts_per_tok + recv_buffer_rows = global_batch_size * sequence_length * topk + + # Pack everything that crosses the shard_map boundary into a dict + # pytree. shard_map fully supports pytrees: ``in_specs`` must + # structurally match ``captured`` and we build them in lockstep + # so adding/removing an optional bias is one ``dict[name] = ...``. + captured: dict = { + "inputs": inputs, + "gate_logits": gate_logits, + "wi_0": params["wi_0"], + "wi_1": params["wi_1"], + "wo": params["wo"], + } + in_specs: dict = { + "inputs": P(ep_axis, None, None), + "gate_logits": P(ep_axis, None, None), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if "expert_bias" in params: + captured["expert_bias"] = params["expert_bias"] + in_specs["expert_bias"] = P(ep_axis) + if "wi_0_bias" in params: + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + captured[name] = params[name] + in_specs[name] = P(ep_axis, None) + + def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: + shard_id = jax.lax.axis_index(ep_axis) + + # -- Stage 1: per-shard route + global permute over all E -- + # Inside the shard_map body each input has its EP axis already + # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. + local_inputs = local["inputs"] + local_logits = local["gate_logits"] + local_b, local_s, local_h = local_inputs.shape + inputs_2d = local_inputs.reshape(-1, local_h) + logits_2d = local_logits.reshape(-1, self.num_experts) + + # The router operates over the full expert axis, so the + # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be + # all-gathered before being passed in. + if "expert_bias" in local: + full_expert_bias = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True + ) + else: + full_expert_bias = None + sparse_probs, routing_map = self._route_topk( + logits_2d, full_expert_bias + ) + + # aux_loss must see the global token batch and the global + # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( + # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable + # (the sum_t * tokens product is data-dependent across + # shards). Cheapest fix: gather logits along the EP axis and + # run the aux-loss kernel on the global tensor. The aux + # branch has no data dependency on the main routing path so + # XLA can overlap the two on the GPU. + if self.aux_loss_coeff > 0.0: + global_logits_2d = jax.lax.all_gather( + logits_2d, axis_name=ep_axis, axis=0, tiled=True + ) + aux_loss = self._compute_aux_loss(global_logits_2d) + else: + aux_loss = None + + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + global_group_sizes = perm["group_sizes"] # [E] + + # -- Stage 2: gather per-expert counts across the EP axis -- + all_shards_tokens_per_expert = jax.lax.all_gather( + global_group_sizes[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) # [num_ep, num_experts] + + # -- Stage 3: forward ragged_all_to_all over EP -- + in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + recv_buf = jnp.zeros( + (recv_buffer_rows, local_h), + dtype=perm["sorted_inputs"].dtype, + ) + x_recv = jax.lax.ragged_all_to_all( + perm["sorted_inputs"], + recv_buf, + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) + + # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) + sorted_x, local_group_sizes, local_perm_state = ( + local_permute_after_a2a( + x_recv, + all_shards_tokens_per_expert, + shard_id, + num_ep, + ) + ) + + # -- Stage 5: per-expert FFN (E_local groups) -- + local_params: dict = { + "wi_0": local["wi_0"], + "wi_1": local["wi_1"], + "wo": local["wo"], + } + if "wi_0_bias" in local: + local_params["wi_0_bias"] = local["wi_0_bias"] + local_params["wi_1_bias"] = local["wi_1_bias"] + local_params["wo_bias"] = local["wo_bias"] + expert_outputs = self._expert_ffn( + sorted_x, + local_group_sizes, + local_params, + n_groups=num_experts_local, + ) + + # -- Stage 6: invert local permute -- + x_send_back = local_unpermute_before_a2a( + expert_outputs, local_perm_state + ) + + # -- Stage 7: reverse ragged_all_to_all over EP -- + in_off_r, send_sz_r, out_off_r, recv_sz_r = ( + compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + ) + send_back_buf = jnp.zeros_like(perm["sorted_inputs"]) + y_back = jax.lax.ragged_all_to_all( + x_send_back, + send_back_buf, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) + + # -- Stage 8: invert global permute, weighted sum over top-k -- + output = self._global_combine( + y_back, perm, batch_size=local_b, sequence_length=local_s + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``out_specs`` must match the returned pytree structurally, + # so always emit a real scalar for aux_loss; the outer + # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss + + # ``check_rep=False`` disables shard_map's invariant that any + # output declared as ``P()`` is replicated across ``ep_axis``. + # We use ``axis_index(ep_axis)`` inside ``_a2a_fn`` so the body + # is genuinely non-replicated, which would otherwise (correctly) + # fail the check. ``ragged_all_to_all`` already produces the + # right cross-shard semantics; this is the standard JAX escape + # hatch when collectives + per-shard logic coexist. + return shard_map( + _a2a_fn, + mesh=mesh, + in_specs=in_specs, + out_specs=(P(ep_axis, None, None), P()), + check_rep=False, + )(captured) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 1a492ba186..f4599a7b8f 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -52,6 +52,11 @@ "unfused_token_dispatch", "unfused_token_combine", "UnfusedPermState", + # Ragged-all-to-all expert-parallelism helpers + "compute_ragged_all_to_all_params", + "compute_reverse_ragged_all_to_all_params", + "local_permute_after_a2a", + "local_unpermute_before_a2a", ] @@ -989,3 +994,334 @@ def unfused_token_combine( reshaped_weights, ) return output.reshape(batch_size, sequence_length, hidden_size) + + +# ============================================================================= +# Ragged-all-to-all expert-parallelism helpers +# ============================================================================= +# +# These helpers support the ragged-all-to-all (A2A / A2Av) EP strategy used by +# :class:`transformer_engine.jax.flax.MoEBlock`. The forward EP path looks +# like:: +# +# route -> global_permute -> AG(group_sizes, ep) +# -> ragged_all_to_all(fwd, ep) +# -> local_permute_after_a2a +# -> grouped_dense x3 + activation +# -> local_unpermute_before_a2a +# -> ragged_all_to_all(reverse, ep) +# -> global_combine +# +# The two ``compute_*_ragged_all_to_all_params`` functions translate +# ``all_shards_tokens_per_expert`` (an EP-axis ``all_gather`` of each shard's +# global ``group_sizes``) into the four ``ragged_all_to_all`` arguments +# (``input_offsets``, ``send_sizes``, ``output_offsets``, ``recv_sizes``). +# ``shard_id`` may be a traced value (e.g. from :func:`jax.lax.axis_index`), +# which is why every slice into ``all_shards_tokens_per_expert`` uses +# :func:`jax.lax.dynamic_slice`. +# +# These functions are pure JAX (no MaxText / TE dependencies) and equivalent +# to :func:`maxtext.layers.te_permutation.compute_ragged_all_to_all_params` +# / :func:`compute_reverse_ragged_all_to_all_params`. + + +def compute_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Forward-direction ragged_all_to_all parameters. + + Computes the four index/size arrays that :func:`jax.lax.ragged_all_to_all` + consumes for the **forward** EP shuffle, where each shard sends its + expert-grouped tokens to the shard that owns those experts. + + Parameters + ---------- + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts gathered across the EP axis. Shape + ``[num_expert_shards, num_experts]`` and integer dtype. + shard_id : jnp.ndarray + Index of the current shard along the EP axis (typically + :func:`jax.lax.axis_index` of the EP axis). Must be a 0-d integer. + num_expert_shards : int + Static EP-axis size. Must match + ``all_shards_tokens_per_expert.shape[0]``. + + Returns + ------- + input_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. Cumulative ``send_sizes`` (with a + leading 0) -- where in the local source buffer each destination + shard's chunk begins. + send_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``send_sizes[i]`` is the number of + tokens this shard sends to shard ``i`` (= the sum of token counts + for the experts owned by shard ``i``). + output_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. ``output_offsets[i]`` is the row in + shard ``i``'s receive buffer where this shard's contribution should + land. Sender-side semantics, per :func:`jax.lax.ragged_all_to_all`. + recv_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``recv_sizes[i]`` is the number of + tokens shard ``i`` sends to this shard. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + + # This shard's row of the gathered table, reshaped so axis 0 indexes the + # destination shard and axis 1 indexes its local experts. + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape( + num_expert_shards, local_expert_size + ) + + # send_sizes[i] = sum of token counts for shard i's experts in our buffer. + send_sizes = jnp.sum(local_reshaped, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens shard i sends to this shard, i.e. the + # sum across our local-expert columns of shard i's row. + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + recv_sizes = jnp.sum(local_expert_columns, axis=1) + + # output_offsets uses sender-side semantics for ragged_all_to_all: + # output_offsets[j] = row in shard j's buffer where THIS shard's chunk + # should be placed. That's the cumulative sum (over source shards 0..j-1) + # of how many tokens those earlier source shards already sent to shard j. + sends_to_target = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # [src_shard, dst_shard] + zero_row = jnp.zeros((1, num_expert_shards), dtype=sends_to_target.dtype) + cumulated = jnp.cumsum( + jnp.concatenate([zero_row, sends_to_target], axis=0), + axis=0, + dtype=sends_to_target.dtype, + ) # [src_shard + 1, dst_shard]; row r = total sent by sources 0..r-1 + output_offsets = jax.lax.dynamic_slice( + cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +def compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Reverse-direction ragged_all_to_all parameters. + + Mirror of :func:`compute_ragged_all_to_all_params` for the **reverse** + EP shuffle that returns expert outputs to their source shards. The + sender / receiver roles are swapped: what we received in the forward + shuffle we now send back, and vice versa. + + Parameters and shapes are identical to + :func:`compute_ragged_all_to_all_params`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + + local_expert_start = shard_id * local_expert_size + + # In reverse, what we received becomes what we send. send_sizes[i] is how + # many tokens we send back to source shard i (= what shard i originally + # sent us, summed across our local experts). + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + send_sizes = jnp.sum(local_expert_columns, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens we receive back from shard i (= what + # we originally sent to shard i in the forward). + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape( + num_expert_shards, local_expert_size + ) + recv_sizes = jnp.sum(local_reshaped, axis=1) + + # output_offsets: the reverse sends-to-target matrix is the transpose of + # the forward one (row i = what shard i sends in reverse = what shard i + # received in forward). Cumsum down source-shard axis, then index our row. + fwd_sends_to = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # forward: [src, dst] + rev_sends_to = jnp.transpose(fwd_sends_to) # reverse: [src, dst] + zero_row = jnp.zeros((1, num_expert_shards), dtype=rev_sends_to.dtype) + rev_cumulated = jnp.cumsum( + jnp.concatenate([zero_row, rev_sends_to], axis=0), + axis=0, + dtype=rev_sends_to.dtype, + ) + output_offsets = jax.lax.dynamic_slice( + rev_cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +# ----------------------------------------------------------------------------- +# Local permute / unpermute +# ----------------------------------------------------------------------------- +# +# After the forward ragged_all_to_all the receive buffer is laid out as +# ``[from_shard_0_chunk | from_shard_1_chunk | ... ]`` and within each chunk +# tokens are sorted by local-expert id. To feed ``grouped_dense`` we want +# ``[expert_0_block | expert_1_block | ... ]`` where each expert's block +# contains tokens from every source shard. ``local_permute_after_a2a`` +# performs that reorder; ``local_unpermute_before_a2a`` undoes it before the +# reverse ragged_all_to_all. +# +# Implementation uses :func:`sort_chunks_by_index`, which is Triton-backed +# (see ``transformer_engine.jax.triton_extensions.permutation``) and has a +# paired custom-VJP backward. There is no pure-JAX alternative here -- the +# global :func:`unfused_token_dispatch` / :func:`token_dispatch` choice is +# unaffected by this; only the (small) post-A2A chunk reorder uses Triton +# unconditionally. + + +def local_permute_after_a2a( + x_recv: jnp.ndarray, + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: + """Reorder tokens received via ragged_all_to_all so each local expert's + tokens are contiguous. + + This is the EP-side complement to the global :func:`token_dispatch` / + :func:`unfused_token_dispatch`. Internally uses + :func:`sort_chunks_by_index` (Triton-backed) for both the forward sort + and -- via :func:`local_unpermute_before_a2a` -- the inverse. + + Parameters + ---------- + x_recv : jnp.ndarray + Output of the forward ``ragged_all_to_all`` of shape + ``[buffer_size, hidden_size]``. Layout: source-shard major, then + local-expert id within each source chunk. + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts of shape + ``[num_expert_shards, num_experts]``. + shard_id : jnp.ndarray + Current EP shard index (typically a traced + :func:`jax.lax.axis_index`). + num_expert_shards : int + Static EP-axis size. + + Returns + ------- + sorted_x : jnp.ndarray + Tokens reordered into expert-major layout. Same shape as ``x_recv``. + local_group_sizes : jnp.ndarray + Per-local-expert token counts of shape ``[local_expert_size]``. + state : dict + Opaque state for :func:`local_unpermute_before_a2a`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert num_experts % num_expert_shards == 0, ( + f"num_experts={num_experts} must be divisible by num_expert_shards" + f"={num_expert_shards}" + ) + local_expert_size = num_experts // num_expert_shards + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + + # Flat sizes in source-major order, matching the receive buffer layout: + # [(s0,e0), (s0,e1), ..., (s1,e0), (s1,e1), ...] + split_sizes = local_expert_columns.reshape(-1) + + # Permutation that maps source-major -> expert-major: + # original index = s * E_local + e + # target index = e * num_shards + s + indices_matrix = jnp.arange( + num_expert_shards * local_expert_size, dtype=jnp.int32 + ).reshape(num_expert_shards, local_expert_size) + sorted_chunk_indices = indices_matrix.T.reshape(-1) + + sorted_x, _ = sort_chunks_by_index(x_recv, split_sizes, sorted_chunk_indices) + sorted_split_sizes = split_sizes[sorted_chunk_indices] + inverse_chunk_indices = jnp.argsort(sorted_chunk_indices) + local_group_sizes = jnp.sum(local_expert_columns, axis=0) + state = { + "sorted_split_sizes": sorted_split_sizes, + "inverse_chunk_indices": inverse_chunk_indices, + } + return sorted_x, local_group_sizes, state + + +def local_unpermute_before_a2a( + expert_outputs: jnp.ndarray, + state: dict, +) -> jnp.ndarray: + """Inverse of :func:`local_permute_after_a2a`. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the local expert FFN of shape ``[buffer_size, hidden_size]``, + in expert-major layout. + state : dict + Opaque state returned by :func:`local_permute_after_a2a`. + + Returns + ------- + unsorted_x : jnp.ndarray + Tokens reordered back into source-shard-major layout, ready for the + reverse ``ragged_all_to_all``. Same shape as ``expert_outputs``. + """ + out, _ = sort_chunks_by_index( + expert_outputs, + state["sorted_split_sizes"], + state["inverse_chunk_indices"], + ) + return out From 6f87629844f1fab722689681e81fc098475d011a Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 30 Apr 2026 14:08:55 -0700 Subject: [PATCH 05/18] fix test_distributed issues with unpopulated LogicallyPartition pytree and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 29 ++++++++++++++++++++++++- transformer_engine/jax/flax/moe.py | 2 +- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 9d9e57140f..1c7b99cda4 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -112,7 +112,34 @@ def loss_fn(block, variables, x): with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): with nn.logical_axis_rules(logical_axis_rules): - sharded_variables = sharded_block.init(init_key, inputs) + # ``MoEBlock`` registers params via ``with_logical_partitioning`` + # which only attaches LogicallyPartitioned metadata; the + # underlying jax.Array stays single-device unless ``init`` + # is run inside ``jax.jit`` with ``out_shardings``. Use the + # canonical Flax-Linen pattern (mirrors + # ``examples/jax/encoder/test_model_parallel_encoder.py``): + # 1. ``jax.eval_shape`` to trace abstract variables (keeps + # the LogicallyPartitioned wrappers; only the inner + # arrays become ShapeDtypeStruct); + # 2. ``nn.get_partition_spec`` to extract a tree of logical + # PartitionSpecs from those wrappers (treats + # LogicallyPartitioned as a leaf); + # 3. ``nn.logical_to_mesh_sharding`` to resolve those + # logical specs to NamedShardings via the active rules; + # 4. ``jax.jit(init, out_shardings=...)`` to actually + # place the params on-device with those shardings. + abstract_variables = jax.eval_shape( + sharded_block.init, init_key, inputs + ) + logical_partition_spec = nn.get_partition_spec( + abstract_variables + ) + out_shardings = nn.logical_to_mesh_sharding( + logical_partition_spec, mesh, logical_axis_rules + ) + sharded_variables = jax.jit( + sharded_block.init, out_shardings=out_shardings + )(init_key, inputs) (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( sharded_block, sharded_variables, inputs diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 690d804e38..050cbe84d0 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -937,7 +937,7 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: return shard_map( _a2a_fn, mesh=mesh, - in_specs=in_specs, + in_specs=(in_specs,), out_specs=(P(ep_axis, None, None), P()), check_rep=False, )(captured) From 6aeb491fa86a3afa33bb327f5d90f3a26a456e3d Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 5 May 2026 14:45:36 -0700 Subject: [PATCH 06/18] add option to choose weight fsdp sharding axis Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 10 +++- transformer_engine/jax/flax/moe.py | 78 ++++++++++++++++++++----- 2 files changed, 71 insertions(+), 17 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 1c7b99cda4..3cd902aa88 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -34,7 +34,9 @@ def _inject_moe(request): DTYPE = jnp.bfloat16 -BATCH_SIZE = 2 +# Must be divisible by ep*fsdp = 4 so the batch dim can be sharded over +# the full ('ep','fsdp') axis tuple under Experiment 3. +BATCH_SIZE = 4 SEQUENCE_LENGTH = 16 HIDDEN_SIZE = 64 INTERMEDIATE_SIZE = 128 @@ -103,8 +105,14 @@ def loss_fn(block, variables, x): ("batch", "fsdp"), ("embed", "fsdp"), ) + # ``data_parallelism_axes=("fsdp",)`` opts in to the true-FSDP + # behavior: the ``shard_map``'s in_specs/out_specs become + # ``P(("ep","fsdp"), None, None)`` for the batch dim, so each + # device owns ``B/(ep*fsdp)`` unique tokens (no redundant compute + # across fsdp peers within an ep group). sharded_block = MoEBlock( expert_parallelism_axis="ep", + data_parallelism_axes=("fsdp",), mesh=mesh, input_axes=("batch", None, None), **base_kwargs, diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 050cbe84d0..bfa00d3827 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -179,6 +179,18 @@ class MoEBlock(TransformerEngineBase): ragged-all-to-all EP strategy. When ``None`` (default), no ``shard_map`` wrapper is used; each TE primitive's ``custom_partitioning`` rule handles DP / FSDP / TP automatically. + data_parallelism_axes : tuple[str, ...] + Additional mesh axes that the input *batch* dim is sharded over + IN ADDITION to ``expert_parallelism_axis``. Setting this to e.g. + ``("fsdp",)`` makes the ``shard_map`` ``in_specs`` for the batch + dim become ``P(("ep", "fsdp"), None, None)`` -- giving each + device a unique slice of the batch (true FSDP) instead of + replicating the per-ep-shard batch across fsdp peers. + Routing is unaffected: ``axis_index("ep")`` still controls the + ragged-all-to-all; the extra fsdp peers within an ep group send + and receive their own batch slices in lockstep. Default ``()`` + preserves legacy ZeRO-1-style behavior (activations replicated + on fsdp within an ep group). tensor_parallelism_axis : Optional[str] Mesh axis for tensor parallelism on the FFN intermediate dim. When set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed @@ -227,6 +239,7 @@ class MoEBlock(TransformerEngineBase): # Parallelism expert_parallelism_axis: Optional[str] = None + data_parallelism_axes: Tuple[str, ...] = () tensor_parallelism_axis: Optional[str] = None # ``jax.sharding.Mesh`` to use when ``expert_parallelism_axis`` is set. # Required for the ``shard_map`` wrapper; ignored otherwise. @@ -751,16 +764,42 @@ def _forward_a2a_ep( ) num_experts_local = self.num_experts // num_ep - # Pre-compute the worst-case A2A receive buffer size (compile-time - # constant). Each shard contributes ``b_l*S*topk = B*S*topk/num_ep`` - # token-expert pairs across all experts; the worst case for one - # shard is "every global pair lands on this shard's local - # experts" -- ``num_ep * (B*S*topk/num_ep) = B*S*topk`` rows. JIT - # needs this static, so we use the global ``batch_size`` from the - # outer scope (sharded layouts don't change it). + # Compose the BATCH sharding axis tuple. ``ep`` is always part of + # the batch axis (so ragged_all_to_all has data to route); any + # ``data_parallelism_axes`` are added on top so the per-device + # batch slice is genuinely unique (true FSDP / DP). + # Examples: + # data_parallelism_axes=() -> P('ep', None, None) + # data_parallelism_axes=('fsdp',) -> P(('ep','fsdp'), None, None) + # data_parallelism_axes=('fsdp','data') -> P(('ep','fsdp','data'), ...) + for ax in self.data_parallelism_axes: + if ax not in mesh.shape: + raise ValueError( + f"data_parallelism_axes contains {ax!r} but mesh has" + f" axes {tuple(mesh.shape.keys())}" + ) + if len(self.data_parallelism_axes) == 0: + batch_pspec_axis: Any = ep_axis + else: + batch_pspec_axis = (ep_axis, *self.data_parallelism_axes) + # The size by which the per-device batch is divided BEYOND ep. + # Used to tighten the worst-case ragged_all_to_all recv buffer: + # at most ``num_ep`` peers each send their entire local + # ``B/(num_ep*dp_size)*S*topk`` token-expert pairs, so the worst + # recv per device is ``num_ep * B/(num_ep*dp_size)*S*topk + # = B/dp_size * S * topk``. + dp_size = 1 + for ax in self.data_parallelism_axes: + dp_size *= mesh.shape[ax] + global_batch_size, sequence_length, _hidden = inputs.shape topk = self.num_experts_per_tok - recv_buffer_rows = global_batch_size * sequence_length * topk + if global_batch_size % dp_size != 0: + raise ValueError( + f"batch={global_batch_size} not divisible by" + f" prod(data_parallelism_axes)={dp_size}" + ) + recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk # Pack everything that crosses the shard_map boundary into a dict # pytree. shard_map fully supports pytrees: ``in_specs`` must @@ -774,8 +813,8 @@ def _forward_a2a_ep( "wo": params["wo"], } in_specs: dict = { - "inputs": P(ep_axis, None, None), - "gate_logits": P(ep_axis, None, None), + "inputs": P(batch_pspec_axis, None, None), + "gate_logits": P(batch_pspec_axis, None, None), "wi_0": P(ep_axis, None, None), "wi_1": P(ep_axis, None, None), "wo": P(ep_axis, None, None), @@ -817,13 +856,20 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable # (the sum_t * tokens product is data-dependent across - # shards). Cheapest fix: gather logits along the EP axis and - # run the aux-loss kernel on the global tensor. The aux - # branch has no data dependency on the main routing path so - # XLA can overlap the two on the GPU. + # shards). Cheapest fix: gather logits along ALL batch + # axes (ep + any DP axes) so the kernel sees the full + # token set. The aux branch has no data dependency on the + # main routing path so XLA can overlap the two on the GPU. if self.aux_loss_coeff > 0.0: + # ``axis_name`` accepts a tuple ⇒ a single all_gather + # over the cartesian product of axes; XLA may lower + # this to one multi-axis collective or split it. + if len(self.data_parallelism_axes) == 0: + aux_gather_axes: Any = ep_axis + else: + aux_gather_axes = (ep_axis, *self.data_parallelism_axes) global_logits_2d = jax.lax.all_gather( - logits_2d, axis_name=ep_axis, axis=0, tiled=True + logits_2d, axis_name=aux_gather_axes, axis=0, tiled=True ) aux_loss = self._compute_aux_loss(global_logits_2d) else: @@ -938,6 +984,6 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: _a2a_fn, mesh=mesh, in_specs=(in_specs,), - out_specs=(P(ep_axis, None, None), P()), + out_specs=(P(batch_pspec_axis, None, None), P()), check_rep=False, )(captured) From 25e1eb80614666678ad6c288286c0d3ce4c943b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 May 2026 23:45:05 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_moe_block.py | 24 ++++------ tests/jax/test_moe_block.py | 29 ++++++------ transformer_engine/jax/flax/moe.py | 62 +++++++++---------------- transformer_engine/jax/permutation.py | 62 +++++++++---------------- 4 files changed, 65 insertions(+), 112 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 3cd902aa88..b50cec686b 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -45,9 +45,7 @@ def _inject_moe(request): def _make_inputs(key: jax.Array) -> jax.Array: - return jax.random.normal( - key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE - ) + return jax.random.normal(key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE) def _unwrap_partitioned(x): @@ -136,23 +134,17 @@ def loss_fn(block, variables, x): # logical specs to NamedShardings via the active rules; # 4. ``jax.jit(init, out_shardings=...)`` to actually # place the params on-device with those shardings. - abstract_variables = jax.eval_shape( - sharded_block.init, init_key, inputs - ) - logical_partition_spec = nn.get_partition_spec( - abstract_variables - ) + abstract_variables = jax.eval_shape(sharded_block.init, init_key, inputs) + logical_partition_spec = nn.get_partition_spec(abstract_variables) out_shardings = nn.logical_to_mesh_sharding( logical_partition_spec, mesh, logical_axis_rules ) - sharded_variables = jax.jit( - sharded_block.init, out_shardings=out_shardings - )(init_key, inputs) - (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( - jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( - sharded_block, sharded_variables, inputs - ) + sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( + init_key, inputs ) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = jax.value_and_grad( + loss_fn, argnums=1, has_aux=True + )(sharded_block, sharded_variables, inputs) wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 39a6bfd592..743e4aba69 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -64,9 +64,7 @@ def _inject_moe(request): def _make_inputs( key: jax.Array, batch_size: int = BATCH_SIZE, sequence_length: int = SEQUENCE_LENGTH ) -> jax.Array: - return jax.random.normal( - key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE - ) + return jax.random.normal(key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE) def _init_and_apply( @@ -108,9 +106,9 @@ def test_forward_shape_and_finite(self, permutation_backend): inputs = _make_inputs(data_key) _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) - assert output.shape == inputs.shape, ( - f"Unexpected output shape {output.shape} for backend {permutation_backend}" - ) + assert ( + output.shape == inputs.shape + ), f"Unexpected output shape {output.shape} for backend {permutation_backend}" assert output.dtype == inputs.dtype assert jnp.all(jnp.isfinite(output)), "Output contains NaN/Inf" assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" @@ -171,20 +169,19 @@ def loss_fn(block, variables, inputs): output, _ = block.apply(variables, inputs) return jnp.mean(output.astype(jnp.float32) ** 2), output - (loss_pj, out_pj), grads_pj = jax.value_and_grad( - loss_fn, argnums=1, has_aux=True - )(pure_block, variables, inputs) - (loss_tr, out_tr), grads_tr = jax.value_and_grad( - loss_fn, argnums=1, has_aux=True - )(triton_block, variables, inputs) + (loss_pj, out_pj), grads_pj = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + pure_block, variables, inputs + ) + (loss_tr, out_tr), grads_tr = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + triton_block, variables, inputs + ) # BF16 tolerances: outputs come out of the grouped-GEMM + weighted # sum so they accumulate error; we use ~2 ULPs worth of slack. atol_out, rtol_out = 5e-2, 5e-2 - assert jnp.allclose(out_pj, out_tr, atol=atol_out, rtol=rtol_out), ( - f"Forward outputs differ across backends: max diff" - f" {jnp.max(jnp.abs(out_pj - out_tr))}" - ) + assert jnp.allclose( + out_pj, out_tr, atol=atol_out, rtol=rtol_out + ), f"Forward outputs differ across backends: max diff {jnp.max(jnp.abs(out_pj - out_tr))}" assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index bfa00d3827..853d22679f 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -357,9 +357,9 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: Scalar auxiliary load-balancing loss when ``aux_loss_coeff > 0``, else ``None``. """ - assert inputs.ndim == 3, ( - f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" - ) + assert ( + inputs.ndim == 3 + ), f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) _, _, hidden_size = inputs.shape @@ -448,9 +448,7 @@ def _compute_aux_loss( score_function=self.score_function, compute_aux_scores=True, ) - aux_tokens_per_expert = jnp.sum( - aux_routing_map.astype(jnp.int32), axis=0 - ) + aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) return fused_moe_aux_loss( aux_scores.astype(jnp.float32), aux_tokens_per_expert, @@ -665,9 +663,7 @@ def _global_combine( pad_offsets=perm_result["pad_offsets"], ) hidden_size = out_2d.shape[-1] - return out_2d.reshape(batch_size, sequence_length, hidden_size).astype( - self.dtype - ) + return out_2d.reshape(batch_size, sequence_length, hidden_size).astype(self.dtype) # ------------------------------------------------------------------ # No-EP forward @@ -690,9 +686,7 @@ def _forward_no_ep( inputs_2d = inputs.reshape(-1, hidden_size) logits_2d = gate_logits.reshape(-1, self.num_experts) - sparse_probs, routing_map = self._route_topk( - logits_2d, params.get("expert_bias") - ) + sparse_probs, routing_map = self._route_topk(logits_2d, params.get("expert_bias")) aux_loss = self._compute_aux_loss(logits_2d) perm = self._global_permute(inputs_2d, sparse_probs, routing_map) expert_outputs = self._expert_ffn( @@ -701,9 +695,7 @@ def _forward_no_ep( params, n_groups=self.num_experts, ) - output = self._global_combine( - expert_outputs, perm, batch_size, sequence_length - ) + output = self._global_combine(expert_outputs, perm, batch_size, sequence_length) if self.tensor_parallelism_axis is not None: output = jax.lax.psum_scatter( @@ -758,10 +750,9 @@ def _forward_a2a_ep( ) mesh = self.mesh num_ep = mesh.shape[ep_axis] - assert self.num_experts % num_ep == 0, ( - f"num_experts={self.num_experts} must be divisible by EP" - f" size={num_ep}" - ) + assert ( + self.num_experts % num_ep == 0 + ), f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" num_experts_local = self.num_experts // num_ep # Compose the BATCH sharding axis tuple. ``ep`` is always part of @@ -796,8 +787,7 @@ def _forward_a2a_ep( topk = self.num_experts_per_tok if global_batch_size % dp_size != 0: raise ValueError( - f"batch={global_batch_size} not divisible by" - f" prod(data_parallelism_axes)={dp_size}" + f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" ) recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk @@ -848,9 +838,7 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: ) else: full_expert_bias = None - sparse_probs, routing_map = self._route_topk( - logits_2d, full_expert_bias - ) + sparse_probs, routing_map = self._route_topk(logits_2d, full_expert_bias) # aux_loss must see the global token batch and the global # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( @@ -905,13 +893,11 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: ) # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) - sorted_x, local_group_sizes, local_perm_state = ( - local_permute_after_a2a( - x_recv, - all_shards_tokens_per_expert, - shard_id, - num_ep, - ) + sorted_x, local_group_sizes, local_perm_state = local_permute_after_a2a( + x_recv, + all_shards_tokens_per_expert, + shard_id, + num_ep, ) # -- Stage 5: per-expert FFN (E_local groups) -- @@ -932,15 +918,11 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: ) # -- Stage 6: invert local permute -- - x_send_back = local_unpermute_before_a2a( - expert_outputs, local_perm_state - ) + x_send_back = local_unpermute_before_a2a(expert_outputs, local_perm_state) # -- Stage 7: reverse ragged_all_to_all over EP -- - in_off_r, send_sz_r, out_off_r, recv_sz_r = ( - compute_reverse_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) + in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep ) send_back_buf = jnp.zeros_like(perm["sorted_inputs"]) y_back = jax.lax.ragged_all_to_all( @@ -954,9 +936,7 @@ def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: ) # -- Stage 8: invert global permute, weighted sum over top-k -- - output = self._global_combine( - y_back, perm, batch_size=local_b, sequence_length=local_s - ) + output = self._global_combine(y_back, perm, batch_size=local_b, sequence_length=local_s) if self.tensor_parallelism_axis is not None: output = jax.lax.psum_scatter( diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index f4599a7b8f..cad31faaf2 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -700,10 +700,9 @@ def _sort_chunks_by_index_bwd_rule( @jax.custom_vjp def _sort_activations(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: """Sort ``inputs`` along the leading dim by ``sort_indices``.""" - assert inputs.shape[0] == sort_indices.shape[0], ( - f"inputs.shape[0]={inputs.shape[0]} must match" - f" sort_indices.shape[0]={sort_indices.shape[0]}" - ) + assert ( + inputs.shape[0] == sort_indices.shape[0] + ), f"inputs.shape[0]={inputs.shape[0]} must match sort_indices.shape[0]={sort_indices.shape[0]}" with jax.named_scope("unfused_sort_activations"): return inputs[sort_indices, ...] @@ -714,9 +713,7 @@ def _sort_activations_fwd( return _sort_activations(inputs, sort_indices), sort_indices -def _sort_activations_bwd( - residuals: jax.Array, grads: jax.Array -) -> Tuple[jax.Array, None]: +def _sort_activations_bwd(residuals: jax.Array, grads: jax.Array) -> Tuple[jax.Array, None]: sort_indices = residuals # Inverse permutation: gather-by-argsort undoes the forward gather. return _sort_activations(grads, jnp.argsort(sort_indices)), None @@ -838,12 +835,10 @@ def unfused_token_dispatch( # ``(align - count % align) % align`` gives 0 (not ``align``) when # already aligned, so we never exceed the per-expert slot capacity of # ``align_size - 1``. - token_count_per_expert = jnp.bincount( - flatten_selected_experts, length=num_experts - ) + token_count_per_expert = jnp.bincount(flatten_selected_experts, length=num_experts) padding_tokens_required_per_expert = ( - (align_size - (token_count_per_expert % align_size)) % align_size - ) + align_size - (token_count_per_expert % align_size) + ) % align_size # Build a static-size padding buffer of shape # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot @@ -869,9 +864,7 @@ def unfused_token_dispatch( ) if roll_to_expert_id is not None: - flatten_selected_experts = ( - flatten_selected_experts - roll_to_expert_id - ) % num_experts + flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % num_experts sorted_selected_experts = jnp.argsort(flatten_selected_experts) @@ -900,9 +893,7 @@ def unfused_token_dispatch( padding_size = max_total_padding_size else: if roll_to_expert_id is not None: - flatten_selected_experts = ( - flatten_selected_experts - roll_to_expert_id - ) % num_experts + flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % num_experts sorted_selected_experts = jnp.argsort(flatten_selected_experts) @@ -1067,10 +1058,9 @@ def compute_ragged_all_to_all_params( tokens shard ``i`` sends to this shard. """ num_experts = all_shards_tokens_per_expert.shape[1] - assert num_experts % num_expert_shards == 0, ( - f"num_experts={num_experts} must be divisible by num_expert_shards" - f"={num_expert_shards}" - ) + assert ( + num_experts % num_expert_shards == 0 + ), f"num_experts={num_experts} must be divisible by num_expert_shards={num_expert_shards}" local_expert_size = num_experts // num_expert_shards # This shard's row of the gathered table, reshaped so axis 0 indexes the @@ -1080,9 +1070,7 @@ def compute_ragged_all_to_all_params( start_indices=(shard_id, 0), slice_sizes=(1, num_experts), ).squeeze(0) - local_reshaped = local_tokens_per_expert.reshape( - num_expert_shards, local_expert_size - ) + local_reshaped = local_tokens_per_expert.reshape(num_expert_shards, local_expert_size) # send_sizes[i] = sum of token counts for shard i's experts in our buffer. send_sizes = jnp.sum(local_reshaped, axis=1) @@ -1144,10 +1132,9 @@ def compute_reverse_ragged_all_to_all_params( :func:`compute_ragged_all_to_all_params`. """ num_experts = all_shards_tokens_per_expert.shape[1] - assert num_experts % num_expert_shards == 0, ( - f"num_experts={num_experts} must be divisible by num_expert_shards" - f"={num_expert_shards}" - ) + assert ( + num_experts % num_expert_shards == 0 + ), f"num_experts={num_experts} must be divisible by num_expert_shards={num_expert_shards}" local_expert_size = num_experts // num_expert_shards local_expert_start = shard_id * local_expert_size @@ -1175,9 +1162,7 @@ def compute_reverse_ragged_all_to_all_params( start_indices=(shard_id, 0), slice_sizes=(1, num_experts), ).squeeze(0) - local_reshaped = local_tokens_per_expert.reshape( - num_expert_shards, local_expert_size - ) + local_reshaped = local_tokens_per_expert.reshape(num_expert_shards, local_expert_size) recv_sizes = jnp.sum(local_reshaped, axis=1) # output_offsets: the reverse sends-to-target matrix is the transpose of @@ -1264,10 +1249,9 @@ def local_permute_after_a2a( Opaque state for :func:`local_unpermute_before_a2a`. """ num_experts = all_shards_tokens_per_expert.shape[1] - assert num_experts % num_expert_shards == 0, ( - f"num_experts={num_experts} must be divisible by num_expert_shards" - f"={num_expert_shards}" - ) + assert ( + num_experts % num_expert_shards == 0 + ), f"num_experts={num_experts} must be divisible by num_expert_shards={num_expert_shards}" local_expert_size = num_experts // num_expert_shards local_expert_start = shard_id * local_expert_size local_expert_columns = jax.lax.dynamic_slice( @@ -1283,9 +1267,9 @@ def local_permute_after_a2a( # Permutation that maps source-major -> expert-major: # original index = s * E_local + e # target index = e * num_shards + s - indices_matrix = jnp.arange( - num_expert_shards * local_expert_size, dtype=jnp.int32 - ).reshape(num_expert_shards, local_expert_size) + indices_matrix = jnp.arange(num_expert_shards * local_expert_size, dtype=jnp.int32).reshape( + num_expert_shards, local_expert_size + ) sorted_chunk_indices = indices_matrix.T.reshape(-1) sorted_x, _ = sort_chunks_by_index(x_recv, split_sizes, sorted_chunk_indices) From d7fef5a30bf4d6b05d3c8b84533a1ce0b6cd92f7 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 5 May 2026 17:46:33 -0700 Subject: [PATCH 08/18] address greptile comments Signed-off-by: tdophung --- transformer_engine/jax/flax/moe.py | 13 +++++++++---- transformer_engine/jax/permutation.py | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 853d22679f..6f3986e9b3 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -65,7 +65,7 @@ from ..dense import grouped_dense from ..permutation import ( - _routing_map_to_selected_experts, + routing_map_to_selected_experts, compute_ragged_all_to_all_params, compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, @@ -492,7 +492,7 @@ def _global_permute( topk = self.num_experts_per_tok if self.permutation_backend == "pure_jax": - selected_experts, routing_weights = _routing_map_to_selected_experts( + selected_experts, routing_weights = routing_map_to_selected_experts( sparse_probs, routing_map, topk ) sorted_inputs, perm_state, group_sizes = unfused_token_dispatch( @@ -715,7 +715,7 @@ def _forward_a2a_ep( inputs: jnp.ndarray, gate_logits: jnp.ndarray, params: dict, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Wrap the body in a ``shard_map`` that runs a forward ``ragged_all_to_all`` (A2A / A2Av) around the FFN. @@ -785,7 +785,12 @@ def _forward_a2a_ep( global_batch_size, sequence_length, _hidden = inputs.shape topk = self.num_experts_per_tok - if global_batch_size % dp_size != 0: + # The shard_map's ``in_specs=P((ep, *dp_axes), ...)`` requires the + # batch dim to be divisible by ``num_ep * dp_size``; check upfront + # here for a clearer error than the one shard_map would raise at + # trace time. + batch_divisor = num_ep * dp_size + if global_batch_size % batch_divisor != 0: raise ValueError( f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" ) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index cad31faaf2..ba271b6b86 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -57,6 +57,7 @@ "compute_reverse_ragged_all_to_all_params", "local_permute_after_a2a", "local_unpermute_before_a2a", + "routing_map_to_selected_experts", ] @@ -722,7 +723,7 @@ def _sort_activations_bwd(residuals: jax.Array, grads: jax.Array) -> Tuple[jax.A _sort_activations.defvjp(_sort_activations_fwd, _sort_activations_bwd) -def _routing_map_to_selected_experts( +def routing_map_to_selected_experts( sparse_probs: jnp.ndarray, routing_map: jnp.ndarray, topk: int, From 3a517083836ffcab60a6d037334cfde5b1b75f4d Mon Sep 17 00:00:00 2001 From: JAX Toolbox Date: Thu, 7 May 2026 15:18:44 -0700 Subject: [PATCH 09/18] address jeremys comments + relax the sum(group_size) <= dim_m constraint in C++ files, make FP8 works. Tested with current scaling Signed-off-by: JAX Toolbox --- tests/jax/test_distributed_moe_block.py | 35 +- tests/jax/test_moe_block.py | 190 ++++- .../common/util/multi_stream.cpp | 66 +- transformer_engine/jax/cpp_extensions/gemm.py | 9 +- .../jax/csrc/extensions/gemm.cpp | 14 +- .../jax/csrc/extensions/quantization.cpp | 14 +- transformer_engine/jax/flax/moe.py | 731 +++++++++++------- transformer_engine/jax/permutation.py | 60 +- 8 files changed, 751 insertions(+), 368 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index b50cec686b..bb15ed8c95 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -74,18 +74,29 @@ def test_ep2_fsdp2_matches_single_device(self, permutation_backend): single_block = MoEBlock(**base_kwargs) - def loss_fn(block, variables, x): - output, aux_loss = block.apply(variables, x) - loss = jnp.mean(output.astype(jnp.float32) ** 2) - if aux_loss is not None: - loss = loss + aux_loss.astype(jnp.float32) - return loss, (output, aux_loss) + def _make_loss_and_grad(block): + """Build a jitted ``value_and_grad`` over ``(variables, x)``. + + Capturing ``block`` in a closure (so it isn't a jit input) + sidesteps having to mark it as static -- Flax modules are + registered pytrees but they carry Python-level config that + jit treats as part of the trace. + """ + + def loss_fn(variables, x): + output, aux_loss = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if aux_loss is not None: + loss = loss + aux_loss.astype(jnp.float32) + return loss, (output, aux_loss) + + return jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) with autocast(enabled=False, mesh_resource=MeshResource()): single_variables = single_block.init(init_key, inputs) - (single_loss, (single_output, single_aux)), single_grads = jax.value_and_grad( - loss_fn, argnums=1, has_aux=True - )(single_block, single_variables, inputs) + (single_loss, (single_output, single_aux)), single_grads = _make_loss_and_grad( + single_block + )(single_variables, inputs) devices = np.asarray(jax.devices()[:4]).reshape(2, 2) mesh = Mesh(devices, ("ep", "fsdp")) @@ -142,9 +153,9 @@ def loss_fn(block, variables, x): sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( init_key, inputs ) - (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = jax.value_and_grad( - loss_fn, argnums=1, has_aux=True - )(sharded_block, sharded_variables, inputs) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( + _make_loss_and_grad(sharded_block)(sharded_variables, inputs) + ) wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 743e4aba69..ed5e0529c5 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -114,7 +114,7 @@ def test_forward_shape_and_finite(self, permutation_backend): assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_backward_grad(self, permutation_backend): + def test_backward_grad_is_finite_and_nonzero(self, permutation_backend): key = jax.random.PRNGKey(1) init_key, data_key = jax.random.split(key) @@ -184,12 +184,24 @@ def loss_fn(block, variables, inputs): ), f"Forward outputs differ across backends: max diff {jnp.max(jnp.abs(out_pj - out_tr))}" assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) + # The two backends share the routing path (same fused top-k) and + # the same expert FFN; the only difference is the order of the + # gather + scatter ops in dispatch/combine. Under bf16 with these + # small shapes, observed grad max-abs-diff is on the order of a + # few-units-of-bf16-eps (~1e-2). 5e-2 / 5e-2 leaves headroom for + # accumulation jitter without masking real divergence. If this + # tightens too far on a particular GPU, print + # ``jnp.max(jnp.abs(g_pj - g_tr))`` from the failing assertion + # and bump to the next safe value with a comment recording the + # measured gap. + atol_grad, rtol_grad = 5e-2, 5e-2 for name in ("gate_kernel", "wi_0", "wi_1", "wo"): g_pj = _unwrap_partitioned(grads_pj["params"][name]) g_tr = _unwrap_partitioned(grads_tr["params"][name]) - assert jnp.allclose(g_pj, g_tr, atol=1e-1, rtol=1e-1), ( + assert jnp.allclose(g_pj, g_tr, atol=atol_grad, rtol=rtol_grad), ( f"Gradient for {name} differs across backends: max diff" - f" {jnp.max(jnp.abs(g_pj - g_tr))}" + f" {jnp.max(jnp.abs(g_pj - g_tr))} (atol={atol_grad}," + f" rtol={rtol_grad})" ) @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) @@ -215,6 +227,134 @@ def test_aux_loss_returned(self, permutation_backend): # With uniform-ish routing the loss should be small-positive, not huge. assert jnp.abs(aux_loss) < 1e2 + def test_aux_loss_uses_real_routing_under_group_topk(self): + """Regression test for PR #2912 review (greptile P1). + + Under DeepSeek-style ``num_groups`` / ``group_topk`` routing, + the auxiliary load-balancing loss must be computed using the + per-expert token counts from the *real* routing_map (post + grouping), not from the clean top-k that the + ``compute_aux_scores=True`` kernel returns. Otherwise the aux + objective trains against the wrong distribution. + + We compute three values: + * ``corrected_ref`` -- ``fused_moe_aux_loss(aux_scores, + tokens_from_real_routing_map, ...)`` (what the block + should produce after the fix). + * ``buggy_ref`` -- ``fused_moe_aux_loss(aux_scores, + tokens_from_aux_routing_map, ...)`` (what the block used + to produce before the fix). + * ``block_aux_loss`` -- what the block actually produces. + + Block must match the corrected reference. We also assert that + the corrected and buggy references differ for this config so + the test is not vacuously satisfied by them coinciding. + """ + from transformer_engine.jax.router import ( + fused_moe_aux_loss, + fused_topk_with_score_function, + ) + + key = jax.random.PRNGKey(7) + init_key, data_key = jax.random.split(key) + + # Pick a config that *reliably* exercises grouped-vs-clean + # divergence: with ``group_topk=1`` only ONE group's experts + # can be selected by grouped routing, so the routing diverges + # from a plain top-k whenever the global top-K experts are + # spread across multiple groups (which is almost always the + # case for random init + ``num_experts_per_tok > 1``). + num_groups = 2 + group_topk = 1 + aux_loss_coeff = 1e-2 + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend="pure_jax", + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + aux_loss_coeff=aux_loss_coeff, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + _output, block_aux_loss = block.apply(variables, inputs) + + assert block_aux_loss is not None + + # Reproduce the gating GEMM and routing externally so we can + # build the references against the same logits the block sees. + gate_kernel = _unwrap_partitioned(variables["params"]["gate_kernel"]) + gate_kernel = gate_kernel.astype(inputs.dtype) + logits = jnp.einsum("bsh,he->bse", inputs, gate_kernel) + logits_2d = logits.reshape(-1, NUM_EXPERTS) + + # Real routing (with grouping). This is what _route_topk + # would produce inside the block. + _, real_routing_map = fused_topk_with_score_function( + logits_2d, + topk=NUM_EXPERTS_PER_TOK, + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + ) + real_tokens = jnp.sum(real_routing_map.astype(jnp.int32), axis=0) + + # Aux scores + the (clean topk) aux_routing_map that the old + # buggy code used for tokens_per_expert. + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d.astype(jnp.float32), + topk=NUM_EXPERTS_PER_TOK, + score_function="sigmoid", + compute_aux_scores=True, + ) + buggy_tokens = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) + + corrected_ref = fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + real_tokens, + topk=NUM_EXPERTS_PER_TOK, + coeff=aux_loss_coeff, + ) + buggy_ref = fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + buggy_tokens, + topk=NUM_EXPERTS_PER_TOK, + coeff=aux_loss_coeff, + ) + + # Sanity: the test config must actually exercise the bug + # (otherwise both references coincide and the assertion below + # would silently pass even with the old code). + assert not jnp.allclose(real_tokens, buggy_tokens), ( + "Test config does not exercise grouped-topk vs clean-topk" + " divergence; pick a config where they differ" + ) + + assert jnp.allclose(block_aux_loss, corrected_ref, atol=1e-5, rtol=1e-5), ( + f"Block aux_loss {block_aux_loss} does not match" + f" real-routing reference {corrected_ref}" + ) + # The corrected and buggy refs can be numerically close + # (only the mis-routed tokens contribute to the difference), + # so assert that the block is *strictly closer* to the + # corrected ref than to the buggy one. This catches the + # regression robustly even when the absolute gap between + # corrected_ref and buggy_ref is sub-tolerance. + diff_to_corrected = jnp.abs(block_aux_loss - corrected_ref) + diff_to_buggy = jnp.abs(block_aux_loss - buggy_ref) + gap = jnp.abs(corrected_ref - buggy_ref) + assert diff_to_corrected < diff_to_buggy, ( + f"Block aux_loss {block_aux_loss} is closer to the *old" + f" buggy* reference ({buggy_ref}, diff={diff_to_buggy})" + f" than to the corrected reference ({corrected_ref}," + f" diff={diff_to_corrected}); the regression has" + f" reappeared. corrected-buggy gap = {gap}" + ) + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) def test_group_topk_deepseek(self, permutation_backend): """Exercise DeepSeek-style grouped top-k routing.""" @@ -240,28 +380,26 @@ def test_group_topk_deepseek(self, permutation_backend): assert output.shape == inputs.shape assert jnp.all(jnp.isfinite(output)) - @pytest.mark.xfail( - reason=( - "TE grouped_dense FFI asserts sum(group_sizes) == M at " - "transformer_engine/jax/csrc/extensions/gemm.cpp:1029. With " - "align_size > 0 both backends produce a buffer where M >= " - "sum(group_sizes) (the slack is structural padding for JIT). " - "The kernel itself iterates over per-expert m_i from " - "group_sizes via nvte_multi_tensor_gemm and never reads past " - "sum(group_sizes), so relaxing that assertion to " - "`m >= sum_group_sizes` is the cleanest fix. The MoE block " - "deliberately does not fold the gap into a single expert " - "(that would create per-shard load imbalance under EP). " - "Re-enable once the FFI check is relaxed." - ), - strict=False, - ) - def test_align_size_equivalence_pure_jax(self): + def test_align_size_equivalence_pure_jax(self, monkeypatch): """For the pure-JAX backend, ``align_size > 0`` must not change the numerical output of the forward pass: padding tokens contribute zero to every expert GEMM output (their input rows are zeros) and are stripped before the weighted sum. + + Why the env knob: the V1 TE grouped GEMM FFI asserts + ``sum(group_sizes) == M`` at + ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``. With + ``align_size > 0`` the pure-JAX backend produces a buffer where + ``M >= sum(group_sizes)`` (the slack is structural padding for + JIT). The V2 grouped GEMM relaxes that assertion to + ``M >= sum(group_sizes)`` and is selected when + ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1``. If V2 isn't supported on + this hardware / for this dtype, the dispatch raises a + ``RuntimeError`` whose message is matched here so the test + ``skip``-s instead of failing. """ + monkeypatch.setenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "1") + key = jax.random.PRNGKey(5) init_key, data_key = jax.random.split(key) @@ -275,10 +413,16 @@ def test_align_size_equivalence_pure_jax(self): block_no_pad = MoEBlock(align_size=0, **base_kwargs) block_pad = MoEBlock(align_size=16, **base_kwargs) inputs = _make_inputs(data_key) - variables = block_no_pad.init(init_key, inputs) - out_no_pad, _ = block_no_pad.apply(variables, inputs) - out_pad, _ = block_pad.apply(variables, inputs) + try: + variables = block_no_pad.init(init_key, inputs) + out_no_pad, _ = block_no_pad.apply(variables, inputs) + out_pad, _ = block_pad.apply(variables, inputs) + except RuntimeError as exc: + if "V2 grouped GEMM is not supported" in str(exc): + pytest.skip(f"V2 grouped GEMM unavailable on this hardware: {exc}") + raise + assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( "align_size > 0 must not change pure_jax forward output; max diff" f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" diff --git a/transformer_engine/common/util/multi_stream.cpp b/transformer_engine/common/util/multi_stream.cpp index 6b19f36741..ec341abc68 100644 --- a/transformer_engine/common/util/multi_stream.cpp +++ b/transformer_engine/common/util/multi_stream.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include "cuda_runtime.h" @@ -19,18 +20,54 @@ namespace transformer_engine::detail { +namespace { + +// CUDA streams and events are device-bound: a stream / event created +// on device A cannot be recorded into / waited on from device B +// (CUDA returns ``cudaErrorInvalidResourceHandle``). The previous +// implementation used ``std::call_once`` to lazily create one +// process-global vector of streams + one of events, which works for +// the single-device case (PyTorch eager / single-host single-device +// JAX) but breaks for single-process *multi*-device JAX: the first +// worker thread to win the ``call_once`` would create streams / +// events on its own device, and subsequent calls from other devices +// would receive those same handles and fail at ``cudaEventRecord``. +// +// We now key the cache on the active CUDA device. Each device gets +// its own ``num_compute_streams`` streams and events, created lazily +// the first time a thread on that device asks for one. +template +auto& per_device_pool(CreateFn&& create) { + static std::mutex mu; + using PoolT = decltype(std::vector{create()}); + static std::unordered_map pools; + int device; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + std::lock_guard lock(mu); + auto it = pools.find(device); + if (it == pools.end()) { + const size_t num_streams = nvte_get_num_compute_streams(); + PoolT v; + v.reserve(num_streams); + for (size_t i = 0; i < num_streams; i++) { + v.push_back(create()); + } + it = pools.emplace(device, std::move(v)).first; + } + return it->second; +} + +} // namespace + cudaStream_t get_compute_stream(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); - static std::vector streams(num_streams); - static std::once_flag stream_init_flag; - auto init = [&]() { - for (size_t i = 0; i < num_streams; i++) { - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1)); - } - }; - std::call_once(stream_init_flag, init); + auto& streams = per_device_pool([] { + cudaStream_t s; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&s, cudaStreamNonBlocking, -1)); + return s; + }); return streams[idx]; } @@ -38,14 +75,11 @@ cudaEvent_t get_compute_stream_event(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); - static std::vector events(num_streams); - static std::once_flag event_init_flag; - auto init = [&]() { - for (size_t i = 0; i < num_streams; i++) { - NVTE_CHECK_CUDA(cudaEventCreate(&events[i])); - } - }; - std::call_once(event_init_flag, init); + auto& events = per_device_pool([] { + cudaEvent_t e; + NVTE_CHECK_CUDA(cudaEventCreate(&e)); + return e; + }); return events[idx]; } diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ff6d07986..94b2de9573 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2024,9 +2024,14 @@ def grouped_gemm_copy_group_sizes( return out -@cache def _should_enforce_v2_grouped_gemm() -> bool: - """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" + """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM. + + Not cached so tests can flip the env var with ``monkeypatch.setenv`` + and have it picked up on the next call. This is called only on + grouped-GEMM dispatch (not in any tight loop), so the per-call + ``getenv`` cost is negligible. + """ val = os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") try: return bool(int(val)) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6ca907032c..8a807cbdcc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1157,12 +1157,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type cudaStreamSynchronize(stream); } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // Allow callers to pass an LHS/RHS that is at least as large as the active + // ragged region (sum_group_sizes). This supports ragged-all-to-all flows + // where the recv buffer is over-allocated to a worst-case size and only + // the first sum_group_sizes rows along the ragged dim are populated; the + // trailing slack rows are not consumed by the per-group GEMMs (which key + // off group_sizes). if (!is_rhs_ragged) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); + NVTE_CHECK(sum_group_sizes <= m, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, + " must be <= M = ", m); } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, + " must be <= K = ", k); } } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 650139a61c..871abb5634 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -383,9 +383,17 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty cudaStreamSynchronize(stream); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, - "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m, - input_dims[0]); + // Allow callers to pass an input that is at least as large as the active + // ragged region (sum_group_sizes). This supports ragged-all-to-all flows + // where the recv buffer is over-allocated to a worst-case size and only the + // first sum_group_sizes rows are populated; the trailing slack rows are + // simply not quantized (and not consumed by the downstream grouped GEMM + // which is also keyed on group_sizes). + // For flatten_axis==1, m == input_dims[0]; for flatten_axis>1, the per-group + // tile is dim_list_host[i] * non_group_m, so the binding dim is input_dims[0]. + NVTE_CHECK(sum_group_sizes <= input_dims[0], + "Unexpected group_sizes! sum(group_sizes)=%zu must be <= input_dims[0]=%zu (M=%zu)", + sum_group_sizes, input_dims[0], m); if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 6f3986e9b3..a882ddfce6 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -6,7 +6,7 @@ This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer that wires together TE's fused router, a selectable token-dispatch backend -(pure-JAX ``unfused_*`` or fused Triton), TE's ``grouped_dense``, and an +(``pure_jax`` or ``triton``), TE's ``grouped_dense``, and an optional ragged-all-to-all (A2A / A2Av) expert-parallelism strategy. Architecture @@ -56,11 +56,12 @@ ``align_size > 0`` tests stay xfail. """ +from functools import partial from typing import Any, Callable, NewType, Optional, Tuple, Union import jax import jax.numpy as jnp -from flax import linen as nn +from flax import linen as nn, struct as flax_struct from jax.sharding import PartitionSpec as P from ..dense import grouped_dense @@ -70,10 +71,11 @@ compute_reverse_ragged_all_to_all_params, local_permute_after_a2a, local_unpermute_before_a2a, + PureJaxPermState, + pure_jax_token_combine, + pure_jax_token_dispatch, token_combine, token_dispatch, - unfused_token_combine, - unfused_token_dispatch, ) from ..quantize import noop_quantizer_set from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function @@ -87,7 +89,36 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["MoEBlock"] +__all__ = ["GlobalPermuteResult", "MoEBlock"] + + +# ============================================================================= +# GlobalPermuteResult +# ============================================================================= +# +# Output of :meth:`MoEBlock._global_permute`. Carried as a pytree (so it +# crosses ``jax.shard_map`` / ``jax.value_and_grad`` boundaries +# transparently) and consumed by :meth:`MoEBlock._global_combine`. The +# fields populated depend on the permutation backend; the unused fields +# stay ``None``. +# +# Per-backend payloads (anything else is ``None``): +# pure_jax: ``perm_state``, ``routing_weights`` +# triton: ``row_id_map``, ``pad_offsets``, ``merging_probs`` + + +@flax_struct.dataclass +class GlobalPermuteResult: + """Result of :meth:`MoEBlock._global_permute`.""" + + sorted_inputs: jnp.ndarray + group_sizes: jnp.ndarray + perm_state: Optional[PureJaxPermState] = None + routing_weights: Optional[jnp.ndarray] = None + row_id_map: Optional[jnp.ndarray] = None + pad_offsets: Optional[jnp.ndarray] = None + merging_probs: Optional[jnp.ndarray] = None + backend: str = flax_struct.field(pytree_node=False, default="pure_jax") # ============================================================================= @@ -106,8 +137,8 @@ class MoEBlock(TransformerEngineBase): Two permutation backends are pluggable via ``permutation_backend``: * ``"pure_jax"`` (default) -- argsort-based - :func:`~transformer_engine.jax.permutation.unfused_token_dispatch` / - :func:`~transformer_engine.jax.permutation.unfused_token_combine`. + :func:`~transformer_engine.jax.permutation.pure_jax_token_dispatch` / + :func:`~transformer_engine.jax.permutation.pure_jax_token_combine`. Faster than Triton in profiling for DeepSeek-style configs. * ``"triton"`` -- TE's fused :func:`~transformer_engine.jax.permutation.token_dispatch` / @@ -273,17 +304,46 @@ def __post_init__(self): super().__post_init__() # ------------------------------------------------------------------ - # Parameter registration + # Entry point # ------------------------------------------------------------------ - def _make_params(self, hidden_size: int) -> dict: - """Register module parameters and return them as a dict.""" - gate_kernel = self.param( - "gate_kernel", - nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), - (hidden_size, self.num_experts), - self.dtype, - ) + @nn.compact + def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: + """Run the MoE forward pass. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[batch, sequence, hidden]``. + + Returns + ------- + output : jnp.ndarray + Output tensor of shape ``[batch, sequence, hidden]``. + aux_loss : Optional[jnp.ndarray] + Scalar auxiliary load-balancing loss when + ``aux_loss_coeff > 0``, else ``None``. + """ + assert ( + inputs.ndim == 3 + ), f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" + inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) + + _, _, hidden_size = inputs.shape + + # Param registrations are inlined here (not in a helper) so each + # ``self.param`` lives close to the rest of the entry point. + # Note: under EP the FFN weights and ``expert_bias`` are + # consumed *inside* a ``shard_map`` body. Flax's ``self.param`` + # must run OUTSIDE any JAX transform that would alter the + # variable scope (``shard_map`` does), so the registrations stay + # here in ``__call__`` and the values are passed down explicitly + # via ``in_specs``. ``_gate`` is called outside ``shard_map`` in + # both paths, so its kernel is registered inline inside + # ``_gate`` itself rather than here. + + gate_logits = self._gate(inputs) + wi_0 = self.param( "wi_0", nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), @@ -302,78 +362,59 @@ def _make_params(self, hidden_size: int) -> dict: (self.num_experts, self.intermediate_size, hidden_size), self.dtype, ) - params: dict = { - "gate_kernel": gate_kernel, - "wi_0": wi_0, - "wi_1": wi_1, - "wo": wo, - } + wi_0_bias = wi_1_bias = wo_bias = None if self.use_bias: - params["wi_0_bias"] = self.param( + wi_0_bias = self.param( "wi_0_bias", nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), (self.num_experts, self.intermediate_size), self.dtype, ) - params["wi_1_bias"] = self.param( + wi_1_bias = self.param( "wi_1_bias", nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), (self.num_experts, self.intermediate_size), self.dtype, ) - params["wo_bias"] = self.param( + wo_bias = self.param( "wo_bias", nn.with_logical_partitioning(self.bias_init, ("exp", "embed")), (self.num_experts, hidden_size), self.dtype, ) + expert_bias = None if self.use_expert_bias: - params["expert_bias"] = self.param( + expert_bias = self.param( "expert_bias", nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), (self.num_experts,), self.dtype, ) - return params - - # ------------------------------------------------------------------ - # Entry point - # ------------------------------------------------------------------ - - @nn.compact - def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: - """Run the MoE forward pass. - - Parameters - ---------- - inputs : jnp.ndarray - Input tensor of shape ``[batch, sequence, hidden]``. - - Returns - ------- - output : jnp.ndarray - Output tensor of shape ``[batch, sequence, hidden]``. - aux_loss : Optional[jnp.ndarray] - Scalar auxiliary load-balancing loss when - ``aux_loss_coeff > 0``, else ``None``. - """ - assert ( - inputs.ndim == 3 - ), f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" - inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) - - _, _, hidden_size = inputs.shape - params = self._make_params(hidden_size) - - # The gate runs OUTSIDE any EP shard_map: under EP each shard - # projects only its local slice of tokens, producing local gate - # logits with the same per-shard layout as ``inputs``. - gate_logits = self._gate(inputs, params["gate_kernel"]) if self.expert_parallelism_axis is None: - output, aux_loss = self._forward_no_ep(inputs, gate_logits, params) + output, aux_loss = self._forward_no_ep( + inputs, + gate_logits, + wi_0=wi_0, + wi_1=wi_1, + wo=wo, + wi_0_bias=wi_0_bias, + wi_1_bias=wi_1_bias, + wo_bias=wo_bias, + expert_bias=expert_bias, + ) else: - output, aux_loss = self._forward_a2a_ep(inputs, gate_logits, params) + output, aux_loss = self._forward_a2a_ep( + inputs, + gate_logits, + wi_0=wi_0, + wi_1=wi_1, + wo=wo, + wi_0_bias=wi_0_bias, + wi_1_bias=wi_1_bias, + wo_bias=wo_bias, + expert_bias=expert_bias, + ) if self.aux_loss_coeff <= 0.0: aux_loss = None @@ -383,14 +424,34 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: # Gate # ------------------------------------------------------------------ - def _gate(self, inputs: jnp.ndarray, gate_kernel: jnp.ndarray) -> jnp.ndarray: + def _gate(self, inputs: jnp.ndarray) -> jnp.ndarray: """Linear gate projection ``inputs @ gate_kernel``. Kept as a plain ``einsum`` (not ``DenseGeneral``) so it composes cleanly with the EP shard_map: the gate runs in the outer (pre-shard_map) scope and its output passes through the - ``shard_map`` boundary unchanged. + ``shard_map`` boundary unchanged. Because the gate runs outside + any ``shard_map`` body in both EP and no-EP forwards, the + ``gate_kernel`` parameter is registered inline here. + + The gating GEMM is intentionally kept in ``self.dtype`` (typically + ``bfloat16``) and is **not** autocast to FP8 even when the caller + wraps the block in :func:`transformer_engine.jax.autocast`. Two + reasons: (1) the GEMM is tiny (``H * E`` with ``E`` small) and + contributes well under 1% of the block's compute, so quantization + savings are marginal; (2) the resulting logits feed a top-k + + softmax (or sigmoid) routing decision that is sensitive to + quantization noise -- routing flips at low-confidence tokens + could materially hurt model quality. To override, wrap the call + site in your own ``autocast`` and manually replace this method. """ + hidden_size = inputs.shape[-1] + gate_kernel = self.param( + "gate_kernel", + nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), + (hidden_size, self.num_experts), + self.dtype, + ) kernel = gate_kernel.astype(inputs.dtype) return jnp.einsum("bsh,he->bse", inputs, kernel) @@ -427,31 +488,48 @@ def _route_topk( def _compute_aux_loss( self, logits_2d: jnp.ndarray, + tokens_per_expert: jnp.ndarray, ) -> Optional[jnp.ndarray]: """Compute the MoE auxiliary load-balancing loss. - The score-for-aux kernel has no data dependency on the main - routing kernel, so XLA can overlap them on the GPU. + The score-for-aux kernel reads only ``logits_2d`` and the final + reduction reads only the (already-computed) ``tokens_per_expert``, + so the aux scores can run concurrently with the main routing + path on the GPU. ``logits_2d`` should be the *full* logits tensor over the global token batch -- under EP the caller is responsible for :func:`jax.lax.all_gather` ing the logits before calling this so the aux_loss formula ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` - sees the global ``T`` and the global ``tokens_per_expert``. + sees the global ``T``. + + ``tokens_per_expert`` must be the per-expert token-assignment + count from the *actual* routing decision -- i.e. derived from + ``_route_topk``'s ``routing_map``, not recomputed from a clean + top-k. This matters under DeepSeek-style routing + (``num_groups > 0`` / ``group_topk > 0``) where the + post-grouping routing differs from a plain top-k. Under EP the + caller is responsible for summing over all (ep + dp) shards + first so the count is global. """ if self.aux_loss_coeff <= 0.0: return None - aux_scores, aux_routing_map = fused_topk_with_score_function( + # The "compute_aux_scores=True" kernel intentionally ignores + # num_groups/group_topk/expert_bias and returns the dense + # post-score-function scores over all experts. Those scores are + # what the aux-loss formula expects (raw scoring, no grouping + # bias); the routing decisions used for ``tokens_per_expert`` + # come from the caller-supplied real ``routing_map``. + aux_scores, _ = fused_topk_with_score_function( logits_2d.astype(jnp.float32), topk=self.num_experts_per_tok, score_function=self.score_function, compute_aux_scores=True, ) - aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) return fused_moe_aux_loss( aux_scores.astype(jnp.float32), - aux_tokens_per_expert, + tokens_per_expert.astype(jnp.int32), topk=self.num_experts_per_tok, coeff=self.aux_loss_coeff, ) @@ -465,28 +543,15 @@ def _global_permute( inputs_2d: jnp.ndarray, sparse_probs: jnp.ndarray, routing_map: jnp.ndarray, - ) -> dict: + ) -> GlobalPermuteResult: """Dispatch tokens to the global expert axis. - Returns a permutation-result dict suitable both for the no-EP - forward (where the same buffer feeds ``_expert_ffn`` directly) and - for the A2A-EP path (where the buffer is sliced + sent over the EP - axis before the FFN). The dict carries the per-backend opaque - state needed to invert the dispatch in :meth:`_global_combine`. - - The output dict layout is:: - - { - "backend": "pure_jax" | "triton", - "sorted_inputs": [buffer_size, hidden], - "group_sizes": [num_experts], # per-expert, - # length == E always. - "perm_state": UnfusedPermState | None, # pure_jax - "row_id_map": jnp.ndarray | None, # triton - "pad_offsets": jnp.ndarray | None, # triton - "routing_weights": jnp.ndarray | None, # pure_jax - "merging_probs": jnp.ndarray | None, # triton - } + Returns a :class:`GlobalPermuteResult` suitable both for the + no-EP forward (where the same buffer feeds ``_expert_ffn`` + directly) and for the A2A-EP path (where the buffer is sliced + + sent over the EP axis before the FFN). The result carries the + per-backend opaque state needed to invert the dispatch in + :meth:`_global_combine`. """ num_tokens = inputs_2d.shape[0] topk = self.num_experts_per_tok @@ -495,20 +560,20 @@ def _global_permute( selected_experts, routing_weights = routing_map_to_selected_experts( sparse_probs, routing_map, topk ) - sorted_inputs, perm_state, group_sizes = unfused_token_dispatch( + sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( inputs_2d, selected_experts, num_experts=self.num_experts, num_experts_per_tok=topk, align_size=self.align_size, ) - return { - "backend": "pure_jax", - "sorted_inputs": sorted_inputs, - "group_sizes": group_sizes, - "perm_state": perm_state, - "routing_weights": routing_weights, - } + return GlobalPermuteResult( + backend="pure_jax", + sorted_inputs=sorted_inputs, + group_sizes=group_sizes, + perm_state=perm_state, + routing_weights=routing_weights, + ) # triton num_out_tokens = num_tokens * topk @@ -526,14 +591,14 @@ def _global_permute( probs=sparse_probs, align_size=align_size_arg, ) - return { - "backend": "triton", - "sorted_inputs": sorted_inputs, - "group_sizes": group_sizes, - "row_id_map": row_id_map, - "pad_offsets": pad_offsets, - "merging_probs": sparse_probs, - } + return GlobalPermuteResult( + backend="triton", + sorted_inputs=sorted_inputs, + group_sizes=group_sizes, + row_id_map=row_id_map, + pad_offsets=pad_offsets, + merging_probs=sparse_probs, + ) # ------------------------------------------------------------------ # Expert FFN (three grouped_dense calls + activation) @@ -543,11 +608,21 @@ def _expert_ffn( self, sorted_inputs: jnp.ndarray, group_sizes: jnp.ndarray, - params: dict, n_groups: int, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Run the per-expert SwiGLU-style FFN over a permuted buffer. + All ``wi_*`` / ``wo`` weights and the optional biases are passed + in as explicit args (rather than registered inline here) because + in the EP path this method runs *inside* a ``shard_map`` body + and Flax param registration must happen outside that scope. + Parameters ---------- sorted_inputs : jnp.ndarray @@ -558,24 +633,26 @@ def _expert_ffn( ``sum(group_sizes)`` must equal ``buffer_size`` (TE ``grouped_dense`` FFI assertion at ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``). - params : dict - Block parameters from :meth:`_make_params`. Reads ``wi_0``, - ``wi_1``, ``wo``, and the optional bias entries. n_groups : int Number of expert groups. Equals ``self.num_experts`` for the no-EP path and ``num_experts // num_ep`` for the A2A-EP path. Used to size the per-call quantizer set so the FP8 metadata tensors match ``group_sizes``. + wi_0, wi_1, wo : jnp.ndarray + Expert weight tensors. Shapes (no-EP): + ``(num_experts, hidden, intermediate)`` for wi_*, + ``(num_experts, intermediate, hidden)`` for wo. Under EP + the leading expert dim is sliced to ``num_experts // num_ep``. + wi_0_bias, wi_1_bias, wo_bias : Optional[jnp.ndarray] + Optional per-expert biases (shape ``(num_experts, N)``); + ``grouped_dense`` adds ``bias[i]`` to the rows belonging to + expert ``i`` in the permuted layout. Returns ------- expert_outputs : jnp.ndarray ``[buffer_size, hidden]``. """ - wi_0 = params["wi_0"] - wi_1 = params["wi_1"] - wo = params["wo"] - # Each grouped_dense call gets its own quantizer_set with # n_groups matching ``group_sizes``; this keeps the FP8 meta # tensors correctly sized in both no-EP and A2A-EP cases. @@ -592,13 +669,6 @@ def _expert_ffn( if q_set_wo == noop_quantizer_set: wo = wo.astype(sorted_inputs.dtype) - # ``grouped_dense`` accepts per-expert bias of shape (G, N); it - # adds ``bias[i]`` to the ``group_sizes[i]`` rows belonging to - # expert ``i`` in the permuted layout. - wi_0_bias = params.get("wi_0_bias") if self.use_bias else None - wi_1_bias = params.get("wi_1_bias") if self.use_bias else None - wo_bias = params.get("wo_bias") if self.use_bias else None - layer_w0 = grouped_dense( sorted_inputs, wi_0, @@ -636,7 +706,7 @@ def _expert_ffn( def _global_combine( self, expert_outputs: jnp.ndarray, - perm_result: dict, + perm_result: GlobalPermuteResult, batch_size: int, sequence_length: int, ) -> jnp.ndarray: @@ -645,12 +715,11 @@ def _global_combine( Gathers per-expert outputs back into ``[batch, sequence, hidden]`` and applies the per-token weighted sum across the top-k experts. """ - backend = perm_result["backend"] - if backend == "pure_jax": - return unfused_token_combine( + if perm_result.backend == "pure_jax": + return pure_jax_token_combine( expert_outputs, - perm_result["perm_state"], - perm_result["routing_weights"], + perm_result.perm_state, + perm_result.routing_weights, num_experts_per_tok=self.num_experts_per_tok, batch_size=batch_size, sequence_length=sequence_length, @@ -658,9 +727,9 @@ def _global_combine( # triton out_2d = token_combine( expert_outputs, - perm_result["row_id_map"], - merging_probs=perm_result["merging_probs"], - pad_offsets=perm_result["pad_offsets"], + perm_result.row_id_map, + merging_probs=perm_result.merging_probs, + pad_offsets=perm_result.pad_offsets, ) hidden_size = out_2d.shape[-1] return out_2d.reshape(batch_size, sequence_length, hidden_size).astype(self.dtype) @@ -673,7 +742,14 @@ def _forward_no_ep( self, inputs: jnp.ndarray, gate_logits: jnp.ndarray, - params: dict, + *, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, + expert_bias: Optional[jnp.ndarray] = None, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Single-shard or DP/FSDP/TP forward (no shard_map wrapper). @@ -681,19 +757,55 @@ def _forward_no_ep( ``custom_partitioning`` rule -- there is no cross-primitive collective that the rules cannot express on their own, so a ``shard_map`` is unnecessary here. + + Sharding contract for callers + ----------------------------- + + On this no-EP path the grouped quantize and grouped GEMMs run + in the caller's outer SPMD context (no ``shard_map`` boundary). + Their custom_partitioning rules read sharding from each input's + ``NamedSharding`` and propagate consistent shardings on outputs. + Concretely: + + * ``inputs`` should be FSDP/DP-sharded on the batch dim + (``input_axes`` in :class:`MoEBlock` enforces this via a + logical ``with_sharding_constraint``). + * ``wi_*`` / ``wo`` weights should carry the logical axes + ``wi_kernel_axes`` / ``wo_kernel_axes`` so FSDP shards a + weight non-contracting dim, gathered inside ``grouped_dense`` + before the GEMM. + * The wgrad reduce-scatter (when FSDP is active) is emitted by + ``grouped_dense_bwd``'s partitioning rule; no explicit + collective is needed here. + + Without those shardings the grouped GEMM falls back to + replicated-everywhere semantics (legal but defeats FSDP/DP). + Tested in ``tests/jax/test_distributed_moe_block.py`` for the + EP=2 + FSDP=2 case; the no-EP + FSDP-only case shares the same + infra and is covered when ``expert_parallelism_axis`` is left + ``None`` in that test. """ batch_size, sequence_length, hidden_size = inputs.shape inputs_2d = inputs.reshape(-1, hidden_size) logits_2d = gate_logits.reshape(-1, self.num_experts) - sparse_probs, routing_map = self._route_topk(logits_2d, params.get("expert_bias")) - aux_loss = self._compute_aux_loss(logits_2d) + sparse_probs, routing_map = self._route_topk(logits_2d, expert_bias) + # ``tokens_per_expert`` MUST come from the real routing_map so the + # aux-loss objective matches actual routing decisions under + # DeepSeek-style num_groups/group_topk routing. + tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) + aux_loss = self._compute_aux_loss(logits_2d, tokens_per_expert) perm = self._global_permute(inputs_2d, sparse_probs, routing_map) expert_outputs = self._expert_ffn( - perm["sorted_inputs"], - perm["group_sizes"], - params, + perm.sorted_inputs, + perm.group_sizes, n_groups=self.num_experts, + wi_0=wi_0, + wi_1=wi_1, + wo=wo, + wi_0_bias=wi_0_bias, + wi_1_bias=wi_1_bias, + wo_bias=wo_bias, ) output = self._global_combine(expert_outputs, perm, batch_size, sequence_length) @@ -714,7 +826,14 @@ def _forward_a2a_ep( self, inputs: jnp.ndarray, gate_logits: jnp.ndarray, - params: dict, + *, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, + expert_bias: Optional[jnp.ndarray] = None, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """Wrap the body in a ``shard_map`` that runs a forward ``ragged_all_to_all`` (A2A / A2Av) around the FFN. @@ -800,12 +919,15 @@ def _forward_a2a_ep( # pytree. shard_map fully supports pytrees: ``in_specs`` must # structurally match ``captured`` and we build them in lockstep # so adding/removing an optional bias is one ``dict[name] = ...``. + # Params must be packed here (rather than passed inline by + # ``self.param`` inside the body) because Flax variable scopes + # must not be entered from inside a JAX transform's body. captured: dict = { "inputs": inputs, "gate_logits": gate_logits, - "wi_0": params["wi_0"], - "wi_1": params["wi_1"], - "wo": params["wo"], + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, } in_specs: dict = { "inputs": P(batch_pspec_axis, None, None), @@ -814,161 +936,208 @@ def _forward_a2a_ep( "wi_1": P(ep_axis, None, None), "wo": P(ep_axis, None, None), } - if "expert_bias" in params: - captured["expert_bias"] = params["expert_bias"] + if expert_bias is not None: + captured["expert_bias"] = expert_bias in_specs["expert_bias"] = P(ep_axis) - if "wi_0_bias" in params: + if wi_0_bias is not None: + captured["wi_0_bias"] = wi_0_bias + captured["wi_1_bias"] = wi_1_bias + captured["wo_bias"] = wo_bias for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - captured[name] = params[name] in_specs[name] = P(ep_axis, None) - def _a2a_fn(local: dict) -> Tuple[jnp.ndarray, jnp.ndarray]: - shard_id = jax.lax.axis_index(ep_axis) - - # -- Stage 1: per-shard route + global permute over all E -- - # Inside the shard_map body each input has its EP axis already - # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. - local_inputs = local["inputs"] - local_logits = local["gate_logits"] - local_b, local_s, local_h = local_inputs.shape - inputs_2d = local_inputs.reshape(-1, local_h) - logits_2d = local_logits.reshape(-1, self.num_experts) - - # The router operates over the full expert axis, so the - # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be - # all-gathered before being passed in. - if "expert_bias" in local: - full_expert_bias = jax.lax.all_gather( - local["expert_bias"], axis_name=ep_axis, tiled=True - ) - else: - full_expert_bias = None - sparse_probs, routing_map = self._route_topk(logits_2d, full_expert_bias) - - # aux_loss must see the global token batch and the global - # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( - # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable - # (the sum_t * tokens product is data-dependent across - # shards). Cheapest fix: gather logits along ALL batch - # axes (ep + any DP axes) so the kernel sees the full - # token set. The aux branch has no data dependency on the - # main routing path so XLA can overlap the two on the GPU. - if self.aux_loss_coeff > 0.0: - # ``axis_name`` accepts a tuple ⇒ a single all_gather - # over the cartesian product of axes; XLA may lower - # this to one multi-axis collective or split it. - if len(self.data_parallelism_axes) == 0: - aux_gather_axes: Any = ep_axis - else: - aux_gather_axes = (ep_axis, *self.data_parallelism_axes) - global_logits_2d = jax.lax.all_gather( - logits_2d, axis_name=aux_gather_axes, axis=0, tiled=True - ) - aux_loss = self._compute_aux_loss(global_logits_2d) - else: - aux_loss = None + a2a_body = partial( + self._a2a_body, + ep_axis=ep_axis, + num_ep=num_ep, + num_experts_local=num_experts_local, + recv_buffer_rows=recv_buffer_rows, + ) - perm = self._global_permute(inputs_2d, sparse_probs, routing_map) - global_group_sizes = perm["group_sizes"] # [E] + # ``check_rep=False`` disables shard_map's invariant that any + # output declared as ``P()`` is replicated across ``ep_axis``. + # We use ``axis_index(ep_axis)`` inside ``_a2a_body`` so the + # body is genuinely non-replicated, which would otherwise + # (correctly) fail the check. ``ragged_all_to_all`` already + # produces the right cross-shard semantics; this is the standard + # JAX escape hatch when collectives + per-shard logic coexist. + return shard_map( + a2a_body, + mesh=mesh, + in_specs=(in_specs,), + out_specs=(P(batch_pspec_axis, None, None), P()), + check_rep=False, + )(captured) - # -- Stage 2: gather per-expert counts across the EP axis -- - all_shards_tokens_per_expert = jax.lax.all_gather( - global_group_sizes[None, :], - axis_name=ep_axis, - axis=0, - tiled=True, - ) # [num_ep, num_experts] + # ------------------------------------------------------------------ + # Body of the per-shard A2A-EP forward (extracted from + # :meth:`_forward_a2a_ep` for readability). Runs *inside* the + # ``shard_map`` and is therefore in EP-manual mode: collectives over + # ``ep_axis`` are explicit, the rest of the mesh stays in auto mode. + # ------------------------------------------------------------------ - # -- Stage 3: forward ragged_all_to_all over EP -- - in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep + def _a2a_body( + self, + local: dict, + *, + ep_axis: str, + num_ep: int, + num_experts_local: int, + recv_buffer_rows: int, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + shard_id = jax.lax.axis_index(ep_axis) + + # -- Stage 1: per-shard route + global permute over all E -- + # Inside the shard_map body each input has its EP axis already + # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. + local_inputs = local["inputs"] + local_logits = local["gate_logits"] + local_b, local_s, local_h = local_inputs.shape + inputs_2d = local_inputs.reshape(-1, local_h) + logits_2d = local_logits.reshape(-1, self.num_experts) + + # The router operates over the full expert axis, so the + # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be + # all-gathered before being passed in. + if "expert_bias" in local: + full_expert_bias = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True ) - recv_buf = jnp.zeros( - (recv_buffer_rows, local_h), - dtype=perm["sorted_inputs"].dtype, + else: + full_expert_bias = None + sparse_probs, routing_map = self._route_topk(logits_2d, full_expert_bias) + + # aux_loss must see the global token batch and the global + # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( + # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable + # (the sum_t * tokens product is data-dependent across + # shards). We need a *single* collective: + # * ``all_gather`` logits over (ep + any DP axes) so both + # (a) the score-for-aux kernel and (b) a re-run of + # ``_route_topk`` see the full token batch. The re-run + # gives us the global per-expert token count directly, + # avoiding a separate ``psum``. Two consecutive global + # collectives over the same replica group at the very + # start of the program have been observed to deadlock + # under FP8 autocast on some XLA + NCCL combinations, + # so we keep this branch to one collective. + # The aux branch has no data dependency on the main routing + # path beyond what is already gathered, so XLA can overlap + # the two routings on the GPU. + if self.aux_loss_coeff > 0.0: + # ``axis_name`` accepts a tuple ⇒ a single collective + # over the cartesian product of axes; XLA may lower + # this to one multi-axis op or split it. + if len(self.data_parallelism_axes) == 0: + aux_collective_axes: Any = ep_axis + else: + aux_collective_axes = (ep_axis, *self.data_parallelism_axes) + global_logits_2d = jax.lax.all_gather( + logits_2d, axis_name=aux_collective_axes, axis=0, tiled=True ) - x_recv = jax.lax.ragged_all_to_all( - perm["sorted_inputs"], - recv_buf, - in_off, - send_sz, - out_off, - recv_sz, - axis_name=ep_axis, + # Re-run topk on the gathered logits to obtain the + # *global* routing_map post-grouping (respects + # num_groups/group_topk/expert_bias just like the local + # routing). Summing over the global token dim gives the + # exact same counts as ``psum(local_tokens_per_expert)`` + # without an extra collective. The duplicate topk + # compute is small relative to the FFNs. + _, global_routing_map = self._route_topk( + global_logits_2d, full_expert_bias ) - - # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) - sorted_x, local_group_sizes, local_perm_state = local_permute_after_a2a( - x_recv, - all_shards_tokens_per_expert, - shard_id, - num_ep, + global_tokens_per_expert = jnp.sum( + global_routing_map.astype(jnp.int32), axis=0 ) - - # -- Stage 5: per-expert FFN (E_local groups) -- - local_params: dict = { - "wi_0": local["wi_0"], - "wi_1": local["wi_1"], - "wo": local["wo"], - } - if "wi_0_bias" in local: - local_params["wi_0_bias"] = local["wi_0_bias"] - local_params["wi_1_bias"] = local["wi_1_bias"] - local_params["wo_bias"] = local["wo_bias"] - expert_outputs = self._expert_ffn( - sorted_x, - local_group_sizes, - local_params, - n_groups=num_experts_local, + aux_loss = self._compute_aux_loss( + global_logits_2d, global_tokens_per_expert ) + else: + aux_loss = None - # -- Stage 6: invert local permute -- - x_send_back = local_unpermute_before_a2a(expert_outputs, local_perm_state) + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + global_group_sizes = perm.group_sizes # [E] + + # -- Stage 2: gather per-expert counts across the EP axis -- + all_shards_tokens_per_expert = jax.lax.all_gather( + global_group_sizes[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) # [num_ep, num_experts] + + # -- Stage 3: forward ragged_all_to_all over EP -- + in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + recv_buf = jnp.zeros( + (recv_buffer_rows, local_h), + dtype=perm.sorted_inputs.dtype, + ) + x_recv = jax.lax.ragged_all_to_all( + perm.sorted_inputs, + recv_buf, + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) - # -- Stage 7: reverse ragged_all_to_all over EP -- - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) - send_back_buf = jnp.zeros_like(perm["sorted_inputs"]) - y_back = jax.lax.ragged_all_to_all( - x_send_back, - send_back_buf, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) + # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) + sorted_x, local_group_sizes, local_perm_state = local_permute_after_a2a( + x_recv, + all_shards_tokens_per_expert, + shard_id, + num_ep, + ) - # -- Stage 8: invert global permute, weighted sum over top-k -- - output = self._global_combine(y_back, perm, batch_size=local_b, sequence_length=local_s) + # -- Stage 5: per-expert FFN (E_local groups) -- + expert_outputs = self._expert_ffn( + sorted_x, + local_group_sizes, + n_groups=num_experts_local, + wi_0=local["wi_0"], + wi_1=local["wi_1"], + wo=local["wo"], + wi_0_bias=local.get("wi_0_bias"), + wi_1_bias=local.get("wi_1_bias"), + wo_bias=local.get("wo_bias"), + ) - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) + # -- Stage 6: invert local permute -- + x_send_back = local_unpermute_before_a2a(expert_outputs, local_perm_state) - # ``out_specs`` must match the returned pytree structurally, - # so always emit a real scalar for aux_loss; the outer - # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. - if aux_loss is None: - aux_loss = jnp.zeros((), dtype=self.dtype) - return output, aux_loss + # -- Stage 7: reverse ragged_all_to_all over EP -- + in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + send_back_buf = jnp.zeros_like(perm.sorted_inputs) + y_back = jax.lax.ragged_all_to_all( + x_send_back, + send_back_buf, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) - # ``check_rep=False`` disables shard_map's invariant that any - # output declared as ``P()`` is replicated across ``ep_axis``. - # We use ``axis_index(ep_axis)`` inside ``_a2a_fn`` so the body - # is genuinely non-replicated, which would otherwise (correctly) - # fail the check. ``ragged_all_to_all`` already produces the - # right cross-shard semantics; this is the standard JAX escape - # hatch when collectives + per-shard logic coexist. - return shard_map( - _a2a_fn, - mesh=mesh, - in_specs=(in_specs,), - out_specs=(P(batch_pspec_axis, None, None), P()), - check_rep=False, - )(captured) + # -- Stage 8: invert global permute, weighted sum over top-k -- + output = self._global_combine( + y_back, perm, batch_size=local_b, sequence_length=local_s + ) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``out_specs`` must match the returned pytree structurally, + # so always emit a real scalar for aux_loss; the outer + # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index ba271b6b86..9fbaf64736 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -9,10 +9,12 @@ Two backends are offered: -* Fused, Triton-backed ``token_dispatch`` / ``token_combine`` - uses the +* Triton-backed ``token_dispatch`` / ``token_combine`` - uses the Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. -* Unfused, pure-JAX ``unfused_token_dispatch`` / ``unfused_token_combine`` - - uses only ``jnp.argsort`` + gather and is therefore compiled as plain XLA. +* Pure-JAX ``pure_jax_token_dispatch`` / ``pure_jax_token_combine`` - uses + only ``jnp.argsort`` + gather and is therefore compiled as plain XLA. + Despite the name, this path is often *faster* than the Triton kernels in + current testing because XLA can fuse the ops with surrounding work. Both backends support optional alignment padding (``align_size > 0``) so each expert's group size is a multiple of ``align_size``, which is required for @@ -49,9 +51,9 @@ "token_dispatch", "token_combine", "sort_chunks_by_index", - "unfused_token_dispatch", - "unfused_token_combine", - "UnfusedPermState", + "pure_jax_token_dispatch", + "pure_jax_token_combine", + "PureJaxPermState", # Ragged-all-to-all expert-parallelism helpers "compute_ragged_all_to_all_params", "compute_reverse_ragged_all_to_all_params", @@ -678,15 +680,19 @@ def _sort_chunks_by_index_bwd_rule( # ============================================================================= -# Unfused (pure-JAX) token dispatch / combine +# Pure-JAX token dispatch / combine # ============================================================================= # # The following implementations use only ``jnp.argsort`` + gather and compile # to plain XLA. They are a drop-in alternative to ``token_dispatch`` / # ``token_combine`` above, differing only in input/output conventions (the -# fused path takes ``routing_map`` and ``sparse_probs`` over all experts; the -# unfused path takes dense ``selected_experts`` and per-token ``weights`` of +# Triton path takes ``routing_map`` and ``sparse_probs`` over all experts; the +# pure-JAX path takes dense ``selected_experts`` and per-token ``weights`` of # shape ``[..., topk]``). +# +# Note: despite Triton being fused and pure-JAX being a sequence of XLA ops, +# the pure-JAX backend is often *faster* in current testing because XLA can +# fuse these ops into the surrounding work. # ----------------------------------------------------------------------------- @@ -704,7 +710,7 @@ def _sort_activations(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: assert ( inputs.shape[0] == sort_indices.shape[0] ), f"inputs.shape[0]={inputs.shape[0]} must match sort_indices.shape[0]={sort_indices.shape[0]}" - with jax.named_scope("unfused_sort_activations"): + with jax.named_scope("pure_jax_sort_activations"): return inputs[sort_indices, ...] @@ -730,7 +736,7 @@ def routing_map_to_selected_experts( ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Convert ``(sparse_probs, routing_map)`` from TE's fused router to the ``(selected_experts, weights)`` format consumed by - :func:`unfused_token_dispatch`. + :func:`pure_jax_token_dispatch`. ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` with exactly ``topk`` ``True`` positions per row. @@ -746,14 +752,14 @@ def routing_map_to_selected_experts( # Permutation state carried from dispatch to combine. -class UnfusedPermState(NamedTuple): - """Opaque state produced by :func:`unfused_token_dispatch`. +class PureJaxPermState(NamedTuple): + """Opaque state produced by :func:`pure_jax_token_dispatch`. Attributes ---------- sorted_indices : jnp.ndarray The argsort indices used in the forward sort. Needed to reverse the - permutation in :func:`unfused_token_combine`. Shape + permutation in :func:`pure_jax_token_combine`. Shape ``[num_real_tokens + padding_size]``. num_real_tokens : int Number of real (non-padding) permuted tokens, i.e. @@ -774,14 +780,14 @@ class UnfusedPermState(NamedTuple): # Dispatch (permute) -def unfused_token_dispatch( +def pure_jax_token_dispatch( inputs: jnp.ndarray, selected_experts: jnp.ndarray, num_experts: int, num_experts_per_tok: int, align_size: int = 0, roll_to_expert_id: Optional[int] = None, -) -> Tuple[jnp.ndarray, UnfusedPermState, jnp.ndarray]: +) -> Tuple[jnp.ndarray, PureJaxPermState, jnp.ndarray]: """Pure-JAX ``argsort``-based token dispatch. Parameters @@ -811,8 +817,8 @@ def unfused_token_dispatch( sorted_inputs : jnp.ndarray Permuted tokens grouped by expert, shape ``[num_real_tokens + padding_size, hidden_size]``. - perm_state : UnfusedPermState - State needed by :func:`unfused_token_combine`. + perm_state : PureJaxPermState + State needed by :func:`pure_jax_token_combine`. group_sizes : jnp.ndarray Token count per expert, shape ``[num_experts]``. Each entry is a multiple of ``align_size`` when ``align_size > 0``. @@ -907,7 +913,7 @@ def unfused_token_dispatch( padding_size = 0 - perm_state = UnfusedPermState( + perm_state = PureJaxPermState( sorted_indices=sorted_selected_experts, num_real_tokens=num_real_tokens, padding_size=padding_size, @@ -919,9 +925,9 @@ def unfused_token_dispatch( # Combine (unpermute + weighted sum) -def unfused_token_combine( +def pure_jax_token_combine( expert_outputs: jnp.ndarray, - perm_state: UnfusedPermState, + perm_state: PureJaxPermState, routing_weights: jnp.ndarray, num_experts_per_tok: int, batch_size: int, @@ -929,7 +935,7 @@ def unfused_token_combine( ) -> jnp.ndarray: """Pure-JAX ``argsort``-based token combine. - Reverses the permutation performed by :func:`unfused_token_dispatch`, + Reverses the permutation performed by :func:`pure_jax_token_dispatch`, strips any alignment-padding rows appended during dispatch, and applies a per-token weighted sum across the top-k experts. @@ -938,8 +944,8 @@ def unfused_token_combine( expert_outputs : jnp.ndarray Output of the expert FFN, shape ``[num_real_tokens + padding_size, hidden_size]``. - perm_state : UnfusedPermState - State returned by :func:`unfused_token_dispatch`. + perm_state : PureJaxPermState + State returned by :func:`pure_jax_token_dispatch`. routing_weights : jnp.ndarray Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` (or broadcastable to it after a ``reshape``). @@ -979,7 +985,7 @@ def unfused_token_combine( # intermediate dtype; callers can upcast before calling if higher # precision weight-sum is desired). reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) - with jax.named_scope("unfused_weight_sum"): + with jax.named_scope("pure_jax_weight_sum"): output = jnp.einsum( "BKE,BK -> BE", reshaped_intermediate, @@ -1206,7 +1212,7 @@ def compute_reverse_ragged_all_to_all_params( # Implementation uses :func:`sort_chunks_by_index`, which is Triton-backed # (see ``transformer_engine.jax.triton_extensions.permutation``) and has a # paired custom-VJP backward. There is no pure-JAX alternative here -- the -# global :func:`unfused_token_dispatch` / :func:`token_dispatch` choice is +# global :func:`pure_jax_token_dispatch` / :func:`token_dispatch` choice is # unaffected by this; only the (small) post-A2A chunk reorder uses Triton # unconditionally. @@ -1221,7 +1227,7 @@ def local_permute_after_a2a( tokens are contiguous. This is the EP-side complement to the global :func:`token_dispatch` / - :func:`unfused_token_dispatch`. Internally uses + :func:`pure_jax_token_dispatch`. Internally uses :func:`sort_chunks_by_index` (Triton-backed) for both the forward sort and -- via :func:`local_unpermute_before_a2a` -- the inverse. From dafaad4b9be3138fe6859b7dd86fe5944a4256aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 22:19:51 +0000 Subject: [PATCH 10/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_moe_block.py | 6 +++--- tests/jax/test_moe_block.py | 7 +++---- transformer_engine/jax/flax/moe.py | 16 ++++------------ 3 files changed, 10 insertions(+), 19 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index bb15ed8c95..8f08889953 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -153,9 +153,9 @@ def loss_fn(variables, x): sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( init_key, inputs ) - (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = ( - _make_loss_and_grad(sharded_block)(sharded_variables, inputs) - ) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = _make_loss_and_grad( + sharded_block + )(sharded_variables, inputs) wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index ed5e0529c5..e87593c9d4 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -334,10 +334,9 @@ def test_aux_loss_uses_real_routing_under_group_topk(self): " divergence; pick a config where they differ" ) - assert jnp.allclose(block_aux_loss, corrected_ref, atol=1e-5, rtol=1e-5), ( - f"Block aux_loss {block_aux_loss} does not match" - f" real-routing reference {corrected_ref}" - ) + assert jnp.allclose( + block_aux_loss, corrected_ref, atol=1e-5, rtol=1e-5 + ), f"Block aux_loss {block_aux_loss} does not match real-routing reference {corrected_ref}" # The corrected and buggy refs can be numerically close # (only the mis-routed tokens contribute to the difference), # so assert that the block is *strictly closer* to the diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index a882ddfce6..712499c2cd 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -1042,15 +1042,9 @@ def _a2a_body( # exact same counts as ``psum(local_tokens_per_expert)`` # without an extra collective. The duplicate topk # compute is small relative to the FFNs. - _, global_routing_map = self._route_topk( - global_logits_2d, full_expert_bias - ) - global_tokens_per_expert = jnp.sum( - global_routing_map.astype(jnp.int32), axis=0 - ) - aux_loss = self._compute_aux_loss( - global_logits_2d, global_tokens_per_expert - ) + _, global_routing_map = self._route_topk(global_logits_2d, full_expert_bias) + global_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) + aux_loss = self._compute_aux_loss(global_logits_2d, global_tokens_per_expert) else: aux_loss = None @@ -1123,9 +1117,7 @@ def _a2a_body( ) # -- Stage 8: invert global permute, weighted sum over top-k -- - output = self._global_combine( - y_back, perm, batch_size=local_b, sequence_length=local_s - ) + output = self._global_combine(y_back, perm, batch_size=local_b, sequence_length=local_s) if self.tensor_parallelism_axis is not None: output = jax.lax.psum_scatter( From 27c18fe582ce97c5bd24f7346d96b8f46ff45923 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 11 May 2026 17:28:08 -0700 Subject: [PATCH 11/18] revert C++ changes and will put in a new branch, tighten distributed grad tol to 5e-2, move arch/align_size docs into MoEBlock class Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 12 +- tests/jax/test_moe_block.py | 13 +-- .../common/util/multi_stream.cpp | 66 +++-------- .../jax/csrc/extensions/gemm.cpp | 14 +-- .../jax/csrc/extensions/quantization.cpp | 14 +-- transformer_engine/jax/flax/moe.py | 108 +++++++++--------- 6 files changed, 92 insertions(+), 135 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 8f08889953..0761c79aaa 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -168,6 +168,14 @@ def loss_fn(variables, x): assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2) assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + # The sharded path runs the same math on each ep-shard but + # accumulates gradients via psum across (ep, fsdp), which changes + # floating-point reduction order vs the single-device run. Under + # bf16 with these toy shapes the observed max-abs grad diff is on + # the order of a few units of bf16 eps (~1e-2). 5e-2 / 5e-2 + # leaves headroom for accumulation jitter without masking real + # divergence; matches the cross-backend bf16 grad tolerance in + # ``tests/jax/test_moe_block.py::test_pure_jax_matches_triton``. for name in ("gate_kernel", "wi_0", "wi_1", "wo"): grad_single = _unwrap_partitioned(single_grads["params"][name]) grad_sharded = _unwrap_partitioned(sharded_grads["params"][name]) @@ -175,7 +183,7 @@ def loss_fn(variables, x): grad_sharded, grad_single, dtype=DTYPE, - atol=1e-1, - rtol=1e-1, + atol=5e-2, + rtol=5e-2, err_msg=f"Distributed gradient mismatch for {name}", ) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index e87593c9d4..a901a73b66 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -228,7 +228,7 @@ def test_aux_loss_returned(self, permutation_backend): assert jnp.abs(aux_loss) < 1e2 def test_aux_loss_uses_real_routing_under_group_topk(self): - """Regression test for PR #2912 review (greptile P1). + """Aux loss must reflect the real (post-group) routing decisions. Under DeepSeek-style ``num_groups`` / ``group_topk`` routing, the auxiliary load-balancing loss must be computed using the @@ -385,12 +385,11 @@ def test_align_size_equivalence_pure_jax(self, monkeypatch): to every expert GEMM output (their input rows are zeros) and are stripped before the weighted sum. - Why the env knob: the V1 TE grouped GEMM FFI asserts - ``sum(group_sizes) == M`` at - ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``. With - ``align_size > 0`` the pure-JAX backend produces a buffer where - ``M >= sum(group_sizes)`` (the slack is structural padding for - JIT). The V2 grouped GEMM relaxes that assertion to + Why the env knob: the V1 TE grouped GEMM FFI asserts strict + equality ``sum(group_sizes) == M``. With ``align_size > 0`` the + pure-JAX backend produces a buffer where ``M >= sum(group_sizes)`` + (the slack is structural padding for JIT), so V1 is incompatible. + The V2 cuBLASLt-backed grouped GEMM relaxes the assertion to ``M >= sum(group_sizes)`` and is selected when ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1``. If V2 isn't supported on this hardware / for this dtype, the dispatch raises a diff --git a/transformer_engine/common/util/multi_stream.cpp b/transformer_engine/common/util/multi_stream.cpp index ec341abc68..6b19f36741 100644 --- a/transformer_engine/common/util/multi_stream.cpp +++ b/transformer_engine/common/util/multi_stream.cpp @@ -12,7 +12,6 @@ #include #include -#include #include #include "cuda_runtime.h" @@ -20,54 +19,18 @@ namespace transformer_engine::detail { -namespace { - -// CUDA streams and events are device-bound: a stream / event created -// on device A cannot be recorded into / waited on from device B -// (CUDA returns ``cudaErrorInvalidResourceHandle``). The previous -// implementation used ``std::call_once`` to lazily create one -// process-global vector of streams + one of events, which works for -// the single-device case (PyTorch eager / single-host single-device -// JAX) but breaks for single-process *multi*-device JAX: the first -// worker thread to win the ``call_once`` would create streams / -// events on its own device, and subsequent calls from other devices -// would receive those same handles and fail at ``cudaEventRecord``. -// -// We now key the cache on the active CUDA device. Each device gets -// its own ``num_compute_streams`` streams and events, created lazily -// the first time a thread on that device asks for one. -template -auto& per_device_pool(CreateFn&& create) { - static std::mutex mu; - using PoolT = decltype(std::vector{create()}); - static std::unordered_map pools; - int device; - NVTE_CHECK_CUDA(cudaGetDevice(&device)); - std::lock_guard lock(mu); - auto it = pools.find(device); - if (it == pools.end()) { - const size_t num_streams = nvte_get_num_compute_streams(); - PoolT v; - v.reserve(num_streams); - for (size_t i = 0; i < num_streams; i++) { - v.push_back(create()); - } - it = pools.emplace(device, std::move(v)).first; - } - return it->second; -} - -} // namespace - cudaStream_t get_compute_stream(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); - auto& streams = per_device_pool([] { - cudaStream_t s; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&s, cudaStreamNonBlocking, -1)); - return s; - }); + static std::vector streams(num_streams); + static std::once_flag stream_init_flag; + auto init = [&]() { + for (size_t i = 0; i < num_streams; i++) { + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1)); + } + }; + std::call_once(stream_init_flag, init); return streams[idx]; } @@ -75,11 +38,14 @@ cudaEvent_t get_compute_stream_event(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); - auto& events = per_device_pool([] { - cudaEvent_t e; - NVTE_CHECK_CUDA(cudaEventCreate(&e)); - return e; - }); + static std::vector events(num_streams); + static std::once_flag event_init_flag; + auto init = [&]() { + for (size_t i = 0; i < num_streams; i++) { + NVTE_CHECK_CUDA(cudaEventCreate(&events[i])); + } + }; + std::call_once(event_init_flag, init); return events[idx]; } diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 8a807cbdcc..6ca907032c 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1157,18 +1157,12 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type cudaStreamSynchronize(stream); } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - // Allow callers to pass an LHS/RHS that is at least as large as the active - // ragged region (sum_group_sizes). This supports ragged-all-to-all flows - // where the recv buffer is over-allocated to a worst-case size and only - // the first sum_group_sizes rows along the ragged dim are populated; the - // trailing slack rows are not consumed by the per-group GEMMs (which key - // off group_sizes). if (!is_rhs_ragged) { - NVTE_CHECK(sum_group_sizes <= m, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, - " must be <= M = ", m); + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); } else { - NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, - " must be <= K = ", k); + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); } } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 871abb5634..650139a61c 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -383,17 +383,9 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty cudaStreamSynchronize(stream); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - // Allow callers to pass an input that is at least as large as the active - // ragged region (sum_group_sizes). This supports ragged-all-to-all flows - // where the recv buffer is over-allocated to a worst-case size and only the - // first sum_group_sizes rows are populated; the trailing slack rows are - // simply not quantized (and not consumed by the downstream grouped GEMM - // which is also keyed on group_sizes). - // For flatten_axis==1, m == input_dims[0]; for flatten_axis>1, the per-group - // tile is dim_list_host[i] * non_group_m, so the binding dim is input_dims[0]. - NVTE_CHECK(sum_group_sizes <= input_dims[0], - "Unexpected group_sizes! sum(group_sizes)=%zu must be <= input_dims[0]=%zu (M=%zu)", - sum_group_sizes, input_dims[0], m); + NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, + "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m, + input_dims[0]); if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 712499c2cd..30f9a1bfb7 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -4,56 +4,9 @@ """Flax Linen MoEBlock for TransformerEngine JAX. -This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer -that wires together TE's fused router, a selectable token-dispatch backend -(``pure_jax`` or ``triton``), TE's ``grouped_dense``, and an -optional ragged-all-to-all (A2A / A2Av) expert-parallelism strategy. - -Architecture ------------- - -The MoEBlock is decomposed into orthogonal stages so the EP wrapper can -inject collectives between them: - -* ``_route``: gate logits -> top-k routing decisions (+ aux loss). -* ``_global_permute``: scatter tokens to experts; produces - ``[num_tokens*topk + maybe_padding, hidden]`` and - per-expert ``group_sizes`` of length ``num_experts``. -* ``_expert_ffn``: three ``grouped_dense`` calls + activation. Operates - on whatever ``(rows, group_sizes, n_groups)`` it is - handed -- agnostic to whether ``n_groups`` is the - global expert count (no-EP) or the local expert - count (A2A-EP). -* ``_global_combine``: inverse of ``_global_permute`` -- gather + weighted - sum across top-k experts. - -Two top-level forward variants compose those stages: - -* ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE - primitive's ``custom_partitioning`` rule handles - DP / FSDP / TP automatically. -* ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and inserts - ``all_gather(group_sizes)`` + forward - ``ragged_all_to_all`` + local permute around the - FFN, plus their inverses afterwards. This is the - only place ``shard_map`` is used; A2A is the - canonical EP strategy because the in-flight NCCL - EP component will require this same data layout. - -Note on ``align_size > 0`` --------------------------- - -Both permutation backends pad each expert's group to a multiple of -``align_size`` when requested, which is what CUBLASLt's grouped GEMM wants -for FP8 shape selection. The pure-JAX backend additionally appends a -zero-input padding tail to keep the buffer statically sized for JIT, so -``sum(group_sizes) <= sorted_inputs.shape[0]`` strictly. TE's -``grouped_dense`` FFI today asserts ``m == sum(group_sizes)`` at -``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``; relaxing that -check to ``m >= sum(group_sizes)`` (the kernel itself only iterates over -``sum(group_sizes)`` rows via ``nvte_multi_tensor_gemm``) is the cleanest -way to support ``align_size > 0`` end-to-end. Until that lands the -``align_size > 0`` tests stay xfail. +This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE +layer. See the class docstring for the architecture, the EP / FSDP / TP +strategies, and the ``align_size > 0`` contract. """ from functools import partial @@ -134,6 +87,51 @@ class MoEBlock(TransformerEngineBase): per-expert two-layer FFN via grouped GEMMs, activation, token combine, and optional ragged-all-to-all expert parallelism. + Architecture + ------------ + + The block is decomposed into orthogonal stages so the EP wrapper can + inject collectives between them: + + * ``_route``: gate logits -> top-k routing decisions (+ aux loss). + * ``_global_permute``: scatter tokens to experts; produces + ``[num_tokens*topk + maybe_padding, hidden]`` and per-expert + ``group_sizes`` of length ``num_experts``. + * ``_expert_ffn``: three ``grouped_dense`` calls + activation. + Operates on whatever ``(rows, group_sizes, n_groups)`` it is + handed -- agnostic to whether ``n_groups`` is the global expert + count (no-EP) or the local expert count (A2A-EP). + * ``_global_combine``: inverse of ``_global_permute`` -- gather + + weighted sum across top-k experts. + + Two top-level forward variants compose those stages: + + * ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE + primitive's ``custom_partitioning`` rule handles DP / FSDP / TP + automatically. + * ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and + inserts ``all_gather(group_sizes)`` + forward + ``ragged_all_to_all`` + local permute around the FFN, plus their + inverses afterwards. This is the only place ``shard_map`` is + used; A2A is the canonical EP strategy because the in-flight + NCCL EP component will require this same data layout. + + Note on ``align_size > 0`` + -------------------------- + + Both permutation backends pad each expert's group to a multiple of + ``align_size`` when requested, which is what cuBLASLt's grouped GEMM + wants for FP8 shape selection. The pure-JAX backend additionally + appends a zero-input padding tail to keep the buffer statically + sized for JIT, so ``sum(group_sizes) <= sorted_inputs.shape[0]`` + strictly. The V1 grouped GEMM FFI asserts strict equality + ``m == sum(group_sizes)`` and is therefore incompatible with + ``align_size > 0``; the V2 cuBLASLt-backed grouped GEMM relaxes this + to ``m >= sum(group_sizes)`` and only iterates over the populated + ragged region. The ``align_size > 0`` tests therefore force + ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1`` and ``skip`` if V2 is not + supported on the target hardware / dtype. + Two permutation backends are pluggable via ``permutation_backend``: * ``"pure_jax"`` (default) -- argsort-based @@ -230,11 +228,11 @@ class MoEBlock(TransformerEngineBase): permutation_backend : str ``"pure_jax"`` (default) or ``"triton"``. align_size : int - Alignment for per-expert group sizes after padding. ``0`` disables - padding (the only supported configuration end-to-end today). ``>0`` - is required for quantized TE grouped GEMM whose recipe-specific - alignment must divide ``align_size``; see the module docstring for - the FFI assertion that currently blocks ``>0`` for both backends. + Alignment for per-expert group sizes after padding. ``0`` + disables padding. ``>0`` is required for quantized TE grouped + GEMM whose recipe-specific alignment must divide ``align_size``, + and requires the V2 cuBLASLt-backed grouped GEMM (see the + ``align_size > 0`` note in this docstring). dtype : jnp.dtype Compute and parameter dtype. From abbb2c6ad5f3189995c3f65a235bb00270444c04 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 12 May 2026 15:53:33 -0700 Subject: [PATCH 12/18] address more comments: ep_resource look up, perm backend enum, accepting None as group_topk, align_size rename, Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 14 +- tests/jax/test_moe_block.py | 27 +-- transformer_engine/jax/flax/moe.py | 228 +++++++++++++----------- transformer_engine/jax/sharding.py | 34 ++++ 4 files changed, 188 insertions(+), 115 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 0761c79aaa..64a8491b6a 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -25,11 +25,13 @@ def _inject_moe(request): from transformer_engine.jax import MeshResource, autocast from transformer_engine.jax.flax import MoEBlock + from transformer_engine.jax.flax.moe import PermutationBackend mod = sys.modules[__name__] mod.MeshResource = MeshResource mod.autocast = autocast mod.MoEBlock = MoEBlock + mod.PermutationBackend = PermutationBackend yield @@ -59,6 +61,7 @@ def test_ep2_fsdp2_matches_single_device(self, permutation_backend): if not is_devices_enough(4): pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.") + permutation_backend = PermutationBackend(permutation_backend) key = jax.random.PRNGKey(11) init_key, data_key = jax.random.split(key) inputs = _make_inputs(data_key) @@ -120,14 +123,19 @@ def loss_fn(variables, x): # device owns ``B/(ep*fsdp)`` unique tokens (no redundant compute # across fsdp peers within an ep group). sharded_block = MoEBlock( - expert_parallelism_axis="ep", data_parallelism_axes=("fsdp",), - mesh=mesh, input_axes=("batch", None, None), **base_kwargs, ) - with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): + # ``MoEBlock`` resolves the EP axis from + # ``global_mesh_resource().ep_resource`` (set via ``autocast``), + # so the ``ep`` axis on the mesh is wired in by passing + # ``ep_resource="ep"`` here -- no per-instance config needed. + with mesh, autocast( + enabled=False, + mesh_resource=MeshResource(fsdp_resource="fsdp", ep_resource="ep"), + ): with nn.logical_axis_rules(logical_axis_rules): # ``MoEBlock`` registers params via ``with_logical_partitioning`` # which only attaches LogicallyPartitioned metadata; the diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index a901a73b66..0d89e6dab7 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -14,7 +14,7 @@ decisions. * Auxiliary load-balancing loss is returned when ``aux_loss_coeff > 0``. * DeepSeek-style grouped top-k (``num_groups`` / ``group_topk``) runs. -* ``align_size > 0`` produces numerically-equivalent outputs to ``align_size = 0`` +* ``_align_size > 0`` produces numerically-equivalent outputs to ``_align_size = 0`` for the pure-JAX backend (padding must not change the result). """ @@ -40,9 +40,11 @@ def _inject_moe(request): return from transformer_engine.jax.flax import MoEBlock + from transformer_engine.jax.flax.moe import PermutationBackend mod = sys.modules[__name__] mod.MoEBlock = MoEBlock + mod.PermutationBackend = PermutationBackend yield @@ -93,6 +95,7 @@ class TestMoEBlockSingleDevice: @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) def test_forward_shape_and_finite(self, permutation_backend): + permutation_backend = PermutationBackend(permutation_backend) key = jax.random.PRNGKey(0) init_key, data_key = jax.random.split(key) @@ -115,6 +118,7 @@ def test_forward_shape_and_finite(self, permutation_backend): @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) def test_backward_grad_is_finite_and_nonzero(self, permutation_backend): + permutation_backend = PermutationBackend(permutation_backend) key = jax.random.PRNGKey(1) init_key, data_key = jax.random.split(key) @@ -157,8 +161,8 @@ def test_pure_jax_triton_equivalence(self): intermediate_size=INTERMEDIATE_SIZE, dtype=DTYPE, ) - pure_block = MoEBlock(permutation_backend="pure_jax", **base_kwargs) - triton_block = MoEBlock(permutation_backend="triton", **base_kwargs) + pure_block = MoEBlock(permutation_backend=PermutationBackend.PURE_JAX, **base_kwargs) + triton_block = MoEBlock(permutation_backend=PermutationBackend.TRITON, **base_kwargs) inputs = _make_inputs(data_key) # Share a single parameter tree so routing decisions and expert @@ -206,6 +210,7 @@ def loss_fn(block, variables, inputs): @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) def test_aux_loss_returned(self, permutation_backend): + permutation_backend = PermutationBackend(permutation_backend) key = jax.random.PRNGKey(3) init_key, data_key = jax.random.split(key) @@ -272,7 +277,7 @@ def test_aux_loss_uses_real_routing_under_group_topk(self): num_experts=NUM_EXPERTS, num_experts_per_tok=NUM_EXPERTS_PER_TOK, intermediate_size=INTERMEDIATE_SIZE, - permutation_backend="pure_jax", + permutation_backend=PermutationBackend.PURE_JAX, score_function="sigmoid", num_groups=num_groups, group_topk=group_topk, @@ -357,6 +362,7 @@ def test_aux_loss_uses_real_routing_under_group_topk(self): @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) def test_group_topk_deepseek(self, permutation_backend): """Exercise DeepSeek-style grouped top-k routing.""" + permutation_backend = PermutationBackend(permutation_backend) key = jax.random.PRNGKey(4) init_key, data_key = jax.random.split(key) @@ -380,13 +386,13 @@ def test_group_topk_deepseek(self, permutation_backend): assert jnp.all(jnp.isfinite(output)) def test_align_size_equivalence_pure_jax(self, monkeypatch): - """For the pure-JAX backend, ``align_size > 0`` must not change the + """For the pure-JAX backend, ``_align_size > 0`` must not change the numerical output of the forward pass: padding tokens contribute zero to every expert GEMM output (their input rows are zeros) and are stripped before the weighted sum. Why the env knob: the V1 TE grouped GEMM FFI asserts strict - equality ``sum(group_sizes) == M``. With ``align_size > 0`` the + equality ``sum(group_sizes) == M``. With ``_align_size > 0`` the pure-JAX backend produces a buffer where ``M >= sum(group_sizes)`` (the slack is structural padding for JIT), so V1 is incompatible. The V2 cuBLASLt-backed grouped GEMM relaxes the assertion to @@ -405,11 +411,11 @@ def test_align_size_equivalence_pure_jax(self, monkeypatch): num_experts=NUM_EXPERTS, num_experts_per_tok=NUM_EXPERTS_PER_TOK, intermediate_size=INTERMEDIATE_SIZE, - permutation_backend="pure_jax", + permutation_backend=PermutationBackend.PURE_JAX, dtype=DTYPE, ) - block_no_pad = MoEBlock(align_size=0, **base_kwargs) - block_pad = MoEBlock(align_size=16, **base_kwargs) + block_no_pad = MoEBlock(_align_size=0, **base_kwargs) + block_pad = MoEBlock(_align_size=16, **base_kwargs) inputs = _make_inputs(data_key) try: @@ -422,7 +428,7 @@ def test_align_size_equivalence_pure_jax(self, monkeypatch): raise assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( - "align_size > 0 must not change pure_jax forward output; max diff" + "_align_size > 0 must not change pure_jax forward output; max diff" f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" ) @@ -430,6 +436,7 @@ def test_align_size_equivalence_pure_jax(self, monkeypatch): def test_jit_and_determinism(self, permutation_backend): """The block must be JIT-compilable and produce a deterministic forward pass across repeat calls with the same params.""" + permutation_backend = PermutationBackend(permutation_backend) key = jax.random.PRNGKey(6) init_key, data_key = jax.random.split(key) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 30f9a1bfb7..288347a6d1 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -5,10 +5,11 @@ """Flax Linen MoEBlock for TransformerEngine JAX. This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE -layer. See the class docstring for the architecture, the EP / FSDP / TP -strategies, and the ``align_size > 0`` contract. +layer. See the class docstring for the architecture, the EP / FSDP +strategies, and the ``_align_size > 0`` contract. """ +from enum import Enum from functools import partial from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -32,7 +33,11 @@ ) from ..quantize import noop_quantizer_set from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function -from ..sharding import with_sharding_constraint_by_logical_axes +from ..sharding import ( + _get_mesh, + get_active_resource_axis, + with_sharding_constraint_by_logical_axes, +) from .module import TransformerEngineBase, _convert_to_activation_function PRNGKey = Any @@ -42,7 +47,25 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["GlobalPermuteResult", "MoEBlock"] +__all__ = ["GlobalPermuteResult", "MoEBlock", "PermutationBackend"] + + +# ============================================================================= +# PermutationBackend +# ============================================================================= + + +class PermutationBackend(Enum): + """Token-dispatch / combine backend used by :class:`MoEBlock`. + + * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain XLA; + typically faster than ``TRITON`` in current testing because XLA can + fuse the ops with surrounding work. + * ``TRITON``: TE's fused Triton kernels. + """ + + PURE_JAX = "pure_jax" + TRITON = "triton" # ============================================================================= @@ -71,7 +94,9 @@ class GlobalPermuteResult: row_id_map: Optional[jnp.ndarray] = None pad_offsets: Optional[jnp.ndarray] = None merging_probs: Optional[jnp.ndarray] = None - backend: str = flax_struct.field(pytree_node=False, default="pure_jax") + backend: PermutationBackend = flax_struct.field( + pytree_node=False, default=PermutationBackend.PURE_JAX + ) # ============================================================================= @@ -107,7 +132,7 @@ class MoEBlock(TransformerEngineBase): Two top-level forward variants compose those stages: * ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE - primitive's ``custom_partitioning`` rule handles DP / FSDP / TP + primitive's ``custom_partitioning`` rule handles DP / FSDP automatically. * ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and inserts ``all_gather(group_sizes)`` + forward @@ -116,40 +141,44 @@ class MoEBlock(TransformerEngineBase): used; A2A is the canonical EP strategy because the in-flight NCCL EP component will require this same data layout. - Note on ``align_size > 0`` - -------------------------- + Note on ``_align_size > 0`` + --------------------------- Both permutation backends pad each expert's group to a multiple of - ``align_size`` when requested, which is what cuBLASLt's grouped GEMM - wants for FP8 shape selection. The pure-JAX backend additionally - appends a zero-input padding tail to keep the buffer statically - sized for JIT, so ``sum(group_sizes) <= sorted_inputs.shape[0]`` - strictly. The V1 grouped GEMM FFI asserts strict equality - ``m == sum(group_sizes)`` and is therefore incompatible with - ``align_size > 0``; the V2 cuBLASLt-backed grouped GEMM relaxes this - to ``m >= sum(group_sizes)`` and only iterates over the populated - ragged region. The ``align_size > 0`` tests therefore force - ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1`` and ``skip`` if V2 is not - supported on the target hardware / dtype. + ``_align_size`` when requested, which is what cuBLASLt's grouped + GEMM wants for FP8 shape selection. The pure-JAX backend + additionally appends a zero-input padding tail to keep the buffer + statically sized for JIT, so ``sum(group_sizes) <= + sorted_inputs.shape[0]`` strictly. The V1 grouped GEMM FFI asserts + strict equality ``m == sum(group_sizes)`` and is therefore + incompatible with ``_align_size > 0``; the V2 cuBLASLt-backed + grouped GEMM relaxes this to ``m >= sum(group_sizes)`` and only + iterates over the populated ragged region. The ``_align_size > 0`` + tests therefore force ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1`` and + ``skip`` if V2 is not supported on the target hardware / dtype. Two permutation backends are pluggable via ``permutation_backend``: - * ``"pure_jax"`` (default) -- argsort-based + * :attr:`PermutationBackend.PURE_JAX` (default) -- argsort-based :func:`~transformer_engine.jax.permutation.pure_jax_token_dispatch` / :func:`~transformer_engine.jax.permutation.pure_jax_token_combine`. Faster than Triton in profiling for DeepSeek-style configs. - * ``"triton"`` -- TE's fused + * :attr:`PermutationBackend.TRITON` -- TE's fused :func:`~transformer_engine.jax.permutation.token_dispatch` / :func:`~transformer_engine.jax.permutation.token_combine` Triton kernels. - Expert parallelism (``expert_parallelism_axis is not None``) uses the - **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes its - own tokens globally over all experts, then a forward - ``ragged_all_to_all`` exchanges per-expert chunks so each shard ends up - holding only the tokens for its local experts; after the FFN a reverse - ``ragged_all_to_all`` returns each shard's outputs to it. This matches - the layout the in-flight NCCL EP component expects. + Expert parallelism is configured via :class:`MeshResource`'s + ``ep_resource`` axis. When that axis is set on the active + :func:`~transformer_engine.jax.global_mesh_resource` and has more + than one device, ``MoEBlock`` dispatches to the + **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes + its own tokens globally over all experts, then a forward + ``ragged_all_to_all`` exchanges per-expert chunks so each shard + ends up holding only the tokens for its local experts; after the + FFN a reverse ``ragged_all_to_all`` returns each shard's outputs + to it. This matches the layout the in-flight NCCL EP component + expects. Parameters ---------- @@ -171,11 +200,11 @@ class MoEBlock(TransformerEngineBase): :func:`fused_topk_with_score_function`. use_pre_softmax : bool Apply softmax before top-k when ``score_function="softmax"``. - num_groups : int - Number of routing groups for grouped top-k (DeepSeek). ``<=0`` - disables. - group_topk : int - Top-k at the group level. ``<=0`` disables. + num_groups : Optional[int] + Number of routing groups for grouped top-k (DeepSeek). ``None`` + (default) disables. + group_topk : Optional[int] + Top-k at the group level. ``None`` (default) disables. scaling_factor : float Scaling factor applied to output probs. use_expert_bias : bool @@ -202,37 +231,22 @@ class MoEBlock(TransformerEngineBase): Logical axes used to constrain the input activation sharding at the block boundary. ``()`` (default) means no constraint. - expert_parallelism_axis : Optional[str] - Mesh axis along which experts are split. When set, the forward - pass is wrapped in :func:`jax.shard_map` that implements the - ragged-all-to-all EP strategy. When ``None`` (default), no - ``shard_map`` wrapper is used; each TE primitive's - ``custom_partitioning`` rule handles DP / FSDP / TP automatically. data_parallelism_axes : tuple[str, ...] Additional mesh axes that the input *batch* dim is sharded over - IN ADDITION to ``expert_parallelism_axis``. Setting this to e.g. - ``("fsdp",)`` makes the ``shard_map`` ``in_specs`` for the batch - dim become ``P(("ep", "fsdp"), None, None)`` -- giving each - device a unique slice of the batch (true FSDP) instead of + IN ADDITION to ``MeshResource.ep_resource``. Setting this to + e.g. ``("fsdp",)`` makes the ``shard_map`` ``in_specs`` for the + batch dim become ``P(("ep", "fsdp"), None, None)`` -- giving + each device a unique slice of the batch (true FSDP) instead of replicating the per-ep-shard batch across fsdp peers. Routing is unaffected: ``axis_index("ep")`` still controls the ragged-all-to-all; the extra fsdp peers within an ep group send and receive their own batch slices in lockstep. Default ``()`` preserves legacy ZeRO-1-style behavior (activations replicated on fsdp within an ep group). - tensor_parallelism_axis : Optional[str] - Mesh axis for tensor parallelism on the FFN intermediate dim. When - set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed - along this axis. - - permutation_backend : str - ``"pure_jax"`` (default) or ``"triton"``. - align_size : int - Alignment for per-expert group sizes after padding. ``0`` - disables padding. ``>0`` is required for quantized TE grouped - GEMM whose recipe-specific alignment must divide ``align_size``, - and requires the V2 cuBLASLt-backed grouped GEMM (see the - ``align_size > 0`` note in this docstring). + + permutation_backend : PermutationBackend + :attr:`PermutationBackend.PURE_JAX` (default) or + :attr:`PermutationBackend.TRITON`. dtype : jnp.dtype Compute and parameter dtype. @@ -243,6 +257,15 @@ class MoEBlock(TransformerEngineBase): use_bias : bool If ``True``, registers per-expert FFN biases ``wi_0_bias``, ``wi_1_bias``, ``wo_bias``. + + TODO: + ----- + ``_align_size`` is an internal, non-public knob (alignment for + per-expert group sizes after padding). A follow-up PR will infer it + from the active quantization recipe, after which it will become a + fully-internal implementation detail. Until then it stays + intentionally underscored to discourage callers from depending on + it. """ # Architecture @@ -254,8 +277,8 @@ class MoEBlock(TransformerEngineBase): # Routing score_function: Union[str, ScoreFunction] = "softmax" use_pre_softmax: bool = False - num_groups: int = -1 - group_topk: int = -1 + num_groups: Optional[int] = None + group_topk: Optional[int] = None scaling_factor: float = 1.0 use_expert_bias: bool = False aux_loss_coeff: float = 0.0 @@ -267,16 +290,18 @@ class MoEBlock(TransformerEngineBase): input_axes: Tuple[Optional[str], ...] = () # Parallelism - expert_parallelism_axis: Optional[str] = None + # + # The EP axis is resolved from ``global_mesh_resource().ep_resource`` + # and the active mesh, not configured per-instance. ``MoEBlock`` + # uses ``_forward_a2a_ep`` when that axis exists on the mesh and + # has > 1 device; otherwise it uses ``_forward_no_ep``. data_parallelism_axes: Tuple[str, ...] = () - tensor_parallelism_axis: Optional[str] = None - # ``jax.sharding.Mesh`` to use when ``expert_parallelism_axis`` is set. - # Required for the ``shard_map`` wrapper; ignored otherwise. - mesh: Optional[Any] = None # Permutation - permutation_backend: str = "pure_jax" - align_size: int = 0 + permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX + # See class docstring "Notes": internal, will be inferred from the + # quantization recipe in a follow-up PR. + _align_size: int = 0 # Dtypes / init / misc dtype: DType = jnp.float32 @@ -294,9 +319,9 @@ def __post_init__(self): 1.0, "fan_in", "truncated_normal", dtype=self.dtype ), ) - if self.permutation_backend not in ("pure_jax", "triton"): - raise ValueError( - "permutation_backend must be 'pure_jax' or 'triton'," + if not isinstance(self.permutation_backend, PermutationBackend): + raise TypeError( + "permutation_backend must be a PermutationBackend," f" got {self.permutation_backend!r}" ) super().__post_init__() @@ -389,7 +414,8 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: self.dtype, ) - if self.expert_parallelism_axis is None: + ep_axis = get_active_resource_axis("ep_resource") + if ep_axis is None: output, aux_loss = self._forward_no_ep( inputs, gate_logits, @@ -405,6 +431,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: output, aux_loss = self._forward_a2a_ep( inputs, gate_logits, + ep_axis=ep_axis, wi_0=wi_0, wi_1=wi_1, wo=wo, @@ -470,12 +497,15 @@ def _route_topk( expert_bias: Optional[jnp.ndarray], ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Run the fused router top-k selection.""" + # ``fused_topk_with_score_function`` uses ``-1`` as the + # "disabled" sentinel for the grouped-routing knobs; translate + # our ``None`` user-facing default to that sentinel here. sparse_probs, routing_map = fused_topk_with_score_function( logits_2d, topk=self.num_experts_per_tok, use_pre_softmax=self.use_pre_softmax, - num_groups=self.num_groups, - group_topk=self.group_topk, + num_groups=-1 if self.num_groups is None else self.num_groups, + group_topk=-1 if self.group_topk is None else self.group_topk, scaling_factor=self.scaling_factor, score_function=self.score_function, expert_bias=expert_bias, @@ -554,7 +584,7 @@ def _global_permute( num_tokens = inputs_2d.shape[0] topk = self.num_experts_per_tok - if self.permutation_backend == "pure_jax": + if self.permutation_backend is PermutationBackend.PURE_JAX: selected_experts, routing_weights = routing_map_to_selected_experts( sparse_probs, routing_map, topk ) @@ -563,10 +593,10 @@ def _global_permute( selected_experts, num_experts=self.num_experts, num_experts_per_tok=topk, - align_size=self.align_size, + align_size=self._align_size, ) return GlobalPermuteResult( - backend="pure_jax", + backend=PermutationBackend.PURE_JAX, sorted_inputs=sorted_inputs, group_sizes=group_sizes, perm_state=perm_state, @@ -575,7 +605,7 @@ def _global_permute( # triton num_out_tokens = num_tokens * topk - align_size_arg = self.align_size if self.align_size > 0 else None + align_size_arg = self._align_size if self._align_size > 0 else None ( sorted_inputs, _permuted_probs, @@ -590,7 +620,7 @@ def _global_permute( align_size=align_size_arg, ) return GlobalPermuteResult( - backend="triton", + backend=PermutationBackend.TRITON, sorted_inputs=sorted_inputs, group_sizes=group_sizes, row_id_map=row_id_map, @@ -713,7 +743,7 @@ def _global_combine( Gathers per-expert outputs back into ``[batch, sequence, hidden]`` and applies the per-token weighted sum across the top-k experts. """ - if perm_result.backend == "pure_jax": + if perm_result.backend is PermutationBackend.PURE_JAX: return pure_jax_token_combine( expert_outputs, perm_result.perm_state, @@ -749,9 +779,9 @@ def _forward_no_ep( wo_bias: Optional[jnp.ndarray] = None, expert_bias: Optional[jnp.ndarray] = None, ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Single-shard or DP/FSDP/TP forward (no shard_map wrapper). + """Single-shard or DP/FSDP forward (no shard_map wrapper). - DP / FSDP / TP all flow through each TE primitive's + DP / FSDP both flow through each TE primitive's ``custom_partitioning`` rule -- there is no cross-primitive collective that the rules cannot express on their own, so a ``shard_map`` is unnecessary here. @@ -780,8 +810,8 @@ def _forward_no_ep( replicated-everywhere semantics (legal but defeats FSDP/DP). Tested in ``tests/jax/test_distributed_moe_block.py`` for the EP=2 + FSDP=2 case; the no-EP + FSDP-only case shares the same - infra and is covered when ``expert_parallelism_axis`` is left - ``None`` in that test. + infra and is covered when ``ep_resource`` is unset on the + active ``MeshResource``. """ batch_size, sequence_length, hidden_size = inputs.shape inputs_2d = inputs.reshape(-1, hidden_size) @@ -806,14 +836,6 @@ def _forward_no_ep( wo_bias=wo_bias, ) output = self._global_combine(expert_outputs, perm, batch_size, sequence_length) - - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) return output, aux_loss # ------------------------------------------------------------------ @@ -825,6 +847,7 @@ def _forward_a2a_ep( inputs: jnp.ndarray, gate_logits: jnp.ndarray, *, + ep_axis: str, wi_0: jnp.ndarray, wi_1: jnp.ndarray, wo: jnp.ndarray, @@ -859,13 +882,13 @@ def _forward_a2a_ep( """ from jax.experimental.shard_map import shard_map - ep_axis = self.expert_parallelism_axis - if self.mesh is None: + mesh = _get_mesh() + if mesh is None or mesh.empty: raise ValueError( - "MoEBlock.expert_parallelism_axis is set; `mesh` must also" - " be provided so the EP shard_map can be built." + "MoEBlock requires an active jax.sharding.Mesh (either via" + " `with mesh:` or `jax.set_mesh`) when EP is configured on" + " the active MeshResource." ) - mesh = self.mesh num_ep = mesh.shape[ep_axis] assert ( self.num_experts % num_ep == 0 @@ -911,7 +934,16 @@ def _forward_a2a_ep( raise ValueError( f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" ) + # Worst-case A2A receive count per shard: every peer can send its + # full per-expert-aligned local buffer. With ``_align_size > 0`` + # each per-expert group can be padded by up to ``_align_size - 1`` + # rows, so per shard the receive can overshoot the unpadded count + # by up to ``num_experts * (_align_size - 1)``. Skipping this + # extra slack would let ``ragged_all_to_all`` write past + # ``recv_buf`` when EP and padding are combined. recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + if self._align_size > 0: + recv_buffer_rows += self.num_experts * (self._align_size - 1) # Pack everything that crosses the shard_map boundary into a dict # pytree. shard_map fully supports pytrees: ``in_specs`` must @@ -1117,14 +1149,6 @@ def _a2a_body( # -- Stage 8: invert global permute, weighted sum over top-k -- output = self._global_combine(y_back, perm, batch_size=local_b, sequence_length=local_s) - if self.tensor_parallelism_axis is not None: - output = jax.lax.psum_scatter( - output, - self.tensor_parallelism_axis, - scatter_dimension=2, - tiled=True, - ) - # ``out_specs`` must match the returned pytree structurally, # so always emit a real scalar for aux_loss; the outer # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c14..182a4a2e00 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -332,6 +332,7 @@ class MeshResource: fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None + ep_resource: Axis name for expert parallelism (MoE expert sharding), default is None """ dp_resource: str = None @@ -340,6 +341,7 @@ class MeshResource: fsdp_resource: str = None pp_resource: str = None cp_resource: str = None + ep_resource: str = None _GLOBAL_MESH_RESOURCE = None @@ -379,6 +381,38 @@ def global_mesh_resource() -> MeshResource: return _GLOBAL_MESH_RESOURCE +def get_active_resource_axis(resource_name: str) -> Optional[str]: + """Resolve a :class:`MeshResource` attribute to its mesh axis name, + or return ``None`` if that resource is not active. + + "Active" means all three are true: + + * a physical mesh is set (``is_mesh_available()``), + * the ``MeshResource`` attribute is non-``None``, + * the corresponding mesh axis has more than 1 device. + + Mirrors the three-step ``is_X_enabled`` idiom in + :func:`get_sharding_map_logic_axis_to_mesh_axis` but returns the + axis name itself (or ``None``) so callers can use it directly in + collectives / ``shard_map`` specs. + + Args: + resource_name: Attribute name on :class:`MeshResource`, e.g. + ``"fsdp_resource"`` or ``"ep_resource"``. + + Returns: + The mesh axis name when active, else ``None``. + """ + if not is_mesh_available(): + return None + if _GLOBAL_MESH_RESOURCE is None: + return None + axis = getattr(_GLOBAL_MESH_RESOURCE, resource_name) + if axis is None or get_mesh_axis_size(axis) <= 1: + return None + return axis + + def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): """Perform all-reduce sum operation along data parallelism and FSDP axes. From b375db7b3f0963571ba91966cfcd2c1b36d84cf3 Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 12 May 2026 16:56:12 -0700 Subject: [PATCH 13/18] tests/jax/test_distributed_moe_block.py Signed-off-by: tdophung --- tests/jax/test_moe_block.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py index 0d89e6dab7..fbe8c083e9 100644 --- a/tests/jax/test_moe_block.py +++ b/tests/jax/test_moe_block.py @@ -2,10 +2,10 @@ # # See LICENSE for license information. -"""Basic tests for ``transformer_engine.jax.flax.MoEBlock``. +"""Basic tests for ``transformer_engine.jax.flax._MoEBlock``. -These tests exercise the MoEBlock on a single device (no expert parallelism) -and verify: +These tests exercise the (experimental) ``_MoEBlock`` on a single device +(no expert parallelism) and verify: * Forward pass runs end-to-end and produces the expected output shape. * Backward pass yields finite, non-trivial parameter gradients. @@ -26,20 +26,23 @@ import pytest -# The MoEBlock pulls in both the fused-router CUDA kernel and the Triton -# permutation kernels, so it can only run in the environment where those are -# available. We gate the test on the ``triton`` marker (the Triton permutation -# backend is stricter than the CUDA router). See ``conftest.py``. +# The ``_MoEBlock`` class pulls in both the fused-router CUDA kernel and +# the Triton permutation kernels, so it can only run in the environment +# where those are available. We gate the test on the ``triton`` marker (the +# Triton permutation backend is stricter than the CUDA router). See +# ``conftest.py``. @pytest.fixture(autouse=True, scope="function") def _inject_moe(request): - """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + """Lazy-load ``_MoEBlock`` only for tests marked ``triton``.""" if not request.node.get_closest_marker("triton"): yield return - from transformer_engine.jax.flax import MoEBlock + # The class is intentionally exposed as ``_MoEBlock`` (experimental); + # aliasing to ``MoEBlock`` here keeps the test bodies readable. + from transformer_engine.jax.flax import _MoEBlock as MoEBlock from transformer_engine.jax.flax.moe import PermutationBackend mod = sys.modules[__name__] @@ -91,7 +94,7 @@ def _unwrap_partitioned(x): @pytest.mark.triton class TestMoEBlockSingleDevice: - """Single-device smoke tests for :class:`MoEBlock`.""" + """Single-device smoke tests for :class:`_MoEBlock`.""" @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) def test_forward_shape_and_finite(self, permutation_backend): From 37c871c48cb83f7af55af839cc845b310cbc604e Mon Sep 17 00:00:00 2001 From: tdophung Date: Tue, 12 May 2026 16:57:16 -0700 Subject: [PATCH 14/18] change naming and add message for experimental feature Signed-off-by: tdophung --- tests/jax/test_distributed_moe_block.py | 9 +++-- transformer_engine/jax/flax/__init__.py | 4 +- transformer_engine/jax/flax/moe.py | 53 ++++++++++++++++--------- transformer_engine/jax/permutation.py | 2 +- 4 files changed, 44 insertions(+), 24 deletions(-) diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py index 64a8491b6a..98fd6a7212 100644 --- a/tests/jax/test_distributed_moe_block.py +++ b/tests/jax/test_distributed_moe_block.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Distributed tests for ``transformer_engine.jax.flax.MoEBlock``.""" +"""Distributed tests for the experimental ``transformer_engine.jax.flax._MoEBlock``.""" import sys @@ -18,13 +18,16 @@ @pytest.fixture(autouse=True, scope="function") def _inject_moe(request): - """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + """Lazy-load ``_MoEBlock`` only for tests marked ``triton``.""" if not request.node.get_closest_marker("triton"): yield return from transformer_engine.jax import MeshResource, autocast - from transformer_engine.jax.flax import MoEBlock + + # The class is intentionally exposed as ``_MoEBlock`` (experimental); + # aliasing to ``MoEBlock`` here keeps the test bodies readable. + from transformer_engine.jax.flax import _MoEBlock as MoEBlock from transformer_engine.jax.flax.moe import PermutationBackend mod = sys.modules[__name__] diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 0cd7835bcf..adf9c8911b 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -9,7 +9,7 @@ make_dot_general_cls, make_grouped_dense_cls, ) -from .moe import MoEBlock +from .moe import _MoEBlock from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -19,7 +19,7 @@ "LayerNorm", "LayerNormDenseGeneral", "LayerNormMLP", - "MoEBlock", + "_MoEBlock", "wrap_function_in_te_state_module", "make_dot_general_cls", "make_grouped_dense_cls", diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index 288347a6d1..f4ef323e24 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -2,11 +2,18 @@ # # See LICENSE for license information. -"""Flax Linen MoEBlock for TransformerEngine JAX. - -This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE -layer. See the class docstring for the architecture, the EP / FSDP -strategies, and the ``_align_size > 0`` contract. +"""Flax Linen MoE block for TransformerEngine JAX. + +This module exposes :class:`_MoEBlock`, an **experimental** self-contained +Flax Linen MoE layer. It is intentionally prefixed with an underscore +while TE's NCCL-backed EP component (and the recipe-driven alignment +follow-up) stabilises; the public ``MoEBlock`` alias will be introduced +once those dependencies are ready (target: the TE release following the +2.16 code freeze). Until then please treat the class, its parameters, +and :class:`GlobalPermuteResult` as unstable. + +See the class docstring for the architecture, the EP / FSDP strategies, +and the ``_align_size > 0`` contract. """ from enum import Enum @@ -47,7 +54,7 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["GlobalPermuteResult", "MoEBlock", "PermutationBackend"] +__all__ = ["GlobalPermuteResult", "PermutationBackend", "_MoEBlock"] # ============================================================================= @@ -56,7 +63,7 @@ class PermutationBackend(Enum): - """Token-dispatch / combine backend used by :class:`MoEBlock`. + """Token-dispatch / combine backend used by :class:`_MoEBlock`. * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain XLA; typically faster than ``TRITON`` in current testing because XLA can @@ -72,9 +79,9 @@ class PermutationBackend(Enum): # GlobalPermuteResult # ============================================================================= # -# Output of :meth:`MoEBlock._global_permute`. Carried as a pytree (so it +# Output of :meth:`_MoEBlock._global_permute`. Carried as a pytree (so it # crosses ``jax.shard_map`` / ``jax.value_and_grad`` boundaries -# transparently) and consumed by :meth:`MoEBlock._global_combine`. The +# transparently) and consumed by :meth:`_MoEBlock._global_combine`. The # fields populated depend on the permutation backend; the unused fields # stay ``None``. # @@ -85,7 +92,7 @@ class PermutationBackend(Enum): @flax_struct.dataclass class GlobalPermuteResult: - """Result of :meth:`MoEBlock._global_permute`.""" + """Result of :meth:`_MoEBlock._global_permute`.""" sorted_inputs: jnp.ndarray group_sizes: jnp.ndarray @@ -100,12 +107,22 @@ class GlobalPermuteResult: # ============================================================================= -# MoEBlock +# _MoEBlock # ============================================================================= -class MoEBlock(TransformerEngineBase): - """Mixture-of-Experts Flax Linen block. +class _MoEBlock(TransformerEngineBase): + """Mixture-of-Experts Flax Linen block (**experimental**). + + .. warning:: + + This class is exposed as ``_MoEBlock`` (leading underscore) on + purpose: it is not part of the stable public API yet. The TE + NCCL-backed EP component and the recipe-driven ``_align_size`` + follow-up both need to land before this is promoted to a public + ``MoEBlock``. Until then, expect signature changes, including + to :class:`GlobalPermuteResult` and :class:`PermutationBackend`. + Target promotion: the TE release after the 2.16 code freeze. Encapsulates the full MoE forward pass: gate projection, fused top-k routing, optional auxiliary load-balancing loss, token dispatch, @@ -171,7 +188,7 @@ class MoEBlock(TransformerEngineBase): Expert parallelism is configured via :class:`MeshResource`'s ``ep_resource`` axis. When that axis is set on the active :func:`~transformer_engine.jax.global_mesh_resource` and has more - than one device, ``MoEBlock`` dispatches to the + than one device, ``_MoEBlock`` dispatches to the **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes its own tokens globally over all experts, then a forward ``ragged_all_to_all`` exchanges per-expert chunks so each shard @@ -292,7 +309,7 @@ class MoEBlock(TransformerEngineBase): # Parallelism # # The EP axis is resolved from ``global_mesh_resource().ep_resource`` - # and the active mesh, not configured per-instance. ``MoEBlock`` + # and the active mesh, not configured per-instance. ``_MoEBlock`` # uses ``_forward_a2a_ep`` when that axis exists on the mesh and # has > 1 device; otherwise it uses ``_forward_no_ep``. data_parallelism_axes: Tuple[str, ...] = () @@ -349,7 +366,7 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: """ assert ( inputs.ndim == 3 - ), f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" + ), f"_MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) _, _, hidden_size = inputs.shape @@ -796,7 +813,7 @@ def _forward_no_ep( Concretely: * ``inputs`` should be FSDP/DP-sharded on the batch dim - (``input_axes`` in :class:`MoEBlock` enforces this via a + (``input_axes`` in :class:`_MoEBlock` enforces this via a logical ``with_sharding_constraint``). * ``wi_*`` / ``wo`` weights should carry the logical axes ``wi_kernel_axes`` / ``wo_kernel_axes`` so FSDP shards a @@ -885,7 +902,7 @@ def _forward_a2a_ep( mesh = _get_mesh() if mesh is None or mesh.empty: raise ValueError( - "MoEBlock requires an active jax.sharding.Mesh (either via" + "_MoEBlock requires an active jax.sharding.Mesh (either via" " `with mesh:` or `jax.set_mesh`) when EP is configured on" " the active MeshResource." ) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 9fbaf64736..157575a441 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -999,7 +999,7 @@ def pure_jax_token_combine( # ============================================================================= # # These helpers support the ragged-all-to-all (A2A / A2Av) EP strategy used by -# :class:`transformer_engine.jax.flax.MoEBlock`. The forward EP path looks +# :class:`transformer_engine.jax.flax._MoEBlock`. The forward EP path looks # like:: # # route -> global_permute -> AG(group_sizes, ep) From ddf5d90bb8742093aae9fa6622c9b8d402d65c52 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 21 May 2026 14:43:15 -0700 Subject: [PATCH 15/18] [JAX] Refactor MoEBlock into a unified MoE custom_vjp, add tests Replace the per-primitive custom_vjp boundaries in MoEBlock with a single jax.custom_vjp covering routing, dispatch, expert FFN, and combine. Helper functions group permute -> ragged_all_to_all -> local-permute into a single dispatch / combine pair, with a hand- derived bwd that mirrors the forward and runs entirely inside the EP shard_map body. Add a multi-process (one-GPU-per-process) test suite for the new unified VJP under a 2x2 (ep, fsdp) mesh: * tests/jax/test_multiprocess_moe_vjp.py -- fwd/bwd + aux_loss + PURE_JAX vs TRITON parity at Mixtral-ish shapes (batch=16, seq=2048, hidden=1024, intermediate=4096, num_experts=8, topk=2). * tests/jax/run_multiprocess_moe_vjp.sh -- launcher; forks one pytest process per visible GPU (mirrors examples/jax/encoder/run_test_multiprocessing_encoder.sh). * tests/jax/conftest.py -- pytest --num-process / --process-id options for the launcher. * qa/L0_jax_distributed_unittest/test.sh -- CI hook for the multiprocess smoke. Signed-off-by: tdophung --- qa/L0_jax_distributed_unittest/test.sh | 8 + tests/jax/conftest.py | 14 + tests/jax/run_multiprocess_moe_vjp.sh | 130 ++ tests/jax/test_distributed_moe_block.py | 200 -- tests/jax/test_moe_block.py | 462 ---- tests/jax/test_moe_vjp.py | 449 ++++ tests/jax/test_multiprocess_moe_vjp.py | 335 +++ .../common/triton/permutation.py | 50 +- transformer_engine/jax/cpp_extensions/gemm.py | 9 +- transformer_engine/jax/flax/moe.py | 1122 +-------- transformer_engine/jax/moe.py | 2019 +++++++++++++++++ 11 files changed, 3086 insertions(+), 1712 deletions(-) create mode 100755 tests/jax/run_multiprocess_moe_vjp.sh delete mode 100644 tests/jax/test_distributed_moe_block.py delete mode 100644 tests/jax/test_moe_block.py create mode 100644 tests/jax/test_moe_vjp.py create mode 100644 tests/jax/test_multiprocess_moe_vjp.py create mode 100644 transformer_engine/jax/moe.py diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index 3f25816600..561a02797b 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -37,6 +37,14 @@ wait TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh" wait +# MoE custom_vjp distributed suite. Runs one Python process per GPU +# via tests/jax/run_multiprocess_moe_vjp.sh (mirrors the pattern in +# examples/jax/encoder/run_test_multiprocessing_encoder.sh). Requires +# >=4 visible GPUs. +TE_PATH=$TE_PATH bash $TE_PATH/tests/jax/run_multiprocess_moe_vjp.sh \ + || test_fail "test_multiprocess_moe_vjp.py" +wait + if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" exit 1 diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index db30f0ed39..74cb91202c 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -86,6 +86,20 @@ def pytest_sessionfinish(self, session, exitstatus): print("=" * 80) +def pytest_addoption(parser): + """CLI options used by multiprocess JAX tests. + + ``--num-process`` and ``--process-id`` let a multiprocess launcher + (see ``tests/jax/run_multiprocess_moe_vjp.sh``) fork one pytest + process per GPU and tell each child its rank, so the test module + can call ``jax.distributed.initialize(...)`` with the right + ``local_device_ids``. Both default to 0; non-multiprocess tests + ignore them. + """ + parser.addoption("--num-process", action="store", default=0) + parser.addoption("--process-id", action="store", default=0) + + def pytest_configure(config): config.addinivalue_line( "markers", diff --git a/tests/jax/run_multiprocess_moe_vjp.sh b/tests/jax/run_multiprocess_moe_vjp.sh new file mode 100755 index 0000000000..e03f6b77b1 --- /dev/null +++ b/tests/jax/run_multiprocess_moe_vjp.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +# +# Multiprocess (one-GPU-per-process) launcher for the unified MoE VJP +# test suite. Forks one pytest invocation per visible GPU, passing each +# its own --num-process=N --process-id=i, and waits for all of them. +# Each child calls jax.distributed.initialize(..., local_device_ids= +# process_id) so each Python process only sees its one GPU as a local +# device and the participating processes form a global mesh. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TE_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" +TEST_FILE="$TE_ROOT/tests/jax/test_multiprocess_moe_vjp.py" +PYTEST_INI="$TE_ROOT/tests/jax/pytest.ini" + +NUM_GPUS="${NUM_GPUS:-$(nvidia-smi -L | wc -l)}" +if [ "$NUM_GPUS" -lt 4 ]; then + echo "[run_multiprocess_moe_vjp.sh] need >=4 GPUs (got $NUM_GPUS); aborting" >&2 + exit 1 +fi + +export XLA_PYTHON_CLIENT_PREALLOCATE="${XLA_PYTHON_CLIENT_PREALLOCATE:-false}" +export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.5}" +export MOE_VJP_COORDINATOR_ADDRESS="${MOE_VJP_COORDINATOR_ADDRESS:-127.0.0.1:13456}" + +echo "============================================================" +echo "MoE VJP MULTIPROCESS test (one process per GPU, ${NUM_GPUS} GPUs)" +echo " test file : $TEST_FILE" +echo " coordinator : $MOE_VJP_COORDINATOR_ADDRESS" +echo " XLA_PYTHON_CLIENT_PREALLOCATE: $XLA_PYTHON_CLIENT_PREALLOCATE" +echo " XLA_PYTHON_CLIENT_MEM_FRACTION: $XLA_PYTHON_CLIENT_MEM_FRACTION" +echo "============================================================" + +# Per-process logs. MOE_VJP_MP_LOG_DIR can be set to a host-mounted dir +# (e.g. when running inside a container that throws away /tmp on exit) +# so logs survive for postmortem inspection. Defaults to a fresh /tmp. +if [ -n "${MOE_VJP_MP_LOG_DIR:-}" ]; then + LOG_DIR="$MOE_VJP_MP_LOG_DIR" + mkdir -p "$LOG_DIR" +else + LOG_DIR=$(mktemp -d -t moe_vjp_mp_XXXXXX) +fi +echo "Per-process logs: $LOG_DIR" + +PIDS=() + +cleanup() { + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -TERM "$pid" 2>/dev/null || true + fi + done + sleep 1 + for pid in "${PIDS[@]:-}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -KILL "$pid" 2>/dev/null || true + fi + done +} +trap cleanup EXIT INT TERM + +# Launch one pytest per GPU. Process 0 streams to stdout; others log +# only to file so the live output isn't a mosaic. +for i in $(seq 0 $((NUM_GPUS - 1))); do + LOG_FILE="$LOG_DIR/proc_${i}.log" + PYTEST_CMD=( + python3 -m pytest -c "$PYTEST_INI" + "$TEST_FILE" + -p no:typeguard + -v -s + --num-process="$NUM_GPUS" + --process-id="$i" + ) + if [ "$i" -eq 0 ]; then + echo "=== Live output from process 0 ===" + "${PYTEST_CMD[@]}" 2>&1 | tee "$LOG_FILE" & + else + "${PYTEST_CMD[@]}" > "$LOG_FILE" 2>&1 & + fi + PIDS+=("$!") +done + +# Wait for all and collect exit codes. +EXITS=() +for pid in "${PIDS[@]}"; do + if wait "$pid"; then + EXITS+=("0") + else + EXITS+=("$?") + fi +done + +# Summary. +echo +echo "============================================================" +echo "Per-process exit codes:" +for i in "${!EXITS[@]}"; do + echo " proc $i -> ${EXITS[$i]}" +done + +# Final pass/fail. Any non-zero in any process fails the suite, but +# we tolerate non-zero on the non-zero processes only if proc 0 +# reports PASS (this matches the encoder launcher's logic). Simplest +# strict rule: any non-zero is a failure. +FAILED=0 +for e in "${EXITS[@]}"; do + if [ "$e" != "0" ]; then + FAILED=1 + break + fi +done + +echo +if [ "$FAILED" -eq 0 ]; then + echo "[run_multiprocess_moe_vjp.sh] all processes PASSED" + if [ -z "${MOE_VJP_MP_LOG_DIR:-}" ]; then + rm -rf "$LOG_DIR" + fi + exit 0 +fi + +echo "[run_multiprocess_moe_vjp.sh] at least one process FAILED" +echo " retaining logs at $LOG_DIR for diagnosis" +echo " process 0 tail:" +tail -20 "$LOG_DIR/proc_0.log" 2>/dev/null || true +exit 1 diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py deleted file mode 100644 index 98fd6a7212..0000000000 --- a/tests/jax/test_distributed_moe_block.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Distributed tests for the experimental ``transformer_engine.jax.flax._MoEBlock``.""" - -import sys - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpy as np -import pytest -from jax.sharding import Mesh, PartitionSpec - -from utils import assert_allclose, is_devices_enough - - -@pytest.fixture(autouse=True, scope="function") -def _inject_moe(request): - """Lazy-load ``_MoEBlock`` only for tests marked ``triton``.""" - if not request.node.get_closest_marker("triton"): - yield - return - - from transformer_engine.jax import MeshResource, autocast - - # The class is intentionally exposed as ``_MoEBlock`` (experimental); - # aliasing to ``MoEBlock`` here keeps the test bodies readable. - from transformer_engine.jax.flax import _MoEBlock as MoEBlock - from transformer_engine.jax.flax.moe import PermutationBackend - - mod = sys.modules[__name__] - mod.MeshResource = MeshResource - mod.autocast = autocast - mod.MoEBlock = MoEBlock - mod.PermutationBackend = PermutationBackend - yield - - -DTYPE = jnp.bfloat16 -# Must be divisible by ep*fsdp = 4 so the batch dim can be sharded over -# the full ('ep','fsdp') axis tuple under Experiment 3. -BATCH_SIZE = 4 -SEQUENCE_LENGTH = 16 -HIDDEN_SIZE = 64 -INTERMEDIATE_SIZE = 128 -NUM_EXPERTS = 8 -NUM_EXPERTS_PER_TOK = 2 - - -def _make_inputs(key: jax.Array) -> jax.Array: - return jax.random.normal(key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE) - - -def _unwrap_partitioned(x): - return x.value if hasattr(x, "value") else x - - -@pytest.mark.triton -class TestDistributedMoEBlock: - @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_ep2_fsdp2_matches_single_device(self, permutation_backend): - if not is_devices_enough(4): - pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.") - - permutation_backend = PermutationBackend(permutation_backend) - key = jax.random.PRNGKey(11) - init_key, data_key = jax.random.split(key) - inputs = _make_inputs(data_key) - - base_kwargs = dict( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=permutation_backend, - aux_loss_coeff=1e-2, - dtype=DTYPE, - ) - - single_block = MoEBlock(**base_kwargs) - - def _make_loss_and_grad(block): - """Build a jitted ``value_and_grad`` over ``(variables, x)``. - - Capturing ``block`` in a closure (so it isn't a jit input) - sidesteps having to mark it as static -- Flax modules are - registered pytrees but they carry Python-level config that - jit treats as part of the trace. - """ - - def loss_fn(variables, x): - output, aux_loss = block.apply(variables, x) - loss = jnp.mean(output.astype(jnp.float32) ** 2) - if aux_loss is not None: - loss = loss + aux_loss.astype(jnp.float32) - return loss, (output, aux_loss) - - return jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) - - with autocast(enabled=False, mesh_resource=MeshResource()): - single_variables = single_block.init(init_key, inputs) - (single_loss, (single_output, single_aux)), single_grads = _make_loss_and_grad( - single_block - )(single_variables, inputs) - - devices = np.asarray(jax.devices()[:4]).reshape(2, 2) - mesh = Mesh(devices, ("ep", "fsdp")) - # FSDP-style sharding: weights are sharded on a *non-contracting* - # weight axis (gathered before the GEMM); activations stay sharded on - # the *batch* axis throughout - the same fsdp mesh axis is reused for - # both. The TE primitives' custom_partitioning rules expect activations - # FSDP-sharded on batch, so we declare ("batch", "fsdp") AND pass - # ``input_axes=("batch", None, None)`` to enforce it on the inputs to - # the block. ("embed", "fsdp") shards the weight's hidden dim, which - # is gathered inside grouped_dense's custom_partitioning before GEMM - # (no reshard of activations needed because their layout is unchanged). - logical_axis_rules = ( - ("exp", "ep"), - ("batch", "fsdp"), - ("embed", "fsdp"), - ) - # ``data_parallelism_axes=("fsdp",)`` opts in to the true-FSDP - # behavior: the ``shard_map``'s in_specs/out_specs become - # ``P(("ep","fsdp"), None, None)`` for the batch dim, so each - # device owns ``B/(ep*fsdp)`` unique tokens (no redundant compute - # across fsdp peers within an ep group). - sharded_block = MoEBlock( - data_parallelism_axes=("fsdp",), - input_axes=("batch", None, None), - **base_kwargs, - ) - - # ``MoEBlock`` resolves the EP axis from - # ``global_mesh_resource().ep_resource`` (set via ``autocast``), - # so the ``ep`` axis on the mesh is wired in by passing - # ``ep_resource="ep"`` here -- no per-instance config needed. - with mesh, autocast( - enabled=False, - mesh_resource=MeshResource(fsdp_resource="fsdp", ep_resource="ep"), - ): - with nn.logical_axis_rules(logical_axis_rules): - # ``MoEBlock`` registers params via ``with_logical_partitioning`` - # which only attaches LogicallyPartitioned metadata; the - # underlying jax.Array stays single-device unless ``init`` - # is run inside ``jax.jit`` with ``out_shardings``. Use the - # canonical Flax-Linen pattern (mirrors - # ``examples/jax/encoder/test_model_parallel_encoder.py``): - # 1. ``jax.eval_shape`` to trace abstract variables (keeps - # the LogicallyPartitioned wrappers; only the inner - # arrays become ShapeDtypeStruct); - # 2. ``nn.get_partition_spec`` to extract a tree of logical - # PartitionSpecs from those wrappers (treats - # LogicallyPartitioned as a leaf); - # 3. ``nn.logical_to_mesh_sharding`` to resolve those - # logical specs to NamedShardings via the active rules; - # 4. ``jax.jit(init, out_shardings=...)`` to actually - # place the params on-device with those shardings. - abstract_variables = jax.eval_shape(sharded_block.init, init_key, inputs) - logical_partition_spec = nn.get_partition_spec(abstract_variables) - out_shardings = nn.logical_to_mesh_sharding( - logical_partition_spec, mesh, logical_axis_rules - ) - sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( - init_key, inputs - ) - (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = _make_loss_and_grad( - sharded_block - )(sharded_variables, inputs) - - wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) - wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) - wo = _unwrap_partitioned(sharded_variables["params"]["wo"]) - assert wi_0.sharding.spec == PartitionSpec("ep", "fsdp", None) - assert wi_1.sharding.spec == PartitionSpec("ep", "fsdp", None) - assert wo.sharding.spec == PartitionSpec("ep", None, "fsdp") - - assert_allclose(sharded_output, single_output, dtype=DTYPE, atol=5e-2, rtol=5e-2) - assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2) - assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2) - - # The sharded path runs the same math on each ep-shard but - # accumulates gradients via psum across (ep, fsdp), which changes - # floating-point reduction order vs the single-device run. Under - # bf16 with these toy shapes the observed max-abs grad diff is on - # the order of a few units of bf16 eps (~1e-2). 5e-2 / 5e-2 - # leaves headroom for accumulation jitter without masking real - # divergence; matches the cross-backend bf16 grad tolerance in - # ``tests/jax/test_moe_block.py::test_pure_jax_matches_triton``. - for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - grad_single = _unwrap_partitioned(single_grads["params"][name]) - grad_sharded = _unwrap_partitioned(sharded_grads["params"][name]) - assert_allclose( - grad_sharded, - grad_single, - dtype=DTYPE, - atol=5e-2, - rtol=5e-2, - err_msg=f"Distributed gradient mismatch for {name}", - ) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py deleted file mode 100644 index fbe8c083e9..0000000000 --- a/tests/jax/test_moe_block.py +++ /dev/null @@ -1,462 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Basic tests for ``transformer_engine.jax.flax._MoEBlock``. - -These tests exercise the (experimental) ``_MoEBlock`` on a single device -(no expert parallelism) and verify: - -* Forward pass runs end-to-end and produces the expected output shape. -* Backward pass yields finite, non-trivial parameter gradients. -* The two permutation backends (``"pure_jax"`` and ``"triton"``) produce - numerically equivalent outputs and gradients when given the same routing - decisions. -* Auxiliary load-balancing loss is returned when ``aux_loss_coeff > 0``. -* DeepSeek-style grouped top-k (``num_groups`` / ``group_topk``) runs. -* ``_align_size > 0`` produces numerically-equivalent outputs to ``_align_size = 0`` - for the pure-JAX backend (padding must not change the result). -""" - -import sys -from typing import Tuple - -import jax -import jax.numpy as jnp -import pytest - - -# The ``_MoEBlock`` class pulls in both the fused-router CUDA kernel and -# the Triton permutation kernels, so it can only run in the environment -# where those are available. We gate the test on the ``triton`` marker (the -# Triton permutation backend is stricter than the CUDA router). See -# ``conftest.py``. - - -@pytest.fixture(autouse=True, scope="function") -def _inject_moe(request): - """Lazy-load ``_MoEBlock`` only for tests marked ``triton``.""" - if not request.node.get_closest_marker("triton"): - yield - return - - # The class is intentionally exposed as ``_MoEBlock`` (experimental); - # aliasing to ``MoEBlock`` here keeps the test bodies readable. - from transformer_engine.jax.flax import _MoEBlock as MoEBlock - from transformer_engine.jax.flax.moe import PermutationBackend - - mod = sys.modules[__name__] - mod.MoEBlock = MoEBlock - mod.PermutationBackend = PermutationBackend - yield - - -# ----------------------------------------------------------------------------- -# Configurations -# ----------------------------------------------------------------------------- -# -# Keep shapes small so the tests are cheap but still exercise every code path. - -DTYPE = jnp.bfloat16 -BATCH_SIZE = 2 -SEQUENCE_LENGTH = 16 -HIDDEN_SIZE = 64 -INTERMEDIATE_SIZE = 128 -NUM_EXPERTS = 8 -NUM_EXPERTS_PER_TOK = 2 - - -def _make_inputs( - key: jax.Array, batch_size: int = BATCH_SIZE, sequence_length: int = SEQUENCE_LENGTH -) -> jax.Array: - return jax.random.normal(key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE) - - -def _init_and_apply( - block, - inputs: jax.Array, - init_key: jax.Array, -) -> Tuple[dict, jax.Array, jax.Array]: - variables = block.init(init_key, inputs) - output, aux_loss = block.apply(variables, inputs) - return variables, output, aux_loss - - -def _unwrap_partitioned(x): - """Strip Flax logical-partition wrappers for numeric assertions.""" - return x.value if hasattr(x, "value") else x - - -# ----------------------------------------------------------------------------- -# Tests -# ----------------------------------------------------------------------------- - - -@pytest.mark.triton -class TestMoEBlockSingleDevice: - """Single-device smoke tests for :class:`_MoEBlock`.""" - - @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_forward_shape_and_finite(self, permutation_backend): - permutation_backend = PermutationBackend(permutation_backend) - key = jax.random.PRNGKey(0) - init_key, data_key = jax.random.split(key) - - block = MoEBlock( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=permutation_backend, - dtype=DTYPE, - ) - inputs = _make_inputs(data_key) - _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) - - assert ( - output.shape == inputs.shape - ), f"Unexpected output shape {output.shape} for backend {permutation_backend}" - assert output.dtype == inputs.dtype - assert jnp.all(jnp.isfinite(output)), "Output contains NaN/Inf" - assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" - - @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_backward_grad_is_finite_and_nonzero(self, permutation_backend): - permutation_backend = PermutationBackend(permutation_backend) - key = jax.random.PRNGKey(1) - init_key, data_key = jax.random.split(key) - - block = MoEBlock( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=permutation_backend, - dtype=DTYPE, - ) - inputs = _make_inputs(data_key) - variables = block.init(init_key, inputs) - - def loss_fn(variables, inputs): - output, _ = block.apply(variables, inputs) - return jnp.mean(output.astype(jnp.float32) ** 2) - - grads = jax.grad(loss_fn)(variables, inputs) - # All trainable kernels should receive a non-trivial gradient. - for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g = _unwrap_partitioned(grads["params"][name]) - assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" - assert jnp.any(g != 0.0), f"{name} gradient is identically zero" - - def test_pure_jax_triton_equivalence(self): - """Both permutation backends must produce the same forward + grads - under identical routing decisions. - - Since the two backends share the same routing path (TE's fused - top-k), fixing the gate kernel gives both the same routing decisions - and the remainder of the network is identical modulo the permutation - implementation, whose semantics are equivalent. - """ - key = jax.random.PRNGKey(2) - init_key, data_key = jax.random.split(key) - - base_kwargs = dict( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - dtype=DTYPE, - ) - pure_block = MoEBlock(permutation_backend=PermutationBackend.PURE_JAX, **base_kwargs) - triton_block = MoEBlock(permutation_backend=PermutationBackend.TRITON, **base_kwargs) - inputs = _make_inputs(data_key) - - # Share a single parameter tree so routing decisions and expert - # weights are identical for both backends. - variables = pure_block.init(init_key, inputs) - - def loss_fn(block, variables, inputs): - output, _ = block.apply(variables, inputs) - return jnp.mean(output.astype(jnp.float32) ** 2), output - - (loss_pj, out_pj), grads_pj = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( - pure_block, variables, inputs - ) - (loss_tr, out_tr), grads_tr = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( - triton_block, variables, inputs - ) - - # BF16 tolerances: outputs come out of the grouped-GEMM + weighted - # sum so they accumulate error; we use ~2 ULPs worth of slack. - atol_out, rtol_out = 5e-2, 5e-2 - assert jnp.allclose( - out_pj, out_tr, atol=atol_out, rtol=rtol_out - ), f"Forward outputs differ across backends: max diff {jnp.max(jnp.abs(out_pj - out_tr))}" - assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) - - # The two backends share the routing path (same fused top-k) and - # the same expert FFN; the only difference is the order of the - # gather + scatter ops in dispatch/combine. Under bf16 with these - # small shapes, observed grad max-abs-diff is on the order of a - # few-units-of-bf16-eps (~1e-2). 5e-2 / 5e-2 leaves headroom for - # accumulation jitter without masking real divergence. If this - # tightens too far on a particular GPU, print - # ``jnp.max(jnp.abs(g_pj - g_tr))`` from the failing assertion - # and bump to the next safe value with a comment recording the - # measured gap. - atol_grad, rtol_grad = 5e-2, 5e-2 - for name in ("gate_kernel", "wi_0", "wi_1", "wo"): - g_pj = _unwrap_partitioned(grads_pj["params"][name]) - g_tr = _unwrap_partitioned(grads_tr["params"][name]) - assert jnp.allclose(g_pj, g_tr, atol=atol_grad, rtol=rtol_grad), ( - f"Gradient for {name} differs across backends: max diff" - f" {jnp.max(jnp.abs(g_pj - g_tr))} (atol={atol_grad}," - f" rtol={rtol_grad})" - ) - - @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_aux_loss_returned(self, permutation_backend): - permutation_backend = PermutationBackend(permutation_backend) - key = jax.random.PRNGKey(3) - init_key, data_key = jax.random.split(key) - - block = MoEBlock( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=permutation_backend, - aux_loss_coeff=1e-2, - dtype=DTYPE, - ) - inputs = _make_inputs(data_key) - _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) - - assert output.shape == inputs.shape - assert aux_loss is not None, "aux_loss should be returned when coeff > 0" - assert aux_loss.shape == (), "aux_loss should be a scalar" - assert jnp.isfinite(aux_loss) - # With uniform-ish routing the loss should be small-positive, not huge. - assert jnp.abs(aux_loss) < 1e2 - - def test_aux_loss_uses_real_routing_under_group_topk(self): - """Aux loss must reflect the real (post-group) routing decisions. - - Under DeepSeek-style ``num_groups`` / ``group_topk`` routing, - the auxiliary load-balancing loss must be computed using the - per-expert token counts from the *real* routing_map (post - grouping), not from the clean top-k that the - ``compute_aux_scores=True`` kernel returns. Otherwise the aux - objective trains against the wrong distribution. - - We compute three values: - * ``corrected_ref`` -- ``fused_moe_aux_loss(aux_scores, - tokens_from_real_routing_map, ...)`` (what the block - should produce after the fix). - * ``buggy_ref`` -- ``fused_moe_aux_loss(aux_scores, - tokens_from_aux_routing_map, ...)`` (what the block used - to produce before the fix). - * ``block_aux_loss`` -- what the block actually produces. - - Block must match the corrected reference. We also assert that - the corrected and buggy references differ for this config so - the test is not vacuously satisfied by them coinciding. - """ - from transformer_engine.jax.router import ( - fused_moe_aux_loss, - fused_topk_with_score_function, - ) - - key = jax.random.PRNGKey(7) - init_key, data_key = jax.random.split(key) - - # Pick a config that *reliably* exercises grouped-vs-clean - # divergence: with ``group_topk=1`` only ONE group's experts - # can be selected by grouped routing, so the routing diverges - # from a plain top-k whenever the global top-K experts are - # spread across multiple groups (which is almost always the - # case for random init + ``num_experts_per_tok > 1``). - num_groups = 2 - group_topk = 1 - aux_loss_coeff = 1e-2 - - block = MoEBlock( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=PermutationBackend.PURE_JAX, - score_function="sigmoid", - num_groups=num_groups, - group_topk=group_topk, - aux_loss_coeff=aux_loss_coeff, - dtype=DTYPE, - ) - inputs = _make_inputs(data_key) - variables = block.init(init_key, inputs) - _output, block_aux_loss = block.apply(variables, inputs) - - assert block_aux_loss is not None - - # Reproduce the gating GEMM and routing externally so we can - # build the references against the same logits the block sees. - gate_kernel = _unwrap_partitioned(variables["params"]["gate_kernel"]) - gate_kernel = gate_kernel.astype(inputs.dtype) - logits = jnp.einsum("bsh,he->bse", inputs, gate_kernel) - logits_2d = logits.reshape(-1, NUM_EXPERTS) - - # Real routing (with grouping). This is what _route_topk - # would produce inside the block. - _, real_routing_map = fused_topk_with_score_function( - logits_2d, - topk=NUM_EXPERTS_PER_TOK, - score_function="sigmoid", - num_groups=num_groups, - group_topk=group_topk, - ) - real_tokens = jnp.sum(real_routing_map.astype(jnp.int32), axis=0) - - # Aux scores + the (clean topk) aux_routing_map that the old - # buggy code used for tokens_per_expert. - aux_scores, aux_routing_map = fused_topk_with_score_function( - logits_2d.astype(jnp.float32), - topk=NUM_EXPERTS_PER_TOK, - score_function="sigmoid", - compute_aux_scores=True, - ) - buggy_tokens = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) - - corrected_ref = fused_moe_aux_loss( - aux_scores.astype(jnp.float32), - real_tokens, - topk=NUM_EXPERTS_PER_TOK, - coeff=aux_loss_coeff, - ) - buggy_ref = fused_moe_aux_loss( - aux_scores.astype(jnp.float32), - buggy_tokens, - topk=NUM_EXPERTS_PER_TOK, - coeff=aux_loss_coeff, - ) - - # Sanity: the test config must actually exercise the bug - # (otherwise both references coincide and the assertion below - # would silently pass even with the old code). - assert not jnp.allclose(real_tokens, buggy_tokens), ( - "Test config does not exercise grouped-topk vs clean-topk" - " divergence; pick a config where they differ" - ) - - assert jnp.allclose( - block_aux_loss, corrected_ref, atol=1e-5, rtol=1e-5 - ), f"Block aux_loss {block_aux_loss} does not match real-routing reference {corrected_ref}" - # The corrected and buggy refs can be numerically close - # (only the mis-routed tokens contribute to the difference), - # so assert that the block is *strictly closer* to the - # corrected ref than to the buggy one. This catches the - # regression robustly even when the absolute gap between - # corrected_ref and buggy_ref is sub-tolerance. - diff_to_corrected = jnp.abs(block_aux_loss - corrected_ref) - diff_to_buggy = jnp.abs(block_aux_loss - buggy_ref) - gap = jnp.abs(corrected_ref - buggy_ref) - assert diff_to_corrected < diff_to_buggy, ( - f"Block aux_loss {block_aux_loss} is closer to the *old" - f" buggy* reference ({buggy_ref}, diff={diff_to_buggy})" - f" than to the corrected reference ({corrected_ref}," - f" diff={diff_to_corrected}); the regression has" - f" reappeared. corrected-buggy gap = {gap}" - ) - - @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_group_topk_deepseek(self, permutation_backend): - """Exercise DeepSeek-style grouped top-k routing.""" - permutation_backend = PermutationBackend(permutation_backend) - key = jax.random.PRNGKey(4) - init_key, data_key = jax.random.split(key) - - # num_groups must divide num_experts. - num_groups = 4 - group_topk = 2 - block = MoEBlock( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=permutation_backend, - score_function="sigmoid", - num_groups=num_groups, - group_topk=group_topk, - dtype=DTYPE, - ) - inputs = _make_inputs(data_key) - _variables, output, _aux_loss = _init_and_apply(block, inputs, init_key) - - assert output.shape == inputs.shape - assert jnp.all(jnp.isfinite(output)) - - def test_align_size_equivalence_pure_jax(self, monkeypatch): - """For the pure-JAX backend, ``_align_size > 0`` must not change the - numerical output of the forward pass: padding tokens contribute zero - to every expert GEMM output (their input rows are zeros) and are - stripped before the weighted sum. - - Why the env knob: the V1 TE grouped GEMM FFI asserts strict - equality ``sum(group_sizes) == M``. With ``_align_size > 0`` the - pure-JAX backend produces a buffer where ``M >= sum(group_sizes)`` - (the slack is structural padding for JIT), so V1 is incompatible. - The V2 cuBLASLt-backed grouped GEMM relaxes the assertion to - ``M >= sum(group_sizes)`` and is selected when - ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1``. If V2 isn't supported on - this hardware / for this dtype, the dispatch raises a - ``RuntimeError`` whose message is matched here so the test - ``skip``-s instead of failing. - """ - monkeypatch.setenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "1") - - key = jax.random.PRNGKey(5) - init_key, data_key = jax.random.split(key) - - base_kwargs = dict( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=PermutationBackend.PURE_JAX, - dtype=DTYPE, - ) - block_no_pad = MoEBlock(_align_size=0, **base_kwargs) - block_pad = MoEBlock(_align_size=16, **base_kwargs) - inputs = _make_inputs(data_key) - - try: - variables = block_no_pad.init(init_key, inputs) - out_no_pad, _ = block_no_pad.apply(variables, inputs) - out_pad, _ = block_pad.apply(variables, inputs) - except RuntimeError as exc: - if "V2 grouped GEMM is not supported" in str(exc): - pytest.skip(f"V2 grouped GEMM unavailable on this hardware: {exc}") - raise - - assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( - "_align_size > 0 must not change pure_jax forward output; max diff" - f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" - ) - - @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) - def test_jit_and_determinism(self, permutation_backend): - """The block must be JIT-compilable and produce a deterministic - forward pass across repeat calls with the same params.""" - permutation_backend = PermutationBackend(permutation_backend) - key = jax.random.PRNGKey(6) - init_key, data_key = jax.random.split(key) - - block = MoEBlock( - num_experts=NUM_EXPERTS, - num_experts_per_tok=NUM_EXPERTS_PER_TOK, - intermediate_size=INTERMEDIATE_SIZE, - permutation_backend=permutation_backend, - dtype=DTYPE, - ) - inputs = _make_inputs(data_key) - variables = block.init(init_key, inputs) - - @jax.jit - def forward(variables, inputs): - return block.apply(variables, inputs)[0] - - out_a = forward(variables, inputs) - out_b = forward(variables, inputs) - assert jnp.array_equal(out_a, out_b), "JITted forward is non-deterministic" diff --git a/tests/jax/test_moe_vjp.py b/tests/jax/test_moe_vjp.py new file mode 100644 index 0000000000..92d95bc896 --- /dev/null +++ b/tests/jax/test_moe_vjp.py @@ -0,0 +1,449 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Single-device tests for the unified MoE custom_vjp at +``transformer_engine.jax.moe.moe`` (and its Flax wrapper +``transformer_engine.jax.flax._MoEBlock``). + +Strategy +-------- + +Rather than reproducing every internal kernel residual, we rely on a +single end-to-end pure-JAX *reference* implementation of the whole +MoE block (``_pure_jax_moe_reference`` below) and compare the TE +``moe(...)`` forward output AND parameter gradients against it. This +gives us coverage of: + +* the gate GEMM, +* the fused top-k routing primitive (and its bwd), +* the dispatch / per-expert FFN / combine pipeline (and their bwds + threaded through the absorbed primitives), +* the optional aux-loss path (and its bwd). + +The reference uses only ``jnp`` ops + ``jax.vjp``, so we get a +"definitive" pullback to compare against without needing the TE +primitive bwd kernels. + +Distributed (EP + FSDP) testing is intentionally NOT in this file -- +that needs a multi-device setup and lives in +``tests/jax/test_distributed_moe_vjp.py`` (follow-up). +""" + +from functools import partial +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + + +# Lazy import (mirrors the gating in the old test file): the underlying +# kernels require triton + the fused-router CUDA kernel. +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + if not request.node.get_closest_marker("triton"): + yield + return + import sys + from transformer_engine.jax.flax import _MoEBlock as MoEBlock + from transformer_engine.jax.moe import PermutationBackend, moe + + mod = sys.modules[__name__] + mod.MoEBlock = MoEBlock + mod.PermutationBackend = PermutationBackend + mod.moe = moe + yield + + +# ----------------------------------------------------------------------------- +# Test config +# ----------------------------------------------------------------------------- + +DTYPE = jnp.float32 # use fp32 for tighter parity assertions +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 32 +INTERMEDIATE_SIZE = 64 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs(key: jax.Array, *, batch=BATCH_SIZE, seq=SEQUENCE_LENGTH) -> jax.Array: + return jax.random.normal(key, (batch, seq, HIDDEN_SIZE), dtype=DTYPE) + + +# ----------------------------------------------------------------------------- +# Pure-JAX reference MoE +# ----------------------------------------------------------------------------- +# +# Implements EXACTLY the same math as ``moe(...)`` for the no-EP, +# softmax-routing, no-bias, silu activation, no-quantization path. +# Returns ``(output, aux_loss_or_zero)``. Used as ground truth for both +# fwd and bwd parity. + + +@partial( + jax.jit, + static_argnames=("num_experts", "num_experts_per_tok", "aux_loss_coeff"), +) +def _pure_jax_moe_reference( + x: jnp.ndarray, + gate_kernel: jnp.ndarray, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + *, + num_experts: int, + num_experts_per_tok: int, + aux_loss_coeff: float = 0.0, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Reference no-EP MoE forward (pure JAX, no TE primitives). + + Mirrors :func:`transformer_engine.jax.moe._body_fwd` for the + PURE_JAX backend, no biases, softmax routing, silu activation, + no quantization. Linear ops only -- ``jax.vjp`` over this gives + the canonical bwd to compare against. + """ + B, S, H = x.shape + T = B * S + x_2d = x.reshape(T, H) + + # Gate + logits = x_2d @ gate_kernel # [T, E] + + # Softmax + topk (no expert_bias, no grouping, scale=1.0) + probs_full = jax.nn.softmax(logits, axis=-1) # [T, E] + # top-k by probability: + sorted_idx = jnp.argsort(probs_full, axis=-1) # ascending + selected = sorted_idx[:, -num_experts_per_tok:] # [T, K] + weights = jnp.take_along_axis(probs_full, selected, axis=-1) # [T, K] + # Normalize topk weights to sum to 1 (matches softmax->topk semantics + # of fused_topk_with_score_function with use_pre_softmax=False): + weights = weights / jnp.sum(weights, axis=-1, keepdims=True) + + # Build a sparse routing_map [T, E] with weights at selected positions + routing_weights_full = jnp.zeros_like(probs_full) + routing_weights_full = routing_weights_full.at[jnp.arange(T)[:, None], selected].set(weights) + + # Per-expert FFN: replicate each token K times, gather by expert, + # run through wi_0 / wi_1 / wo, gather back, weighted-sum. + # + # Vectorize the gather without sorting: for each (token, slot k), + # multiply the corresponding expert's FFN by routing_weights[t, k] + # and sum over experts. + # x_2d: [T, H], wi_0: [E, H, M], wi_1: [E, H, M], wo: [E, M, H] + # For each expert e: layer_w0_e = x_2d @ wi_0[e]; layer_w1_e = x_2d @ wi_1[e] + # intermediate_e = silu(layer_w0_e) * layer_w1_e + # expert_out_e = intermediate_e @ wo[e] + # output[t, h] = sum_e routing_weights_full[t, e] * expert_out_e[t, h] + layer_w0 = jnp.einsum("th,ehm->tem", x_2d, wi_0) # [T, E, M] + layer_w1 = jnp.einsum("th,ehm->tem", x_2d, wi_1) # [T, E, M] + intermediate = jax.nn.silu(layer_w0) * layer_w1 # [T, E, M] + expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H] + output_2d = jnp.einsum("te,teh->th", routing_weights_full, expert_out) # [T, H] + output = output_2d.reshape(B, S, H) + + if aux_loss_coeff > 0.0: + # aux scores: clean per-expert softmax (compute_aux_scores=True + # kernel uses a clean softmax, no bias, scale=1, no grouping). + aux_probs = jax.nn.softmax(logits.astype(jnp.float32), axis=-1) + # tokens_per_expert from REAL routing_map (post-grouping); here + # there's no grouping so == count of non-zero positions per expert. + routing_map = (routing_weights_full > 0).astype(jnp.int32) + tokens_per_expert = jnp.sum(routing_map, axis=0) # [E] + # aux_loss formula: (E * coeff / (k * T^2)) * sum_e + # (sum_t aux_probs[t, e]) * tokens_per_expert[e] + sum_probs_per_expert = jnp.sum(aux_probs, axis=0) # [E] + aux_loss = (num_experts * aux_loss_coeff / (num_experts_per_tok * (T**2))) * jnp.sum( + sum_probs_per_expert * tokens_per_expert.astype(jnp.float32) + ) + else: + aux_loss = jnp.zeros((), dtype=DTYPE) + + return output, aux_loss + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _init_params(key: jax.Array) -> dict: + k_g, k_w0, k_w1, k_wo = jax.random.split(key, 4) + init = jax.nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + return dict( + gate_kernel=init(k_g, (HIDDEN_SIZE, NUM_EXPERTS), DTYPE), + wi_0=init(k_w0, (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), DTYPE), + wi_1=init(k_w1, (NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE), DTYPE), + wo=init(k_wo, (NUM_EXPERTS, INTERMEDIATE_SIZE, HIDDEN_SIZE), DTYPE), + ) + + +@partial(jax.jit, static_argnames=("permutation_backend", "aux_loss_coeff")) +def _run_te_moe( + x: jnp.ndarray, + params: dict, + *, + permutation_backend, + aux_loss_coeff: float = 0.0, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + return moe( # noqa: F821 -- injected by fixture + x, + params["gate_kernel"], + params["wi_0"], + params["wi_1"], + params["wo"], + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + activation_type="silu", + score_function="softmax", + use_pre_softmax=False, + scaling_factor=1.0, + aux_loss_coeff=aux_loss_coeff, + permutation_backend=permutation_backend, + align_size=0, + dtype=DTYPE, + ) + + +@partial(jax.jit, static_argnames=("permutation_backend", "aux_loss_coeff")) +def _grads_te_main_loss(params, x, *, permutation_backend, aux_loss_coeff: float = 0.0): + """jit'd grad of ``mean(out**2)`` w.r.t. params (no aux contribution).""" + + def loss(params, x): + out, _ = _run_te_moe( + x, params, permutation_backend=permutation_backend, aux_loss_coeff=aux_loss_coeff + ) + return jnp.mean(out**2) + + return jax.grad(loss)(params, x) + + +@partial(jax.jit, static_argnames=("num_experts", "num_experts_per_tok", "aux_loss_coeff")) +def _grads_ref_main_loss(params, x, *, num_experts, num_experts_per_tok, aux_loss_coeff=0.0): + """jit'd grad of ``mean(out**2)`` w.r.t. params on the pure-JAX ref.""" + + def loss(params, x): + out, _ = _pure_jax_moe_reference( + x, + **params, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + aux_loss_coeff=aux_loss_coeff, + ) + return jnp.mean(out**2) + + return jax.grad(loss)(params, x) + + +@partial(jax.jit, static_argnames=("permutation_backend",)) +def _grad_te_aux_only(params, x, *, permutation_backend): + """jit'd grad of just the aux loss scalar (no main contribution).""" + + def aux_only(params, x): + _, aux = _run_te_moe( + x, params, permutation_backend=permutation_backend, aux_loss_coeff=1e-2 + ) + return aux.astype(jnp.float32) + + return jax.grad(aux_only)(params, x) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.triton +class TestMoeVjpForward: + """Forward shape / finiteness / parity vs pure-JAX reference.""" + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_forward_shape_and_finite(self, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(0) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + out, aux = _run_te_moe(x, params, permutation_backend=backend) + assert out.shape == x.shape + assert out.dtype == x.dtype + assert jnp.all(jnp.isfinite(out)) + assert aux is None + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_forward_parity_vs_pure_jax_reference(self, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(1) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + out_te, _ = _run_te_moe(x, params, permutation_backend=backend) + out_ref, _ = _pure_jax_moe_reference( + x, + **params, + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + ) + # FP32, small shapes -> tight tolerance + np.testing.assert_allclose(np.array(out_te), np.array(out_ref), atol=2e-5, rtol=2e-5) + + def test_pure_jax_triton_equivalence(self): + key = jax.random.PRNGKey(2) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + out_pj, _ = _run_te_moe( + x, params, permutation_backend=PermutationBackend.PURE_JAX # noqa: F821 + ) + out_tr, _ = _run_te_moe( + x, params, permutation_backend=PermutationBackend.TRITON # noqa: F821 + ) + np.testing.assert_allclose(np.array(out_pj), np.array(out_tr), atol=2e-5, rtol=2e-5) + + +@pytest.mark.triton +class TestMoeVjpBackward: + """Backward parity vs pure-JAX reference (which uses ``jax.vjp`` over + plain JAX ops, giving us the canonical pullback).""" + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_grads_finite_and_nonzero(self, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(3) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + grads = _grads_te_main_loss(params, x, permutation_backend=backend) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g = grads[name] + assert jnp.all(jnp.isfinite(g)), f"{name} grad has NaN/Inf" + assert jnp.any(g != 0.0), f"{name} grad is identically zero" + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_grads_match_pure_jax_reference(self, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(4) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + grads_te = _grads_te_main_loss(params, x, permutation_backend=backend) + grads_ref = _grads_ref_main_loss( + params, + x, + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + ) + # Loose-ish tol on grads: routing path has discrete topk so the + # softmax cotangent paths through the non-topk experts diverge + # slightly between TE (which uses the fused topk bwd) and the + # reference (which uses argsort-based take_along_axis). + # Tighter than the bf16 tests. + for name in ("wi_0", "wi_1", "wo"): + np.testing.assert_allclose( + np.array(grads_te[name]), + np.array(grads_ref[name]), + atol=5e-5, + rtol=5e-5, + err_msg=f"grad mismatch on {name}", + ) + # Gate grad has more error budget because it propagates through + # the topk derivative kernel (which differs in zero-pattern + # treatment from a plain take_along_axis). + np.testing.assert_allclose( + np.array(grads_te["gate_kernel"]), + np.array(grads_ref["gate_kernel"]), + atol=5e-4, + rtol=5e-4, + err_msg="grad mismatch on gate_kernel", + ) + + +@pytest.mark.triton +class TestMoeVjpAuxLoss: + """Aux-loss path: forward + grad parity.""" + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_aux_loss_returned_and_finite(self, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(5) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + _, aux = _run_te_moe(x, params, permutation_backend=backend, aux_loss_coeff=1e-2) + assert aux is not None + assert aux.shape == () + assert jnp.isfinite(aux) + assert jnp.abs(aux) < 1e2 + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_aux_loss_parity_vs_reference(self, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(6) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + _, aux_te = _run_te_moe(x, params, permutation_backend=backend, aux_loss_coeff=1e-2) + _, aux_ref = _pure_jax_moe_reference( + x, + **params, + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + aux_loss_coeff=1e-2, + ) + np.testing.assert_allclose(float(aux_te), float(aux_ref), atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_aux_loss_grads_propagate_to_logits(self, backend_name): + """The aux-loss bwd path must produce non-zero gate-kernel grads + when only the aux-loss scalar is differentiated (no main-output + contribution).""" + backend = PermutationBackend(backend_name) # noqa: F821 + key = jax.random.PRNGKey(7) + kp, kx = jax.random.split(key) + params = _init_params(kp) + x = _make_inputs(kx) + g_gate = _grad_te_aux_only(params, x, permutation_backend=backend)["gate_kernel"] + assert jnp.all(jnp.isfinite(g_gate)) + assert jnp.any( + g_gate != 0.0 + ), "aux_loss bwd should propagate to gate_kernel via fused_topk bwd" + + +# ----------------------------------------------------------------------------- +# Flax wrapper smoke test +# ----------------------------------------------------------------------------- + + +@pytest.mark.triton +class TestMoEBlockFlaxWrapper: + """Sanity-check the thin Flax wrapper: forward + grad on init.""" + + def test_init_and_apply(self): + block = MoEBlock( # noqa: F821 + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=PermutationBackend.PURE_JAX, # noqa: F821 + dtype=DTYPE, + ) + key = jax.random.PRNGKey(8) + ki, kx = jax.random.split(key) + x = _make_inputs(kx) + variables = jax.jit(block.init)(ki, x) + out, aux = jax.jit(block.apply)(variables, x) + assert out.shape == x.shape + assert aux is None + + @jax.jit + def grad_fn(variables, x): + return jax.grad(lambda v, x: jnp.mean(block.apply(v, x)[0] ** 2))(variables, x) + + grads = grad_fn(variables, x) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g = grads["params"][name] + g = g.value if hasattr(g, "value") else g + assert jnp.all(jnp.isfinite(g)), f"{name} grad NaN/Inf" + assert jnp.any(g != 0.0), f"{name} grad zero" diff --git a/tests/jax/test_multiprocess_moe_vjp.py b/tests/jax/test_multiprocess_moe_vjp.py new file mode 100644 index 0000000000..ddf04f0ea4 --- /dev/null +++ b/tests/jax/test_multiprocess_moe_vjp.py @@ -0,0 +1,335 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Multi-process (one-GPU-per-process) tests for the unified MoE custom_vjp. + +The launcher ``tests/jax/run_multiprocess_moe_vjp.sh`` forks one pytest +process per visible GPU (mirroring +``examples/jax/encoder/run_test_multiprocessing_encoder.sh``). Each +process binds to exactly one device via +``jax.distributed.initialize(..., local_device_ids=process_id)``; the +participating processes form a global mesh through JAX's distributed +runtime. + +How to run +---------- + +You typically do NOT invoke pytest on this file directly -- use the +launcher, which passes ``--num-process=N --process-id=i`` to each +forked process. Driving it directly with only one process will skip +every test because :func:`jax.distributed.initialize` requires +multiple participants. + + bash tests/jax/run_multiprocess_moe_vjp.sh + +CI invocation lives in ``qa/L0_jax_distributed_unittest/test.sh``. +""" + +import os + +# NCCL needs HBM headroom that JAX's default 90% preallocation does +# not leave. Set before any jax import below. +os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") +os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.5") + +import sys + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from flax.linen import partitioning as nn_partitioning + + +# Per-process distributed bootstrap. Each pytest invocation initializes +# JAX with exactly one local device (its assigned GPU). Once +# initialized, the four processes form one global mesh of 4 devices. +def _init_distributed(num_process: int, process_id: int) -> bool: + """Initialize jax.distributed for this pytest process. + + Returns True if initialization succeeded (i.e. this is a real + multi-process launch), False if num_process == 0 / 1 meaning the + file is being collected without a launcher and tests should be + skipped at module level. + """ + if num_process <= 1: + return False + coord = os.environ.get("MOE_VJP_COORDINATOR_ADDRESS", "127.0.0.1:1234") + jax.distributed.initialize( + coordinator_address=coord, + num_processes=num_process, + process_id=process_id, + local_device_ids=process_id, + ) + assert jax.local_device_count() == 1, "one GPU per process is the whole point" + assert ( + jax.device_count() == num_process + ), f"global device_count {jax.device_count()} != num_process {num_process}" + return True + + +# Read --num-process / --process-id BEFORE pytest collects any tests so +# we can fast-skip the whole module when not in a multiprocess launch. +def _read_mp_options(): + # Use pytest's option lookup via the request fixture isn't available + # at module top-level; parse argv ourselves the same way encoder + # test does. CLI form is e.g. "pytest ... --num-process=4 --process-id=0". + num = int(os.environ.get("MP_NUM_PROCESS", "0") or "0") + pid = int(os.environ.get("MP_PROCESS_ID", "0") or "0") + for i, a in enumerate(sys.argv): + if a.startswith("--num-process="): + num = int(a.split("=", 1)[1]) + elif a == "--num-process" and i + 1 < len(sys.argv): + num = int(sys.argv[i + 1]) + elif a.startswith("--process-id="): + pid = int(a.split("=", 1)[1]) + elif a == "--process-id" and i + 1 < len(sys.argv): + pid = int(sys.argv[i + 1]) + return num, pid + + +_MP_NUM_PROCESS, _MP_PROCESS_ID = _read_mp_options() +_MP_ACTIVE = _init_distributed(_MP_NUM_PROCESS, _MP_PROCESS_ID) + +if not _MP_ACTIVE: + # Skip the entire module if not launched via the multiprocess + # runner. Lets `pytest tests/jax/` collect this file harmlessly. + pytest.skip( + "test_multiprocess_moe_vjp.py requires the multiprocess launcher " + "(run_multiprocess_moe_vjp.sh). Skipping.", + allow_module_level=True, + ) + + +NUM_DEVICES_REQUIRED = 4 +EP_AXIS = "ep" +FSDP_AXIS = "fsdp" +EP_SIZE = 2 +FSDP_SIZE = 2 + +LOGICAL_AXIS_RULES = ( + ("exp", EP_AXIS), + ("embed", FSDP_AXIS), + ("mlp", None), + ("batch", (EP_AXIS, FSDP_AXIS)), +) + + +@pytest.fixture(scope="module") +def mesh(): + if jax.device_count() < NUM_DEVICES_REQUIRED: + pytest.skip( + f"Need >={NUM_DEVICES_REQUIRED} devices for ep={EP_SIZE} x fsdp={FSDP_SIZE};" + f" have {jax.device_count()}" + ) + devices = mesh_utils.create_device_mesh((EP_SIZE, FSDP_SIZE)) + return Mesh(devices, axis_names=(EP_AXIS, FSDP_AXIS)) + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + if not request.node.get_closest_marker("triton"): + yield + return + from transformer_engine.jax.flax import _MoEBlock as MoEBlock + from transformer_engine.jax.moe import PermutationBackend + from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + mod = sys.modules[__name__] + mod.MoEBlock = MoEBlock + mod.PermutationBackend = PermutationBackend + mod.MeshResource = MeshResource + mod.global_shard_guard = global_shard_guard + yield + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _make_block( + *, + num_experts, + num_experts_per_tok, + intermediate_size, + permutation_backend, + aux_loss_coeff=0.0, + dtype=jnp.bfloat16, + align_size=0, +): + return MoEBlock( # noqa: F821 + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + intermediate_size=intermediate_size, + permutation_backend=permutation_backend, + data_parallelism_axes=(FSDP_AXIS,), + aux_loss_coeff=aux_loss_coeff, + dtype=dtype, + _align_size=align_size, + ) + + +def _shard_inputs(x, mesh): + return jax.lax.with_sharding_constraint( + x, NamedSharding(mesh, P((EP_AXIS, FSDP_AXIS), None, None)) + ) + + +def _init_apply(block, mesh, x, key): + with mesh, global_shard_guard( # noqa: F821 + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) # noqa: F821 + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + x = _shard_inputs(x, mesh) + variables = jax.jit(block.init)(key, x) + jax.block_until_ready(jax.tree_util.tree_leaves(variables)[0]) + output, aux = jax.jit(block.apply)(variables, x) + jax.block_until_ready(output) + return variables, output, aux + + +def _grad_step(block, variables, mesh, x): + with mesh, global_shard_guard( # noqa: F821 + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) # noqa: F821 + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + x = _shard_inputs(x, mesh) + + def loss_fn(variables, x): + output, aux = block.apply(variables, x) + main = jnp.mean(output.astype(jnp.float32) ** 2) + return main + (aux.astype(jnp.float32) if aux is not None else 0.0) + + grads = jax.jit(jax.grad(loss_fn))(variables, x) + jax.block_until_ready(jax.tree_util.tree_leaves(grads)[0]) + return grads + + +def _unwrap(x): + return x.value if hasattr(x, "value") else x + + +def _local_shard(x): + """Return the local (this-process) shard of a global JAX Array as numpy. + + Every assertion in this file is structural (finite-ness, non-zero, + parity within tolerance). For all of these, checking the local + shard on each process is sufficient and avoids any cross-process + collective in the test machinery. ``arr.addressable_data(0)`` + returns the local-device view of the sharded array -- with one + GPU per process there is exactly one addressable shard. + """ + return np.asarray(jax.device_get(x.addressable_data(0))) + + +# ----------------------------------------------------------------------------- +# Mixtral-style shapes, sized to fit on a single 4-GPU bf16 box (a +# 4-way data-parallel shard of a Mixtral-8 block). +# ----------------------------------------------------------------------------- + +BATCH = EP_SIZE * FSDP_SIZE * 4 # 16 +SEQ = 2048 +HIDDEN = 1024 +INTER = 4096 +NUM_EXPERTS = 8 +TOPK = 2 + + +@pytest.mark.triton +class TestMoeVjpMultiprocess: + """Multiprocess (one-GPU-per-process) correctness checks for the + unified MoE custom_vjp. + """ + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_fwd_and_bwd(self, mesh, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + block = _make_block( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + permutation_backend=backend, + ) + x = jax.random.normal( + jax.random.PRNGKey(0), + (BATCH, SEQ, HIDDEN), + dtype=jnp.bfloat16, + ) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1)) + # Local-shard checks (see _local_shard docstring for why). + out_local = _local_shard(output) + assert output.dtype == x.dtype + assert np.all(np.isfinite(out_local)), "output has NaN/Inf" + assert aux is None + grads = _grad_step(block, variables, mesh, x) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_local = _local_shard(_unwrap(grads["params"][name])) + assert np.all(np.isfinite(g_local)), f"{name} grad has NaN/Inf" + assert np.any(g_local != 0.0), f"{name} grad is identically zero" + + @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + def test_aux_loss(self, mesh, backend_name): + backend = PermutationBackend(backend_name) # noqa: F821 + block = _make_block( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + permutation_backend=backend, + aux_loss_coeff=1e-2, + ) + x = jax.random.normal( + jax.random.PRNGKey(4), + (BATCH, SEQ, HIDDEN), + dtype=jnp.bfloat16, + ) + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(5)) + out_local = _local_shard(output) + assert np.all(np.isfinite(out_local)), "output has NaN/Inf under aux" + assert aux is not None + assert aux.shape == () + aux_local = _local_shard(aux) + assert np.isfinite(aux_local), "aux is NaN/Inf" + grads = _grad_step(block, variables, mesh, x) + g_gate_local = _local_shard(_unwrap(grads["params"]["gate_kernel"])) + assert np.all(np.isfinite(g_gate_local)), "gate grad NaN/Inf under aux" + + def test_pure_jax_triton_parity(self, mesh): + block_pj = _make_block( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + permutation_backend=PermutationBackend.PURE_JAX, # noqa: F821 + ) + block_tr = _make_block( + num_experts=NUM_EXPERTS, + num_experts_per_tok=TOPK, + intermediate_size=INTER, + permutation_backend=PermutationBackend.TRITON, # noqa: F821 + ) + x = jax.random.normal( + jax.random.PRNGKey(6), + (BATCH, SEQ, HIDDEN), + dtype=jnp.bfloat16, + ) + variables, out_pj, _ = _init_apply(block_pj, mesh, x, jax.random.PRNGKey(7)) + with mesh, global_shard_guard( # noqa: F821 + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) # noqa: F821 + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + x_sh = _shard_inputs(x, mesh) + out_tr, _ = jax.jit(block_tr.apply)(variables, x_sh) + + out_pj_local = _local_shard(out_pj) + out_tr_local = _local_shard(out_tr) + diff = float(np.max(np.abs(out_pj_local - out_tr_local))) + assert diff < 5e-2, f"forward parity breach: max_abs_diff={diff}" + + grads_pj = _grad_step(block_pj, variables, mesh, x) + grads_tr = _grad_step(block_tr, variables, mesh, x) + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_pj = _local_shard(_unwrap(grads_pj["params"][name])) + g_tr = _local_shard(_unwrap(grads_tr["params"][name])) + d = float(np.max(np.abs(g_pj - g_tr))) + assert d < 5e-2, f"grad parity breach on {name}: max_abs_diff={d}" diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 75bb85f5ec..b3893843af 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -12,6 +12,16 @@ from packaging import version +_PERMUTATION_AUTOTUNE_BLOCK_SIZES = (64, 128, 256, 512, 1024, 2048, 4096) + + +def _permutation_autotune_configs(): + """Autotune ``configs`` list shared by every permutation Triton + kernel below. + """ + return [triton.Config({"BLOCK_SIZE": bs}) for bs in _PERMUTATION_AUTOTUNE_BLOCK_SIZES] + + # The following three argsort related kernels are adapted from # the issue https://github.com/triton-lang/triton/issues/3698 @@ -295,15 +305,7 @@ def _permute_kernel( try: _permute_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], + configs=_permutation_autotune_configs(), key=["hidden_size"], )(_permute_kernel) except RuntimeError: @@ -416,15 +418,7 @@ def _unpermute_kernel( try: _unpermute_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], + configs=_permutation_autotune_configs(), key=["hidden_size"], )(_unpermute_kernel) except RuntimeError: @@ -525,15 +519,7 @@ def _unpermute_bwd_with_merging_probs_kernel( try: _unpermute_bwd_with_merging_probs_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], + configs=_permutation_autotune_configs(), key=["hidden_size"], )(_unpermute_bwd_with_merging_probs_kernel) except RuntimeError: @@ -643,15 +629,7 @@ def _sort_chunks_by_map_kernel( try: _sort_chunks_by_map_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], + configs=_permutation_autotune_configs(), key=["hidden_size"], )(_sort_chunks_by_map_kernel) except RuntimeError: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 94b2de9573..4ff6d07986 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2024,14 +2024,9 @@ def grouped_gemm_copy_group_sizes( return out +@cache def _should_enforce_v2_grouped_gemm() -> bool: - """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM. - - Not cached so tests can flip the env var with ``monkeypatch.setenv`` - and have it picked up on the next call. This is called only on - grouped-GEMM dispatch (not in any tight loop), so the per-call - ``getenv`` cost is negligible. - """ + """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" val = os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") try: return bool(int(val)) diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py index f4ef323e24..f02d6650a0 100644 --- a/transformer_engine/jax/flax/moe.py +++ b/transformer_engine/jax/flax/moe.py @@ -4,48 +4,42 @@ """Flax Linen MoE block for TransformerEngine JAX. -This module exposes :class:`_MoEBlock`, an **experimental** self-contained -Flax Linen MoE layer. It is intentionally prefixed with an underscore -while TE's NCCL-backed EP component (and the recipe-driven alignment -follow-up) stabilises; the public ``MoEBlock`` alias will be introduced -once those dependencies are ready (target: the TE release following the -2.16 code freeze). Until then please treat the class, its parameters, -and :class:`GlobalPermuteResult` as unstable. - -See the class docstring for the architecture, the EP / FSDP strategies, -and the ``_align_size > 0`` contract. +This module exposes :class:`_MoEBlock`, an experimental Flax Linen layer +that is a thin wrapper around the framework-agnostic functional MoE entry +point :func:`transformer_engine.jax.moe.moe`. The wrapper's only job is +to: + +1. Register the gate kernel, per-expert FFN kernels, and optional biases + as ``self.param`` slots (with the right + :func:`flax.linen.with_logical_partitioning` annotations so JAX's + sharding layer FSDPs the params correctly). +2. Resolve the EP axis name from the active + :class:`transformer_engine.jax.sharding.MeshResource`. +3. Forward all knobs to :func:`moe`. + +All routing, dispatch, FFN, combine, and aux-loss logic lives in +``moe.py`` under a *single* ``jax.custom_vjp`` so future fusions +(FP8-on-the-wire EP, fused ``ragged_all_to_all + grouped_gemm``, gate + +route + dispatch fusion) can land without touching this wrapper. + +The class is intentionally underscore-prefixed; the public ``MoEBlock`` +alias will be introduced once TE's NCCL-backed EP component (and the +recipe-driven alignment follow-up) stabilises (target: the TE release +following the 2.16 code freeze). """ -from enum import Enum -from functools import partial from typing import Any, Callable, NewType, Optional, Tuple, Union import jax import jax.numpy as jnp -from flax import linen as nn, struct as flax_struct -from jax.sharding import PartitionSpec as P +from flax import linen as nn +from jax.sharding import PartitionSpec as P # noqa: F401 (re-exported for convenience) -from ..dense import grouped_dense -from ..permutation import ( - routing_map_to_selected_experts, - compute_ragged_all_to_all_params, - compute_reverse_ragged_all_to_all_params, - local_permute_after_a2a, - local_unpermute_before_a2a, - PureJaxPermState, - pure_jax_token_combine, - pure_jax_token_dispatch, - token_combine, - token_dispatch, -) +from ..moe import PermutationBackend, moe from ..quantize import noop_quantizer_set -from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function -from ..sharding import ( - _get_mesh, - get_active_resource_axis, - with_sharding_constraint_by_logical_axes, -) -from .module import TransformerEngineBase, _convert_to_activation_function +from ..router import ScoreFunction +from ..sharding import get_active_resource_axis +from .module import TransformerEngineBase PRNGKey = Any Shape = Tuple[int, ...] @@ -54,235 +48,73 @@ Initializer = Callable[[PRNGKey, Shape, DType], Array] -__all__ = ["GlobalPermuteResult", "PermutationBackend", "_MoEBlock"] - - -# ============================================================================= -# PermutationBackend -# ============================================================================= - - -class PermutationBackend(Enum): - """Token-dispatch / combine backend used by :class:`_MoEBlock`. - - * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain XLA; - typically faster than ``TRITON`` in current testing because XLA can - fuse the ops with surrounding work. - * ``TRITON``: TE's fused Triton kernels. - """ - - PURE_JAX = "pure_jax" - TRITON = "triton" - - -# ============================================================================= -# GlobalPermuteResult -# ============================================================================= -# -# Output of :meth:`_MoEBlock._global_permute`. Carried as a pytree (so it -# crosses ``jax.shard_map`` / ``jax.value_and_grad`` boundaries -# transparently) and consumed by :meth:`_MoEBlock._global_combine`. The -# fields populated depend on the permutation backend; the unused fields -# stay ``None``. -# -# Per-backend payloads (anything else is ``None``): -# pure_jax: ``perm_state``, ``routing_weights`` -# triton: ``row_id_map``, ``pad_offsets``, ``merging_probs`` - - -@flax_struct.dataclass -class GlobalPermuteResult: - """Result of :meth:`_MoEBlock._global_permute`.""" - - sorted_inputs: jnp.ndarray - group_sizes: jnp.ndarray - perm_state: Optional[PureJaxPermState] = None - routing_weights: Optional[jnp.ndarray] = None - row_id_map: Optional[jnp.ndarray] = None - pad_offsets: Optional[jnp.ndarray] = None - merging_probs: Optional[jnp.ndarray] = None - backend: PermutationBackend = flax_struct.field( - pytree_node=False, default=PermutationBackend.PURE_JAX - ) - - -# ============================================================================= -# _MoEBlock -# ============================================================================= +__all__ = ["PermutationBackend", "_MoEBlock"] class _MoEBlock(TransformerEngineBase): - """Mixture-of-Experts Flax Linen block (**experimental**). - - .. warning:: - - This class is exposed as ``_MoEBlock`` (leading underscore) on - purpose: it is not part of the stable public API yet. The TE - NCCL-backed EP component and the recipe-driven ``_align_size`` - follow-up both need to land before this is promoted to a public - ``MoEBlock``. Until then, expect signature changes, including - to :class:`GlobalPermuteResult` and :class:`PermutationBackend`. - Target promotion: the TE release after the 2.16 code freeze. - - Encapsulates the full MoE forward pass: gate projection, fused top-k - routing, optional auxiliary load-balancing loss, token dispatch, - per-expert two-layer FFN via grouped GEMMs, activation, token combine, - and optional ragged-all-to-all expert parallelism. + """Experimental Flax MoE layer over TransformerEngine. - Architecture - ------------ - - The block is decomposed into orthogonal stages so the EP wrapper can - inject collectives between them: - - * ``_route``: gate logits -> top-k routing decisions (+ aux loss). - * ``_global_permute``: scatter tokens to experts; produces - ``[num_tokens*topk + maybe_padding, hidden]`` and per-expert - ``group_sizes`` of length ``num_experts``. - * ``_expert_ffn``: three ``grouped_dense`` calls + activation. - Operates on whatever ``(rows, group_sizes, n_groups)`` it is - handed -- agnostic to whether ``n_groups`` is the global expert - count (no-EP) or the local expert count (A2A-EP). - * ``_global_combine``: inverse of ``_global_permute`` -- gather + - weighted sum across top-k experts. - - Two top-level forward variants compose those stages: - - * ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE - primitive's ``custom_partitioning`` rule handles DP / FSDP - automatically. - * ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and - inserts ``all_gather(group_sizes)`` + forward - ``ragged_all_to_all`` + local permute around the FFN, plus their - inverses afterwards. This is the only place ``shard_map`` is - used; A2A is the canonical EP strategy because the in-flight - NCCL EP component will require this same data layout. - - Note on ``_align_size > 0`` - --------------------------- - - Both permutation backends pad each expert's group to a multiple of - ``_align_size`` when requested, which is what cuBLASLt's grouped - GEMM wants for FP8 shape selection. The pure-JAX backend - additionally appends a zero-input padding tail to keep the buffer - statically sized for JIT, so ``sum(group_sizes) <= - sorted_inputs.shape[0]`` strictly. The V1 grouped GEMM FFI asserts - strict equality ``m == sum(group_sizes)`` and is therefore - incompatible with ``_align_size > 0``; the V2 cuBLASLt-backed - grouped GEMM relaxes this to ``m >= sum(group_sizes)`` and only - iterates over the populated ragged region. The ``_align_size > 0`` - tests therefore force ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1`` and - ``skip`` if V2 is not supported on the target hardware / dtype. - - Two permutation backends are pluggable via ``permutation_backend``: - - * :attr:`PermutationBackend.PURE_JAX` (default) -- argsort-based - :func:`~transformer_engine.jax.permutation.pure_jax_token_dispatch` / - :func:`~transformer_engine.jax.permutation.pure_jax_token_combine`. - Faster than Triton in profiling for DeepSeek-style configs. - * :attr:`PermutationBackend.TRITON` -- TE's fused - :func:`~transformer_engine.jax.permutation.token_dispatch` / - :func:`~transformer_engine.jax.permutation.token_combine` Triton - kernels. - - Expert parallelism is configured via :class:`MeshResource`'s - ``ep_resource`` axis. When that axis is set on the active - :func:`~transformer_engine.jax.global_mesh_resource` and has more - than one device, ``_MoEBlock`` dispatches to the - **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes - its own tokens globally over all experts, then a forward - ``ragged_all_to_all`` exchanges per-expert chunks so each shard - ends up holding only the tokens for its local experts; after the - FFN a reverse ``ragged_all_to_all`` returns each shard's outputs - to it. This matches the layout the in-flight NCCL EP component - expects. + See module docstring for the design (this class is a thin Flax + wrapper around :func:`transformer_engine.jax.moe.moe`). Constructor + knob set kept compatible with the previous bespoke implementation so + existing call sites need no changes. Parameters ---------- num_experts : int - Total number of experts. + Total number of experts. Under EP this must be divisible by the + EP mesh axis size. num_experts_per_tok : int - Top-k value (number of experts each token is routed to). + Top-k value for routing. intermediate_size : int - Per-expert FFN hidden dim. - + Hidden dim of the per-expert FFN (the inner ``mlp`` axis). activation_type : str - FFN activation applied to the gate projection. Paired with the up - projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. - Resolved via :func:`flax.linen.` (``"silu"``, ``"gelu"``, - ``"relu"``, ``"swish"``, ...) plus ``"linear"`` for identity. + Activation between ``layer_w0 @ wi_0`` and the elementwise + product with ``layer_w0 @ wi_1``. Default ``"silu"``. - score_function : str or ScoreFunction - ``"softmax"`` (default) or ``"sigmoid"`` for - :func:`fused_topk_with_score_function`. + score_function : Union[str, ScoreFunction] + ``"softmax"`` (default) or ``"sigmoid"`` for the routing scores. use_pre_softmax : bool - Apply softmax before top-k when ``score_function="softmax"``. - num_groups : Optional[int] - Number of routing groups for grouped top-k (DeepSeek). ``None`` - (default) disables. - group_topk : Optional[int] - Top-k at the group level. ``None`` (default) disables. + Apply softmax before topk (vs. after). + num_groups, group_topk : Optional[int] + Grouped top-k knobs (DeepSeek-style). ``None`` disables grouping. scaling_factor : float - Scaling factor applied to output probs. + Multiplier on the routing weights. use_expert_bias : bool - If ``True``, registers a learnable ``expert_bias`` parameter of - shape ``[num_experts]`` and passes it to the fused router. The - router primitive validates that this is paired with - ``score_function="sigmoid"``. + If ``True``, registers a per-expert routing bias (shape ``[E]``). + Only meaningful with ``score_function="sigmoid"``; the underlying + primitive validates the pairing. aux_loss_coeff : float - If ``> 0``, compute and return the MoE auxiliary load-balancing - loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. - - gate_kernel_axes : tuple[str, ...] - Logical partitioning axes for the gate kernel of shape - ``[hidden, num_experts]``. - wi_kernel_axes : tuple[str, ...] - Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of - shape ``[num_experts, hidden, intermediate]``. Default - ``("exp", "embed", "mlp")``. - wo_kernel_axes : tuple[str, ...] - Logical partitioning axes for the ``wo`` kernel of shape - ``[num_experts, intermediate, hidden]``. Default - ``("exp", "mlp", "embed")``. - input_axes : tuple[str, ...] - Logical axes used to constrain the input activation sharding at the - block boundary. ``()`` (default) means no constraint. + If ``> 0``, return the MoE auxiliary load-balancing loss scalar + in addition to the main output. + gate_kernel_axes, wi_kernel_axes, wo_kernel_axes, input_axes : + Logical sharding axis tuples (consumed by Flax's + :func:`with_logical_partitioning` and our internal + :func:`with_sharding_constraint_by_logical_axes`). data_parallelism_axes : tuple[str, ...] - Additional mesh axes that the input *batch* dim is sharded over - IN ADDITION to ``MeshResource.ep_resource``. Setting this to - e.g. ``("fsdp",)`` makes the ``shard_map`` ``in_specs`` for the - batch dim become ``P(("ep", "fsdp"), None, None)`` -- giving - each device a unique slice of the batch (true FSDP) instead of - replicating the per-ep-shard batch across fsdp peers. - Routing is unaffected: ``axis_index("ep")`` still controls the - ragged-all-to-all; the extra fsdp peers within an ep group send - and receive their own batch slices in lockstep. Default ``()`` - preserves legacy ZeRO-1-style behavior (activations replicated - on fsdp within an ep group). - + FSDP axes over which the input *batch* dim is sharded IN + ADDITION to the EP axis. Empty (default) means activations are + replicated across non-EP axes within an EP group; set e.g. + ``("fsdp",)`` for true FSDP-of-batch where each device owns a + unique slice of the batch. permutation_backend : PermutationBackend - :attr:`PermutationBackend.PURE_JAX` (default) or - :attr:`PermutationBackend.TRITON`. + ``PURE_JAX`` (default) or ``TRITON``. + _align_size : int + Per-expert group-size alignment (``0`` disables; required > 0 + for quantized grouped GEMM). Internal knob; will be inferred + from the active quantization recipe in a follow-up PR. dtype : jnp.dtype - Compute and parameter dtype. - kernel_init : Initializer - Initializer for all kernels (gate + per-expert FFN). Defaults to - ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax - convention). + Compute / parameter dtype. + kernel_init, bias_init, expert_bias_init : Initializers. use_bias : bool - If ``True``, registers per-expert FFN biases ``wi_0_bias``, - ``wi_1_bias``, ``wo_bias``. + Register per-expert FFN biases. - TODO: - ----- - ``_align_size`` is an internal, non-public knob (alignment for - per-expert group sizes after padding). A follow-up PR will infer it - from the active quantization recipe, after which it will become a - fully-internal implementation detail. Until then it stays - intentionally underscored to discourage callers from depending on - it. + Quantization is currently configured via the standard TE autocast + context (``fp8_autocast``/``with_quantizer_set``); per-call + quantizer sets can also be passed through ``__call__``'s + ``quantizer_sets`` keyword once we stabilise the recipe pipeline. """ # Architecture @@ -300,24 +132,17 @@ class _MoEBlock(TransformerEngineBase): use_expert_bias: bool = False aux_loss_coeff: float = 0.0 - # Sharding + # Sharding (logical axes) gate_kernel_axes: Tuple[Optional[str], ...] = () wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp") wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed") input_axes: Tuple[Optional[str], ...] = () # Parallelism - # - # The EP axis is resolved from ``global_mesh_resource().ep_resource`` - # and the active mesh, not configured per-instance. ``_MoEBlock`` - # uses ``_forward_a2a_ep`` when that axis exists on the mesh and - # has > 1 device; otherwise it uses ``_forward_no_ep``. data_parallelism_axes: Tuple[str, ...] = () # Permutation permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX - # See class docstring "Notes": internal, will be inferred from the - # quantization recipe in a follow-up PR. _align_size: int = 0 # Dtypes / init / misc @@ -338,15 +163,11 @@ def __post_init__(self): ) if not isinstance(self.permutation_backend, PermutationBackend): raise TypeError( - "permutation_backend must be a PermutationBackend," - f" got {self.permutation_backend!r}" + "permutation_backend must be a PermutationBackend, got" + f" {self.permutation_backend!r}" ) super().__post_init__() - # ------------------------------------------------------------------ - # Entry point - # ------------------------------------------------------------------ - @nn.compact def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: """Run the MoE forward pass. @@ -354,36 +175,31 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: Parameters ---------- inputs : jnp.ndarray - Input tensor of shape ``[batch, sequence, hidden]``. + ``[batch, sequence, hidden]``. Returns ------- output : jnp.ndarray - Output tensor of shape ``[batch, sequence, hidden]``. + ``[batch, sequence, hidden]``. aux_loss : Optional[jnp.ndarray] - Scalar auxiliary load-balancing loss when - ``aux_loss_coeff > 0``, else ``None``. + Scalar load-balancing loss when ``aux_loss_coeff > 0``, + else ``None``. """ assert ( inputs.ndim == 3 ), f"_MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" - inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) - _, _, hidden_size = inputs.shape - # Param registrations are inlined here (not in a helper) so each - # ``self.param`` lives close to the rest of the entry point. - # Note: under EP the FFN weights and ``expert_bias`` are - # consumed *inside* a ``shard_map`` body. Flax's ``self.param`` - # must run OUTSIDE any JAX transform that would alter the - # variable scope (``shard_map`` does), so the registrations stay - # here in ``__call__`` and the values are passed down explicitly - # via ``in_specs``. ``_gate`` is called outside ``shard_map`` in - # both paths, so its kernel is registered inline inside - # ``_gate`` itself rather than here. - - gate_logits = self._gate(inputs) - + # Param registrations -- must run OUTSIDE any JAX transform that + # alters the variable scope (e.g. shard_map). The functional + # ``moe(...)`` opens its own shard_map internally for the EP + # path, so registering params here is correct. + gate_kernel = self.param( + "gate_kernel", + nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), + (hidden_size, self.num_experts), + self.dtype, + ) wi_0 = self.param( "wi_0", nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), @@ -432,743 +248,35 @@ def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: ) ep_axis = get_active_resource_axis("ep_resource") - if ep_axis is None: - output, aux_loss = self._forward_no_ep( - inputs, - gate_logits, - wi_0=wi_0, - wi_1=wi_1, - wo=wo, - wi_0_bias=wi_0_bias, - wi_1_bias=wi_1_bias, - wo_bias=wo_bias, - expert_bias=expert_bias, - ) - else: - output, aux_loss = self._forward_a2a_ep( - inputs, - gate_logits, - ep_axis=ep_axis, - wi_0=wi_0, - wi_1=wi_1, - wo=wo, - wi_0_bias=wi_0_bias, - wi_1_bias=wi_1_bias, - wo_bias=wo_bias, - expert_bias=expert_bias, - ) - - if self.aux_loss_coeff <= 0.0: - aux_loss = None - return output, aux_loss - - # ------------------------------------------------------------------ - # Gate - # ------------------------------------------------------------------ - - def _gate(self, inputs: jnp.ndarray) -> jnp.ndarray: - """Linear gate projection ``inputs @ gate_kernel``. - - Kept as a plain ``einsum`` (not ``DenseGeneral``) so it composes - cleanly with the EP shard_map: the gate runs in the outer - (pre-shard_map) scope and its output passes through the - ``shard_map`` boundary unchanged. Because the gate runs outside - any ``shard_map`` body in both EP and no-EP forwards, the - ``gate_kernel`` parameter is registered inline here. - - The gating GEMM is intentionally kept in ``self.dtype`` (typically - ``bfloat16``) and is **not** autocast to FP8 even when the caller - wraps the block in :func:`transformer_engine.jax.autocast`. Two - reasons: (1) the GEMM is tiny (``H * E`` with ``E`` small) and - contributes well under 1% of the block's compute, so quantization - savings are marginal; (2) the resulting logits feed a top-k + - softmax (or sigmoid) routing decision that is sensitive to - quantization noise -- routing flips at low-confidence tokens - could materially hurt model quality. To override, wrap the call - site in your own ``autocast`` and manually replace this method. - """ - hidden_size = inputs.shape[-1] - gate_kernel = self.param( - "gate_kernel", - nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), - (hidden_size, self.num_experts), - self.dtype, - ) - kernel = gate_kernel.astype(inputs.dtype) - return jnp.einsum("bsh,he->bse", inputs, kernel) - - # ------------------------------------------------------------------ - # Route - # ------------------------------------------------------------------ - # - # The router is split into two pieces so the EP path can compute - # aux_loss over global (cross-shard) statistics without re-running - # the main top-k path. ``_route_topk`` returns the per-token routing - # decisions (used by ``_global_permute``) and ``_compute_aux_loss`` - # returns the scalar load-balancing loss given the (possibly - # gathered) logits. - - def _route_topk( - self, - logits_2d: jnp.ndarray, - expert_bias: Optional[jnp.ndarray], - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - """Run the fused router top-k selection.""" - # ``fused_topk_with_score_function`` uses ``-1`` as the - # "disabled" sentinel for the grouped-routing knobs; translate - # our ``None`` user-facing default to that sentinel here. - sparse_probs, routing_map = fused_topk_with_score_function( - logits_2d, - topk=self.num_experts_per_tok, - use_pre_softmax=self.use_pre_softmax, - num_groups=-1 if self.num_groups is None else self.num_groups, - group_topk=-1 if self.group_topk is None else self.group_topk, - scaling_factor=self.scaling_factor, - score_function=self.score_function, - expert_bias=expert_bias, - ) - sparse_probs = sparse_probs.astype(self.dtype) - return sparse_probs, routing_map - - def _compute_aux_loss( - self, - logits_2d: jnp.ndarray, - tokens_per_expert: jnp.ndarray, - ) -> Optional[jnp.ndarray]: - """Compute the MoE auxiliary load-balancing loss. - The score-for-aux kernel reads only ``logits_2d`` and the final - reduction reads only the (already-computed) ``tokens_per_expert``, - so the aux scores can run concurrently with the main routing - path on the GPU. - - ``logits_2d`` should be the *full* logits tensor over the global - token batch -- under EP the caller is responsible for - :func:`jax.lax.all_gather` ing the logits before calling this so - the aux_loss formula - ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` - sees the global ``T``. - - ``tokens_per_expert`` must be the per-expert token-assignment - count from the *actual* routing decision -- i.e. derived from - ``_route_topk``'s ``routing_map``, not recomputed from a clean - top-k. This matters under DeepSeek-style routing - (``num_groups > 0`` / ``group_topk > 0``) where the - post-grouping routing differs from a plain top-k. Under EP the - caller is responsible for summing over all (ep + dp) shards - first so the count is global. - """ - if self.aux_loss_coeff <= 0.0: - return None - # The "compute_aux_scores=True" kernel intentionally ignores - # num_groups/group_topk/expert_bias and returns the dense - # post-score-function scores over all experts. Those scores are - # what the aux-loss formula expects (raw scoring, no grouping - # bias); the routing decisions used for ``tokens_per_expert`` - # come from the caller-supplied real ``routing_map``. - aux_scores, _ = fused_topk_with_score_function( - logits_2d.astype(jnp.float32), - topk=self.num_experts_per_tok, - score_function=self.score_function, - compute_aux_scores=True, - ) - return fused_moe_aux_loss( - aux_scores.astype(jnp.float32), - tokens_per_expert.astype(jnp.int32), - topk=self.num_experts_per_tok, - coeff=self.aux_loss_coeff, - ) - - # ------------------------------------------------------------------ - # Global permute (route -> token dispatch) - # ------------------------------------------------------------------ - - def _global_permute( - self, - inputs_2d: jnp.ndarray, - sparse_probs: jnp.ndarray, - routing_map: jnp.ndarray, - ) -> GlobalPermuteResult: - """Dispatch tokens to the global expert axis. - - Returns a :class:`GlobalPermuteResult` suitable both for the - no-EP forward (where the same buffer feeds ``_expert_ffn`` - directly) and for the A2A-EP path (where the buffer is sliced + - sent over the EP axis before the FFN). The result carries the - per-backend opaque state needed to invert the dispatch in - :meth:`_global_combine`. - """ - num_tokens = inputs_2d.shape[0] - topk = self.num_experts_per_tok - - if self.permutation_backend is PermutationBackend.PURE_JAX: - selected_experts, routing_weights = routing_map_to_selected_experts( - sparse_probs, routing_map, topk - ) - sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( - inputs_2d, - selected_experts, - num_experts=self.num_experts, - num_experts_per_tok=topk, - align_size=self._align_size, - ) - return GlobalPermuteResult( - backend=PermutationBackend.PURE_JAX, - sorted_inputs=sorted_inputs, - group_sizes=group_sizes, - perm_state=perm_state, - routing_weights=routing_weights, - ) - - # triton - num_out_tokens = num_tokens * topk - align_size_arg = self._align_size if self._align_size > 0 else None - ( - sorted_inputs, - _permuted_probs, - row_id_map, - pad_offsets, - group_sizes, - ) = token_dispatch( - inputs_2d, - routing_map, - num_out_tokens=num_out_tokens, - probs=sparse_probs, - align_size=align_size_arg, - ) - return GlobalPermuteResult( - backend=PermutationBackend.TRITON, - sorted_inputs=sorted_inputs, - group_sizes=group_sizes, - row_id_map=row_id_map, - pad_offsets=pad_offsets, - merging_probs=sparse_probs, - ) - - # ------------------------------------------------------------------ - # Expert FFN (three grouped_dense calls + activation) - # ------------------------------------------------------------------ - - def _expert_ffn( - self, - sorted_inputs: jnp.ndarray, - group_sizes: jnp.ndarray, - n_groups: int, - wi_0: jnp.ndarray, - wi_1: jnp.ndarray, - wo: jnp.ndarray, - wi_0_bias: Optional[jnp.ndarray] = None, - wi_1_bias: Optional[jnp.ndarray] = None, - wo_bias: Optional[jnp.ndarray] = None, - ) -> jnp.ndarray: - """Run the per-expert SwiGLU-style FFN over a permuted buffer. - - All ``wi_*`` / ``wo`` weights and the optional biases are passed - in as explicit args (rather than registered inline here) because - in the EP path this method runs *inside* a ``shard_map`` body - and Flax param registration must happen outside that scope. - - Parameters - ---------- - sorted_inputs : jnp.ndarray - Permuted tokens of shape ``[buffer_size, hidden]`` (rows - grouped by expert). - group_sizes : jnp.ndarray - Per-group token counts of shape ``[n_groups]``. - ``sum(group_sizes)`` must equal ``buffer_size`` (TE - ``grouped_dense`` FFI assertion at - ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``). - n_groups : int - Number of expert groups. Equals ``self.num_experts`` for the - no-EP path and ``num_experts // num_ep`` for the A2A-EP path. - Used to size the per-call quantizer set so the FP8 metadata - tensors match ``group_sizes``. - wi_0, wi_1, wo : jnp.ndarray - Expert weight tensors. Shapes (no-EP): - ``(num_experts, hidden, intermediate)`` for wi_*, - ``(num_experts, intermediate, hidden)`` for wo. Under EP - the leading expert dim is sliced to ``num_experts // num_ep``. - wi_0_bias, wi_1_bias, wo_bias : Optional[jnp.ndarray] - Optional per-expert biases (shape ``(num_experts, N)``); - ``grouped_dense`` adds ``bias[i]`` to the rows belonging to - expert ``i`` in the permuted layout. - - Returns - ------- - expert_outputs : jnp.ndarray - ``[buffer_size, hidden]``. - """ - # Each grouped_dense call gets its own quantizer_set with - # n_groups matching ``group_sizes``; this keeps the FP8 meta - # tensors correctly sized in both no-EP and A2A-EP cases. - q_set_w0 = self.generate_quantizer_set(postfix="_w0", n_groups=n_groups) - q_set_w1 = self.generate_quantizer_set(postfix="_w1", n_groups=n_groups) - q_set_wo = self.generate_quantizer_set(postfix="_wo", n_groups=n_groups) - - # Cast kernels to the activation dtype when no FP8 quantization - # is active (mirrors DenseGeneral). - if q_set_w0 == noop_quantizer_set: - wi_0 = wi_0.astype(sorted_inputs.dtype) - if q_set_w1 == noop_quantizer_set: - wi_1 = wi_1.astype(sorted_inputs.dtype) - if q_set_wo == noop_quantizer_set: - wo = wo.astype(sorted_inputs.dtype) - - layer_w0 = grouped_dense( - sorted_inputs, + return moe( + inputs, + gate_kernel, wi_0, - group_sizes, - contracting_dims=((1,), (1,)), - bias=wi_0_bias, - quantizer_set=q_set_w0, - ) - layer_w1 = grouped_dense( - sorted_inputs, wi_1, - group_sizes, - contracting_dims=((1,), (1,)), - bias=wi_1_bias, - quantizer_set=q_set_w1, - ) - - act_fn = _convert_to_activation_function(self.activation_type) - intermediate = act_fn(layer_w0) * layer_w1 - - expert_outputs = grouped_dense( - intermediate, wo, - group_sizes, - contracting_dims=((1,), (1,)), - bias=wo_bias, - quantizer_set=q_set_wo, - ) - return expert_outputs - - # ------------------------------------------------------------------ - # Global combine (token combine -> back to [B, S, H]) - # ------------------------------------------------------------------ - - def _global_combine( - self, - expert_outputs: jnp.ndarray, - perm_result: GlobalPermuteResult, - batch_size: int, - sequence_length: int, - ) -> jnp.ndarray: - """Inverse of :meth:`_global_permute`. - - Gathers per-expert outputs back into ``[batch, sequence, hidden]`` - and applies the per-token weighted sum across the top-k experts. - """ - if perm_result.backend is PermutationBackend.PURE_JAX: - return pure_jax_token_combine( - expert_outputs, - perm_result.perm_state, - perm_result.routing_weights, - num_experts_per_tok=self.num_experts_per_tok, - batch_size=batch_size, - sequence_length=sequence_length, - ) - # triton - out_2d = token_combine( - expert_outputs, - perm_result.row_id_map, - merging_probs=perm_result.merging_probs, - pad_offsets=perm_result.pad_offsets, - ) - hidden_size = out_2d.shape[-1] - return out_2d.reshape(batch_size, sequence_length, hidden_size).astype(self.dtype) - - # ------------------------------------------------------------------ - # No-EP forward - # ------------------------------------------------------------------ - - def _forward_no_ep( - self, - inputs: jnp.ndarray, - gate_logits: jnp.ndarray, - *, - wi_0: jnp.ndarray, - wi_1: jnp.ndarray, - wo: jnp.ndarray, - wi_0_bias: Optional[jnp.ndarray] = None, - wi_1_bias: Optional[jnp.ndarray] = None, - wo_bias: Optional[jnp.ndarray] = None, - expert_bias: Optional[jnp.ndarray] = None, - ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Single-shard or DP/FSDP forward (no shard_map wrapper). - - DP / FSDP both flow through each TE primitive's - ``custom_partitioning`` rule -- there is no cross-primitive - collective that the rules cannot express on their own, so a - ``shard_map`` is unnecessary here. - - Sharding contract for callers - ----------------------------- - - On this no-EP path the grouped quantize and grouped GEMMs run - in the caller's outer SPMD context (no ``shard_map`` boundary). - Their custom_partitioning rules read sharding from each input's - ``NamedSharding`` and propagate consistent shardings on outputs. - Concretely: - - * ``inputs`` should be FSDP/DP-sharded on the batch dim - (``input_axes`` in :class:`_MoEBlock` enforces this via a - logical ``with_sharding_constraint``). - * ``wi_*`` / ``wo`` weights should carry the logical axes - ``wi_kernel_axes`` / ``wo_kernel_axes`` so FSDP shards a - weight non-contracting dim, gathered inside ``grouped_dense`` - before the GEMM. - * The wgrad reduce-scatter (when FSDP is active) is emitted by - ``grouped_dense_bwd``'s partitioning rule; no explicit - collective is needed here. - - Without those shardings the grouped GEMM falls back to - replicated-everywhere semantics (legal but defeats FSDP/DP). - Tested in ``tests/jax/test_distributed_moe_block.py`` for the - EP=2 + FSDP=2 case; the no-EP + FSDP-only case shares the same - infra and is covered when ``ep_resource`` is unset on the - active ``MeshResource``. - """ - batch_size, sequence_length, hidden_size = inputs.shape - inputs_2d = inputs.reshape(-1, hidden_size) - logits_2d = gate_logits.reshape(-1, self.num_experts) - - sparse_probs, routing_map = self._route_topk(logits_2d, expert_bias) - # ``tokens_per_expert`` MUST come from the real routing_map so the - # aux-loss objective matches actual routing decisions under - # DeepSeek-style num_groups/group_topk routing. - tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) - aux_loss = self._compute_aux_loss(logits_2d, tokens_per_expert) - perm = self._global_permute(inputs_2d, sparse_probs, routing_map) - expert_outputs = self._expert_ffn( - perm.sorted_inputs, - perm.group_sizes, - n_groups=self.num_experts, - wi_0=wi_0, - wi_1=wi_1, - wo=wo, - wi_0_bias=wi_0_bias, - wi_1_bias=wi_1_bias, - wo_bias=wo_bias, - ) - output = self._global_combine(expert_outputs, perm, batch_size, sequence_length) - return output, aux_loss - - # ------------------------------------------------------------------ - # A2A (ragged-all-to-all) EP forward - # ------------------------------------------------------------------ - - def _forward_a2a_ep( - self, - inputs: jnp.ndarray, - gate_logits: jnp.ndarray, - *, - ep_axis: str, - wi_0: jnp.ndarray, - wi_1: jnp.ndarray, - wo: jnp.ndarray, - wi_0_bias: Optional[jnp.ndarray] = None, - wi_1_bias: Optional[jnp.ndarray] = None, - wo_bias: Optional[jnp.ndarray] = None, - expert_bias: Optional[jnp.ndarray] = None, - ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: - """Wrap the body in a ``shard_map`` that runs a forward - ``ragged_all_to_all`` (A2A / A2Av) around the FFN. - - For each EP shard the wrapper: - - 1. Routes the shard's local tokens **globally** over all - ``num_experts`` experts (no roll, no local-mask -- every shard - sees the full expert axis). - 2. ``all_gather`` s its per-expert ``group_sizes`` so all shards - know the complete ``[num_ep, num_experts]`` token-count matrix. - 3. Forward ``ragged_all_to_all`` over the EP axis: each shard - sends per-expert chunks to the shard that owns those experts, - and receives chunks for its own ``num_experts // num_ep`` - local experts from every other shard. - 4. Reorders the received buffer from ``(source_shard, expert)`` - to ``(expert, source_shard)`` ordering so each local expert's - tokens are contiguous. - 5. Runs the three ``grouped_dense`` calls + activation over the - ``E_local``-group buffer. - 6. Reverses the local reorder. - 7. Reverse ``ragged_all_to_all`` over EP returns each shard's - token outputs to it. - 8. Inverts the global permute and applies the top-k weighted sum. - """ - from jax.experimental.shard_map import shard_map - - mesh = _get_mesh() - if mesh is None or mesh.empty: - raise ValueError( - "_MoEBlock requires an active jax.sharding.Mesh (either via" - " `with mesh:` or `jax.set_mesh`) when EP is configured on" - " the active MeshResource." - ) - num_ep = mesh.shape[ep_axis] - assert ( - self.num_experts % num_ep == 0 - ), f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" - num_experts_local = self.num_experts // num_ep - - # Compose the BATCH sharding axis tuple. ``ep`` is always part of - # the batch axis (so ragged_all_to_all has data to route); any - # ``data_parallelism_axes`` are added on top so the per-device - # batch slice is genuinely unique (true FSDP / DP). - # Examples: - # data_parallelism_axes=() -> P('ep', None, None) - # data_parallelism_axes=('fsdp',) -> P(('ep','fsdp'), None, None) - # data_parallelism_axes=('fsdp','data') -> P(('ep','fsdp','data'), ...) - for ax in self.data_parallelism_axes: - if ax not in mesh.shape: - raise ValueError( - f"data_parallelism_axes contains {ax!r} but mesh has" - f" axes {tuple(mesh.shape.keys())}" - ) - if len(self.data_parallelism_axes) == 0: - batch_pspec_axis: Any = ep_axis - else: - batch_pspec_axis = (ep_axis, *self.data_parallelism_axes) - # The size by which the per-device batch is divided BEYOND ep. - # Used to tighten the worst-case ragged_all_to_all recv buffer: - # at most ``num_ep`` peers each send their entire local - # ``B/(num_ep*dp_size)*S*topk`` token-expert pairs, so the worst - # recv per device is ``num_ep * B/(num_ep*dp_size)*S*topk - # = B/dp_size * S * topk``. - dp_size = 1 - for ax in self.data_parallelism_axes: - dp_size *= mesh.shape[ax] - - global_batch_size, sequence_length, _hidden = inputs.shape - topk = self.num_experts_per_tok - # The shard_map's ``in_specs=P((ep, *dp_axes), ...)`` requires the - # batch dim to be divisible by ``num_ep * dp_size``; check upfront - # here for a clearer error than the one shard_map would raise at - # trace time. - batch_divisor = num_ep * dp_size - if global_batch_size % batch_divisor != 0: - raise ValueError( - f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" - ) - # Worst-case A2A receive count per shard: every peer can send its - # full per-expert-aligned local buffer. With ``_align_size > 0`` - # each per-expert group can be padded by up to ``_align_size - 1`` - # rows, so per shard the receive can overshoot the unpadded count - # by up to ``num_experts * (_align_size - 1)``. Skipping this - # extra slack would let ``ragged_all_to_all`` write past - # ``recv_buf`` when EP and padding are combined. - recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk - if self._align_size > 0: - recv_buffer_rows += self.num_experts * (self._align_size - 1) - - # Pack everything that crosses the shard_map boundary into a dict - # pytree. shard_map fully supports pytrees: ``in_specs`` must - # structurally match ``captured`` and we build them in lockstep - # so adding/removing an optional bias is one ``dict[name] = ...``. - # Params must be packed here (rather than passed inline by - # ``self.param`` inside the body) because Flax variable scopes - # must not be entered from inside a JAX transform's body. - captured: dict = { - "inputs": inputs, - "gate_logits": gate_logits, - "wi_0": wi_0, - "wi_1": wi_1, - "wo": wo, - } - in_specs: dict = { - "inputs": P(batch_pspec_axis, None, None), - "gate_logits": P(batch_pspec_axis, None, None), - "wi_0": P(ep_axis, None, None), - "wi_1": P(ep_axis, None, None), - "wo": P(ep_axis, None, None), - } - if expert_bias is not None: - captured["expert_bias"] = expert_bias - in_specs["expert_bias"] = P(ep_axis) - if wi_0_bias is not None: - captured["wi_0_bias"] = wi_0_bias - captured["wi_1_bias"] = wi_1_bias - captured["wo_bias"] = wo_bias - for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): - in_specs[name] = P(ep_axis, None) - - a2a_body = partial( - self._a2a_body, + wi_0_bias, + wi_1_bias, + wo_bias, + expert_bias, + num_experts=self.num_experts, + num_experts_per_tok=self.num_experts_per_tok, + activation_type=self.activation_type, + score_function=self.score_function, + use_pre_softmax=self.use_pre_softmax, + num_groups=self.num_groups, + group_topk=self.group_topk, + scaling_factor=self.scaling_factor, + aux_loss_coeff=self.aux_loss_coeff, + permutation_backend=self.permutation_backend, + align_size=self._align_size, + gate_inside_vjp=True, ep_axis=ep_axis, - num_ep=num_ep, - num_experts_local=num_experts_local, - recv_buffer_rows=recv_buffer_rows, - ) - - # ``check_rep=False`` disables shard_map's invariant that any - # output declared as ``P()`` is replicated across ``ep_axis``. - # We use ``axis_index(ep_axis)`` inside ``_a2a_body`` so the - # body is genuinely non-replicated, which would otherwise - # (correctly) fail the check. ``ragged_all_to_all`` already - # produces the right cross-shard semantics; this is the standard - # JAX escape hatch when collectives + per-shard logic coexist. - return shard_map( - a2a_body, - mesh=mesh, - in_specs=(in_specs,), - out_specs=(P(batch_pspec_axis, None, None), P()), - check_rep=False, - )(captured) - - # ------------------------------------------------------------------ - # Body of the per-shard A2A-EP forward (extracted from - # :meth:`_forward_a2a_ep` for readability). Runs *inside* the - # ``shard_map`` and is therefore in EP-manual mode: collectives over - # ``ep_axis`` are explicit, the rest of the mesh stays in auto mode. - # ------------------------------------------------------------------ - - def _a2a_body( - self, - local: dict, - *, - ep_axis: str, - num_ep: int, - num_experts_local: int, - recv_buffer_rows: int, - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - shard_id = jax.lax.axis_index(ep_axis) - - # -- Stage 1: per-shard route + global permute over all E -- - # Inside the shard_map body each input has its EP axis already - # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. - local_inputs = local["inputs"] - local_logits = local["gate_logits"] - local_b, local_s, local_h = local_inputs.shape - inputs_2d = local_inputs.reshape(-1, local_h) - logits_2d = local_logits.reshape(-1, self.num_experts) - - # The router operates over the full expert axis, so the - # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be - # all-gathered before being passed in. - if "expert_bias" in local: - full_expert_bias = jax.lax.all_gather( - local["expert_bias"], axis_name=ep_axis, tiled=True - ) - else: - full_expert_bias = None - sparse_probs, routing_map = self._route_topk(logits_2d, full_expert_bias) - - # aux_loss must see the global token batch and the global - # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( - # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable - # (the sum_t * tokens product is data-dependent across - # shards). We need a *single* collective: - # * ``all_gather`` logits over (ep + any DP axes) so both - # (a) the score-for-aux kernel and (b) a re-run of - # ``_route_topk`` see the full token batch. The re-run - # gives us the global per-expert token count directly, - # avoiding a separate ``psum``. Two consecutive global - # collectives over the same replica group at the very - # start of the program have been observed to deadlock - # under FP8 autocast on some XLA + NCCL combinations, - # so we keep this branch to one collective. - # The aux branch has no data dependency on the main routing - # path beyond what is already gathered, so XLA can overlap - # the two routings on the GPU. - if self.aux_loss_coeff > 0.0: - # ``axis_name`` accepts a tuple ⇒ a single collective - # over the cartesian product of axes; XLA may lower - # this to one multi-axis op or split it. - if len(self.data_parallelism_axes) == 0: - aux_collective_axes: Any = ep_axis - else: - aux_collective_axes = (ep_axis, *self.data_parallelism_axes) - global_logits_2d = jax.lax.all_gather( - logits_2d, axis_name=aux_collective_axes, axis=0, tiled=True - ) - # Re-run topk on the gathered logits to obtain the - # *global* routing_map post-grouping (respects - # num_groups/group_topk/expert_bias just like the local - # routing). Summing over the global token dim gives the - # exact same counts as ``psum(local_tokens_per_expert)`` - # without an extra collective. The duplicate topk - # compute is small relative to the FFNs. - _, global_routing_map = self._route_topk(global_logits_2d, full_expert_bias) - global_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) - aux_loss = self._compute_aux_loss(global_logits_2d, global_tokens_per_expert) - else: - aux_loss = None - - perm = self._global_permute(inputs_2d, sparse_probs, routing_map) - global_group_sizes = perm.group_sizes # [E] - - # -- Stage 2: gather per-expert counts across the EP axis -- - all_shards_tokens_per_expert = jax.lax.all_gather( - global_group_sizes[None, :], - axis_name=ep_axis, - axis=0, - tiled=True, - ) # [num_ep, num_experts] - - # -- Stage 3: forward ragged_all_to_all over EP -- - in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) - recv_buf = jnp.zeros( - (recv_buffer_rows, local_h), - dtype=perm.sorted_inputs.dtype, + data_parallelism_axes=self.data_parallelism_axes, + input_axes=self.input_axes, + gate_kernel_axes=self.gate_kernel_axes, + wi_kernel_axes=self.wi_kernel_axes, + wo_kernel_axes=self.wo_kernel_axes, + quantizer_sets=(noop_quantizer_set, noop_quantizer_set, noop_quantizer_set), + dtype=self.dtype, ) - x_recv = jax.lax.ragged_all_to_all( - perm.sorted_inputs, - recv_buf, - in_off, - send_sz, - out_off, - recv_sz, - axis_name=ep_axis, - ) - - # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) - sorted_x, local_group_sizes, local_perm_state = local_permute_after_a2a( - x_recv, - all_shards_tokens_per_expert, - shard_id, - num_ep, - ) - - # -- Stage 5: per-expert FFN (E_local groups) -- - expert_outputs = self._expert_ffn( - sorted_x, - local_group_sizes, - n_groups=num_experts_local, - wi_0=local["wi_0"], - wi_1=local["wi_1"], - wo=local["wo"], - wi_0_bias=local.get("wi_0_bias"), - wi_1_bias=local.get("wi_1_bias"), - wo_bias=local.get("wo_bias"), - ) - - # -- Stage 6: invert local permute -- - x_send_back = local_unpermute_before_a2a(expert_outputs, local_perm_state) - - # -- Stage 7: reverse ragged_all_to_all over EP -- - in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( - all_shards_tokens_per_expert, shard_id, num_ep - ) - send_back_buf = jnp.zeros_like(perm.sorted_inputs) - y_back = jax.lax.ragged_all_to_all( - x_send_back, - send_back_buf, - in_off_r, - send_sz_r, - out_off_r, - recv_sz_r, - axis_name=ep_axis, - ) - - # -- Stage 8: invert global permute, weighted sum over top-k -- - output = self._global_combine(y_back, perm, batch_size=local_b, sequence_length=local_s) - - # ``out_specs`` must match the returned pytree structurally, - # so always emit a real scalar for aux_loss; the outer - # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. - if aux_loss is None: - aux_loss = jnp.zeros((), dtype=self.dtype) - return output, aux_loss diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py new file mode 100644 index 0000000000..f7b0880091 --- /dev/null +++ b/transformer_engine/jax/moe.py @@ -0,0 +1,2019 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Functional Mixture-of-Experts (MoE) entry point with a single fused VJP. + +This module exposes :func:`moe`, the framework-agnostic flat function that +implements an entire MoE block (gate -> top-k routing -> token dispatch -> +per-expert FFN -> token combine, plus optional expert parallelism via a +shard_map / ragged_all_to_all collective) under a *single* +``jax.custom_vjp``. It is the moral analog of +:func:`transformer_engine.jax.layernorm_mlp.layernorm_mlp` for MoE: one +custom_vjp boundary covers the whole block so future fusions (FP8 over the +EP wire, fused ``ragged_all_to_all + grouped_gemm``, gate+route+dispatch +fusion) can land without re-architecting the call site. + +Design rationale +---------------- + +The earlier MoE block (:class:`transformer_engine.jax.flax.moe._MoEBlock`) +composed many narrower custom_vjps -- one per :func:`grouped_dense`, one +per :func:`token_dispatch`, etc. Every nested custom_vjp is a place where +a quantized :class:`ScaledTensor` cannot survive (JAX requires custom_vjp +inputs / outputs to be plain ``jnp.ndarray`` ish pytrees). To enable +end-to-end FP8 flow -- in particular FP8 carried over the EP +ragged_all_to_all -- the dispatch's quantize, the a2a, the per-expert +FFN, the inverse a2a, and the combine all have to live inside the same +VJP. This file collapses them into one. + +Implementation conventions +-------------------------- + +* No nested ``custom_vjp``. Every primitive's ``_fwd`` and ``_bwd`` is + called directly (e.g. :func:`tex.fused_topk_with_score_function_fwd` / + ``_bwd``, :func:`unpermute_with_mask_map`, + :func:`unpermute_bwd_with_merging_probs`, + :func:`sort_chunks_by_map(is_forward=False)`, + forward + reverse :func:`jax.lax.ragged_all_to_all`) so the outer + ``_moe_bwd_rule`` controls the bwd graph end-to-end without invoking + ``jax.vjp`` for re-linearization. +* The fwd/bwd context (``ctx``) is a plain ``dict`` whose keys depend on + the static configuration (permutation backend, EP active or not, + presence of biases, aux loss enabled). The ``_moe_fwd_rule`` builds a + matching ``ctx_specs`` dict in lockstep when opening the EP shard_map + so ``out_specs`` structurally matches the body's return. +* :func:`_dispatch` is the helper that wraps + ``permute -> a2a -> local_permute`` (forward); :func:`_combine` is its + inverse. Their ``_bwd`` siblings drive the inverse collectives in the + bwd rule. None of these helpers form a custom_vjp boundary. +""" + +from enum import Enum +from functools import partial +from typing import Any, Callable, NewType, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P + +from . import cpp_extensions as tex +from .permutation import ( + PureJaxPermState, + compute_ragged_all_to_all_params, + compute_reverse_ragged_all_to_all_params, + pure_jax_token_combine, + pure_jax_token_dispatch, + routing_map_to_selected_experts, +) +from .quantize import ( + QuantizerSet, + ScaledTensor, + TensorUsage, + noop_quantizer_set, + with_sharding_constraint_by_logical_axes, +) +from .router import ScoreFunction, _validate_score_function +from .sharding import _get_mesh +from .triton_extensions.permutation import ( + make_chunk_sort_map, + make_row_id_map, + permute_with_mask_map, + permute_with_mask_map_and_pad, + sort_chunks_by_map, + unpermute_bwd_with_merging_probs, + unpermute_bwd_with_merging_probs_and_unpad, + unpermute_with_mask_map, + unpermute_with_mask_map_and_unpad, +) +from .flax.module import _convert_to_activation_function + +PRNGKey = Any +Shape = Tuple[int, ...] +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) + + +__all__ = ["moe", "PermutationBackend"] + + +# ============================================================================= +# Enums +# ============================================================================= + + +class PermutationBackend(Enum): + """Token-dispatch / combine backend used by :func:`moe`. + + * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain XLA; + typically faster than ``TRITON`` in current testing because XLA can + fuse the ops with surrounding work. + * ``TRITON``: TE's fused Triton kernels. + """ + + PURE_JAX = "pure_jax" + TRITON = "triton" + + +# ============================================================================= +# ctx / dispatch-state key conventions +# ============================================================================= +# +# Both ``ctx`` (carried fwd_rule -> bwd_rule) and the dispatch state +# (carried _dispatch -> _combine / _dispatch_bwd / _combine_bwd) are plain +# python dicts. Using a dict (rather than a flax_struct.dataclass) lets us +# vary the populated keys with the static config without breaking +# ``shard_map``'s ``out_specs`` structural match: the spec dict and the +# value dict are built with the SAME keys via :func:`_build_ctx_specs`. +# +# Below is the key glossary so the rest of the file reads cleanly. +# +# DispatchState (dict): values are jnp.ndarray unless noted +# Always present: +# "group_sizes" [n_groups] per-expert token counts +# (n_groups = E for no-EP, +# E_local for EP) +# "ep_active" bool (carried as a Python flag, +# not in the dict; passed +# alongside) +# PURE_JAX backend: +# "sorted_indices" [num_real + padding] argsort indices +# "routing_weights" [num_tokens, topk] per-token-per-expert weights +# TRITON backend: +# "row_id_map" [num_tokens, 2*E + 1] +# "pad_offsets" [E] or None +# "merging_probs" [num_tokens, E] +# EP-only: +# "all_shards_tokens_per_expert" [num_ep, E] +# "local_perm_row_id_map" [recv_buffer_rows] +# "local_perm_inv_row_id_map" [recv_buffer_rows] +# +# NOTE: per-shard compile-time-constant shapes (num_real_tokens, +# padding_size, pre/post_a2a_buffer_shape) are NOT stored in this +# dict; they are recomputed in _body_fwd/_body_bwd via +# _compute_static_shape_info and passed as Python ints / int tuples to +# the dispatch/combine helpers. Storing them in the dict would cause +# JAX's pytree-flatten across the shard_map boundary to coerce them +# into JitTracer 0-d arrays, which breaks Python-level control flow +# (e.g. ``if padding > 0``) and ``jnp.zeros(shape)`` in the bwd. +# +# MoECtx (dict): values are jnp.ndarray / ScaledTensor unless noted +# Always present: +# "x" [B, S, H] +# "gate_kernel" [H, E] (only meaningful when gate_inside_vjp=True) +# "logits_2d" [T, E] T = local-batch * S +# "saved_scores" [T, E] from fused_topk fwd primitive +# "routing_map" [T, E] +# "dispatch" DispatchState dict +# "casted_sorted_x_lhs_trans" ScaledTensor or ndarray +# "casted_wi_0_rhs_trans" ScaledTensor or ndarray +# "casted_wi_1_rhs_trans" ScaledTensor or ndarray +# "layer_w0" ndarray (pre-activation) +# "layer_w1" ndarray +# "casted_intermediate_lhs_trans" ScaledTensor or ndarray +# "casted_wo_rhs_trans" ScaledTensor or ndarray +# "expert_outputs" ndarray (FFN output, needed for TRITON +# combine_bwd's +# unpermute_bwd_with_merging_probs) +# "local_group_sizes" [n_groups] -- mirrors dispatch.group_sizes +# but kept here for FFN bwd +# convenience +# Optional: +# "expert_bias" [E] only when expert_bias was provided +# "wi_0_bias_shape" tuple -- only when bias is used (carried +# non-diff via static side; here +# only if needed) +# "aux_const_buf" ndarray -- only when aux_loss_coeff > 0 +# "aux_tokens_per_expert" [E] -- ditto +# "aux_logits_for_score" [global_T, E] -- ditto, may be the +# gathered global logits +# or the local logits + + +# ============================================================================= +# Static shape helper +# ============================================================================= +# +# A set of per-shard shape/size values that the dispatch and combine +# helpers (both fwd and bwd) need. They're all derivable from existing +# static args, so we recompute them in both ``_body_fwd`` and +# ``_body_bwd`` and pass them as Python ints / int-tuples through +# explicit kwargs. We MUST NOT stash them inside the dynamic +# ``state`` / ``ctx`` dict: when the dict crosses the EP shard_map's +# out_specs/in_specs boundary, JAX's pytree-flatten coerces any Python +# int leaves into traced 0-d arrays, which then breaks dependent Python +# code in the bwd (e.g. ``if padding > 0`` and ``jnp.zeros(shape)``). + + +def _compute_static_shape_info( + *, + batch_size: int, + sequence_length: int, + hidden: int, + num_experts: int, + num_experts_per_tok: int, + align_size: int, + ep_active: bool, + num_ep: int = 1, + fsdp_sizes: Tuple[int, ...] = (), + recv_buffer_rows: int = 0, + batch_is_per_shard: bool = True, +) -> dict: + """Compute per-shard compile-time-constant shape info used by both + dispatch/combine fwd and dispatch/combine bwd. + + Returned dict has Python ints / int tuples (NOT jnp arrays) so the + caller can pass them as ordinary static keyword args. See the + module-level comment above for why this matters. + + ``batch_is_per_shard`` controls whether ``batch_size`` is already + sharded (True -- e.g. when this is called from inside a shard_map + body, where ``x.shape[0]`` reports the per-shard batch size) or + global (False -- e.g. when computing from x.shape outside the + shard_map body). + + Keys + ---- + num_real_tokens : int + Per-shard count of real (non-padding) permuted tokens, i.e. + ``per_shard_num_tokens * num_experts_per_tok``. + padding_size : int + Per-shard number of alignment-padding tokens appended to the + sort buffer (``num_experts * (align_size - 1)`` when + ``align_size > 0``, else ``0``). Matches the convention used + by ``pure_jax_token_dispatch``. + pre_a2a_buffer_shape : tuple[int, int] + ``(num_real_tokens + padding_size, hidden)`` -- the per-shard + shape of the sorted-inputs buffer that is sent over the EP + ragged_all_to_all in the fwd direction. + post_a2a_buffer_shape : Optional[tuple[int, int]] + ``(recv_buffer_rows, hidden)`` when EP is active, ``None`` + otherwise. + """ + import math + + if ep_active and not batch_is_per_shard: + dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 + per_shard_batch = batch_size // (num_ep * dp_size) + else: + per_shard_batch = batch_size + per_shard_num_tokens = per_shard_batch * sequence_length + num_real_tokens = per_shard_num_tokens * num_experts_per_tok + padding_size = num_experts * (align_size - 1) if align_size > 0 else 0 + pre_a2a_buffer_shape = (num_real_tokens + padding_size, hidden) + post_a2a_buffer_shape = (recv_buffer_rows, hidden) if ep_active else None + return dict( + num_real_tokens=num_real_tokens, + padding_size=padding_size, + pre_a2a_buffer_shape=pre_a2a_buffer_shape, + post_a2a_buffer_shape=post_a2a_buffer_shape, + ) + + +# ============================================================================= +# Dispatch / combine helpers (no VJP boundary -- pure Python) +# ============================================================================= + + +def _dispatch( + inputs_2d: jnp.ndarray, + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + *, + backend: PermutationBackend, + num_experts: int, + num_experts_per_tok: int, + align_size: int, + # EP-only: + ep_active: bool, + ep_axis: Optional[str], + num_ep: int, + recv_buffer_rows: int, + shard_id: Optional[jnp.ndarray] = None, +) -> Tuple[jnp.ndarray, dict]: + """``permute -> (a2a -> local_permute) iff ep_active``. + + Returns ``(sorted_x, state)`` where ``sorted_x`` has shape + ``[buffer_rows, hidden]`` -- ``E`` groups (no-EP) or ``E_local`` groups + (EP) -- and ``state`` is a dict carrying everything :func:`_combine` + and the bwd helpers need to reverse the operation. + + Bypasses the ``custom_vjp``-wrapped public ``token_dispatch`` / + ``pure_jax_token_dispatch`` wrappers (well, mostly: PURE_JAX still + composes through ``pure_jax_token_dispatch`` because that helper has + no ``custom_vjp`` itself -- only its inner ``_sort_activations`` does, + which is fine since we never auto-diff through it from this layer). + For TRITON we call the underlying ``permute_with_mask_map`` / + ``permute_with_mask_map_and_pad`` primitives directly. + """ + num_tokens, hidden = inputs_2d.shape + topk = num_experts_per_tok + state: dict = {} + + # ------------------------------------------------------------------ + # Step 1: global permute (every shard routes its own tokens over the + # full expert axis). Backend-specific. + # ------------------------------------------------------------------ + if backend is PermutationBackend.PURE_JAX: + selected_experts, routing_weights = routing_map_to_selected_experts( + sparse_probs, routing_map, topk + ) + sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( + inputs_2d, + selected_experts, + num_experts=num_experts, + num_experts_per_tok=topk, + align_size=align_size, + ) + # NOTE: ``perm_state.num_real_tokens`` and ``perm_state.padding_size`` + # are compile-time Python ints; intentionally NOT stored in + # ``state`` (would be coerced to JitTracer 0-d arrays under + # the EP shard_map's pytree flatten). Recompute via + # ``_compute_static_shape_info`` in the bwd / EP-combine + # call sites that need them. + state["sorted_indices"] = perm_state.sorted_indices + state["routing_weights"] = routing_weights + else: + # TRITON backend -- inline the underlying primitive sequence + # (mirrors ``_token_dispatch_fwd_rule`` but exposes the residuals + # to our ctx instead of saving them inside another custom_vjp). + num_out_tokens = num_tokens * topk + row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) + tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + if align_size > 0: + target_tokens_per_expert = ( + jnp.ceil(tokens_per_expert / align_size) * align_size + ).astype(jnp.int32) + pad_lengths = target_tokens_per_expert - tokens_per_expert + cum_pad = jnp.cumsum(pad_lengths) + pad_offsets = jnp.concatenate([jnp.array([0], dtype=cum_pad.dtype), cum_pad[:-1]]) + worst_case_out_tokens = ( + (num_out_tokens + num_experts * (align_size - 1)) // align_size + ) * align_size + sorted_inputs, _ = permute_with_mask_map_and_pad( + inputs_2d, + row_id_map, + None, + pad_offsets, + num_tokens, + num_experts, + worst_case_out_tokens, + hidden, + align_size=align_size, + ) + group_sizes = target_tokens_per_expert + else: + sorted_inputs, _ = permute_with_mask_map( + inputs_2d, + row_id_map, + None, + num_tokens, + num_experts, + num_out_tokens, + hidden, + ) + pad_offsets = None + group_sizes = tokens_per_expert + state["row_id_map"] = row_id_map + state["pad_offsets"] = pad_offsets + state["merging_probs"] = sparse_probs + + state["group_sizes"] = group_sizes + + if not ep_active: + return sorted_inputs, state + + # ------------------------------------------------------------------ + # Step 2 (EP only): all_gather per-expert counts so every shard knows + # the [num_ep, num_experts] token-count matrix. + # ------------------------------------------------------------------ + all_shards_tokens_per_expert = jax.lax.all_gather( + group_sizes[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) + + # ------------------------------------------------------------------ + # Step 3 (EP only): forward ragged_all_to_all over the EP axis. + # ------------------------------------------------------------------ + in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + pre_a2a_buffer_shape = sorted_inputs.shape + post_a2a_buffer_shape = (recv_buffer_rows, hidden) + recv_buf = jnp.zeros(post_a2a_buffer_shape, dtype=sorted_inputs.dtype) + x_recv = jax.lax.ragged_all_to_all( + sorted_inputs, recv_buf, in_off, send_sz, out_off, recv_sz, axis_name=ep_axis + ) + + # ------------------------------------------------------------------ + # Step 4 (EP only): local permute -- (source_shard, expert) -> + # (expert, shard). Inlined ``local_permute_after_a2a`` so we control + # both the row_id_map and its inverse for the bwd. + # ------------------------------------------------------------------ + num_experts_local = num_experts // num_ep + local_expert_start = shard_id * num_experts_local + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_ep, num_experts_local), + ) + split_sizes = local_expert_columns.reshape(-1) # source-major + indices_matrix = jnp.arange(num_ep * num_experts_local, dtype=jnp.int32).reshape( + num_ep, num_experts_local + ) + sorted_chunk_indices = indices_matrix.T.reshape(-1) # source-major -> expert-major + num_chunks = num_ep * num_experts_local + # Build a SINGLE row_id_map. ``is_forward=True`` permutes + # source-major -> expert-major; ``is_forward=False`` is the exact + # inverse (this is exactly what ``_sort_chunks_by_index_bwd_rule`` + # uses on the saved residual). _MoEBlock builds two row_id_maps + # only because it calls ``sort_chunks_by_index`` twice -- once in + # ``local_permute_after_a2a`` and again in ``local_unpermute_before_a2a``; + # each of those wrappers calls ``make_chunk_sort_map`` internally. + # Here we share one map across (fwd permute, fwd inverse-permute, + # bwd permute, bwd inverse-permute). + local_perm_row_id_map = make_chunk_sort_map( + split_sizes, sorted_chunk_indices, recv_buffer_rows, num_chunks + ) + sorted_x, _ = sort_chunks_by_map( + x_recv, local_perm_row_id_map, None, recv_buffer_rows, hidden, is_forward=True + ) + local_group_sizes = jnp.sum(local_expert_columns, axis=0) + + state["all_shards_tokens_per_expert"] = all_shards_tokens_per_expert + state["local_perm_row_id_map"] = local_perm_row_id_map + # NOTE: pre_a2a_buffer_shape and post_a2a_buffer_shape are compile- + # time int tuples; intentionally NOT stored in ``state`` (would be + # coerced to JitTracer 0-d arrays under the EP shard_map's pytree + # flatten). Recompute via ``_compute_static_shape_info`` in the + # bwd call sites that need them. + # For EP, we override ``group_sizes`` to be the per-local-expert + # counts (the FFN runs over E_local groups, not E). The original + # global ``group_sizes`` lives inside ``all_shards_tokens_per_expert`` + # if anyone needs it for diagnostics. + state["group_sizes"] = local_group_sizes + + return sorted_x, state + + +def _combine( + expert_outputs: jnp.ndarray, + state: dict, + *, + backend: PermutationBackend, + ep_active: bool, + batch_size: int, + sequence_length: int, + dtype: jnp.dtype, + num_experts_per_tok: int, + # Per-shard compile-time-constant shape info (Python ints / int tuples). + # Computed by _compute_static_shape_info in the caller, passed here + # rather than stored in ``state`` to survive shard_map crossings. + num_real_tokens: int, + padding_size: int, + pre_a2a_buffer_shape: Tuple[int, int], + # EP-only: + ep_axis: Optional[str], + shard_id: Optional[jnp.ndarray] = None, + num_ep: int = 1, +) -> jnp.ndarray: + """Inverse of :func:`_dispatch`. Returns ``[B, S, H]``.""" + if ep_active: + # Step 1 (EP): inverse local permute. Reuse the SAME row_id_map + # built in _dispatch by setting is_forward=False (this is the + # exact inverse, identical to what + # ``_sort_chunks_by_index_bwd_rule`` does with the saved residual). + recv_buffer_rows, hidden = expert_outputs.shape + x_send_back, _ = sort_chunks_by_map( + expert_outputs, + state["local_perm_row_id_map"], + None, + recv_buffer_rows, + hidden, + is_forward=False, + ) + # Step 2 (EP): reverse ragged_all_to_all. + in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( + state["all_shards_tokens_per_expert"], shard_id, num_ep + ) + send_back_buf = jnp.zeros(pre_a2a_buffer_shape, dtype=expert_outputs.dtype) + expert_outputs = jax.lax.ragged_all_to_all( + x_send_back, + send_back_buf, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) + + # Step 3: global combine. + if backend is PermutationBackend.PURE_JAX: + # Reuse the reference pure-jax implementation; it has no + # custom_vjp on its outer surface so we can call it freely. + perm_state = PureJaxPermState( + sorted_indices=state["sorted_indices"], + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return pure_jax_token_combine( + expert_outputs, + perm_state, + state["routing_weights"], + num_experts_per_tok=num_experts_per_tok, + batch_size=batch_size, + sequence_length=sequence_length, + ) + # TRITON + num_tokens = state["row_id_map"].shape[0] + num_experts = (state["row_id_map"].shape[1] - 1) // 2 + hidden = expert_outputs.shape[-1] + if state["pad_offsets"] is not None: + out_2d, _ = unpermute_with_mask_map_and_unpad( + expert_outputs, + state["row_id_map"], + state["merging_probs"], + None, + state["pad_offsets"], + num_tokens, + num_experts, + hidden, + ) + else: + out_2d, _ = unpermute_with_mask_map( + expert_outputs, + state["row_id_map"], + state["merging_probs"], + None, + num_tokens, + num_experts, + hidden, + ) + return out_2d.reshape(batch_size, sequence_length, hidden).astype(dtype) + + +def _combine_bwd( + d_output: jnp.ndarray, + state: dict, + expert_outputs: jnp.ndarray, + *, + backend: PermutationBackend, + ep_active: bool, + batch_size: int, + sequence_length: int, + dtype: jnp.dtype, + num_experts: int, + num_experts_per_tok: int, + # Per-shard compile-time-constant shape info (Python ints / int tuples). + # See ``_compute_static_shape_info`` and the note in ``_dispatch`` + # for why these are kwargs rather than state-dict entries. + num_real_tokens: int, + padding_size: int, + post_a2a_buffer_shape: Optional[Tuple[int, int]], + # EP-only: + ep_axis: Optional[str], + shard_id: Optional[jnp.ndarray] = None, + num_ep: int = 1, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Inverse of :func:`_combine` on the cotangent. + + Returns ``(d_expert_outputs, d_routing_weights_or_merging_probs)``. + + ``expert_outputs`` is the *forward* output of the FFN (same value the + fwd handed to :func:`_combine`). It's required by the TRITON + combine_bwd kernel; for PURE_JAX we don't need it but accept it for + a symmetric signature. + """ + # Step 3 inverse: global combine bwd. + d_output_2d = d_output.reshape(-1, d_output.shape[-1]) + if backend is PermutationBackend.PURE_JAX: + # The pure-jax combine is: + # unsort = _sort_activations(expert_outputs, argsort(sorted_indices)) + # if pad: unsort = unsort[:num_real] + # reshape -> einsum BKE,BK -> BE -> reshape to BSE + # Hand-derive the bwd in plain JAX (no custom_vjp involved): + unsort_indices = jnp.argsort(state["sorted_indices"]) + topk = num_experts_per_tok + num_real = num_real_tokens + padding = padding_size + # Recover the unsorted intermediate that the fwd produced (we + # need it for the d_routing_weights pullback). Apply the same + # gather the fwd did. + unsort_intermediate = expert_outputs[unsort_indices] + if padding > 0: + unsort_intermediate = unsort_intermediate[:num_real] + # Bwd of einsum/reshape: + # output[B, E] = sum_K intermediate[B, K, E] * weights[B, K] + # d_intermediate[B, K, E] = d_output[B, E] * weights[B, K] + # d_weights[B, K] = sum_E d_output[B, E] * intermediate[B, K, E] + rw = state["routing_weights"].reshape(-1, topk) + intermediate_3d = unsort_intermediate.reshape(rw.shape[0], topk, -1) + rw_cast = rw.astype(intermediate_3d.dtype) + d_intermediate_3d = jnp.einsum("BE,BK -> BKE", d_output_2d, rw_cast) + d_routing_weights = jnp.einsum("BE,BKE -> BK", d_output_2d, intermediate_3d).astype( + state["routing_weights"].dtype + ) + d_routing_weights = d_routing_weights.reshape(state["routing_weights"].shape) + d_unsort_intermediate = d_intermediate_3d.reshape(num_real, -1) + # Pad back with zeros if the fwd stripped padding. + if padding > 0: + d_unsort_intermediate = jnp.concatenate( + [ + d_unsort_intermediate, + jnp.zeros( + (padding, d_unsort_intermediate.shape[-1]), + dtype=d_unsort_intermediate.dtype, + ), + ], + axis=0, + ) + # Bwd of the gather is gather-by-original-indices: + # sorted = unsort[argsort(sorted_indices)] + # d_sorted = scatter d_unsort via argsort(sorted_indices) + # = d_unsort[sorted_indices] (gather by original sorted_indices, + # which is the inverse of argsort(sorted_indices)). + d_expert_outputs_global = d_unsort_intermediate[state["sorted_indices"]] + else: + # TRITON combine bwd: requires fwd_input (expert_outputs). + num_tokens = state["row_id_map"].shape[0] + n_experts = (state["row_id_map"].shape[1] - 1) // 2 + hidden = d_output_2d.shape[-1] + num_out_tokens = expert_outputs.shape[0] + if state["pad_offsets"] is not None: + d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs_and_unpad( + d_output_2d, + state["row_id_map"], + expert_outputs, + state["merging_probs"], + state["pad_offsets"], + num_tokens, + n_experts, + num_out_tokens, + hidden, + ) + # The kernel only writes positions tokens map to; padded + # positions may contain NaN. Replace with zeros (matches + # ``_token_combine_bwd_rule``). + d_expert_outputs_global = jnp.where( + jnp.isnan(d_expert_outputs_global), 0.0, d_expert_outputs_global + ) + else: + d_expert_outputs_global, d_merging_probs = unpermute_bwd_with_merging_probs( + d_output_2d, + state["row_id_map"], + expert_outputs, + state["merging_probs"], + num_tokens, + n_experts, + num_out_tokens, + hidden, + ) + d_routing_weights = d_merging_probs + + if not ep_active: + return d_expert_outputs_global, d_routing_weights + + # Step 2 (EP) inverse: bwd of reverse ragged_all_to_all is a forward + # ragged_all_to_all using the SAME forward parameters (sender / + # receiver roles swap from the reverse direction back to forward). + in_off_f, send_sz_f, out_off_f, recv_sz_f = compute_ragged_all_to_all_params( + state["all_shards_tokens_per_expert"], shard_id, num_ep + ) + recv_buf_for_bwd = jnp.zeros(post_a2a_buffer_shape, dtype=d_expert_outputs_global.dtype) + d_x_send_back = jax.lax.ragged_all_to_all( + d_expert_outputs_global, + recv_buf_for_bwd, + in_off_f, + send_sz_f, + out_off_f, + recv_sz_f, + axis_name=ep_axis, + ) + # Step 1 (EP) inverse: combine fwd applied is_forward=False; the + # bwd is is_forward=True with the SAME row_id_map. + recv_buffer_rows, hidden = d_x_send_back.shape + d_expert_outputs, _ = sort_chunks_by_map( + d_x_send_back, + state["local_perm_row_id_map"], + None, + recv_buffer_rows, + hidden, + is_forward=True, + ) + return d_expert_outputs, d_routing_weights + + +def _dispatch_bwd( + d_sorted_x: jnp.ndarray, + state: dict, + inputs_2d_shape: Tuple[int, ...], + *, + backend: PermutationBackend, + ep_active: bool, + num_experts: int, + num_experts_per_tok: int, + # Per-shard compile-time-constant shape info (Python ints / int tuples). + # See ``_compute_static_shape_info`` and the note in ``_dispatch`` + # for why these are kwargs rather than state-dict entries. + num_real_tokens: int, + padding_size: int, + pre_a2a_buffer_shape: Tuple[int, int], + # EP-only: + ep_axis: Optional[str], + shard_id: Optional[jnp.ndarray] = None, + num_ep: int = 1, +) -> jnp.ndarray: + """Inverse of :func:`_dispatch` on the cotangent. Returns ``d_inputs_2d``. + + The probs path through dispatch is always discarded (PURE_JAX never + threads probs through dispatch; TRITON technically does but the + caller drops ``permuted_probs``, so its cotangent is structurally + zero). The probs gradient instead flows back through + :func:`_combine_bwd`. + """ + if ep_active: + # Step 4 inverse: dispatch fwd applied is_forward=True; bwd is + # is_forward=False with the SAME row_id_map. + recv_buffer_rows, hidden = d_sorted_x.shape + d_x_recv, _ = sort_chunks_by_map( + d_sorted_x, + state["local_perm_row_id_map"], + None, + recv_buffer_rows, + hidden, + is_forward=False, + ) + # Step 3 inverse: bwd of forward ragged_a2a is the reverse-direction + # ragged_a2a using the SAME params with sender/receiver swapped. + in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( + state["all_shards_tokens_per_expert"], shard_id, num_ep + ) + recv_buf_pre = jnp.zeros(pre_a2a_buffer_shape, dtype=d_x_recv.dtype) + d_sorted_x = jax.lax.ragged_all_to_all( + d_x_recv, + recv_buf_pre, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) + + # Step 1 inverse: global permute bwd. + if backend is PermutationBackend.PURE_JAX: + # Fwd was: replicated = repeat(inputs_2d, topk, axis=0) + # padded = pad(replicated, (0, padding_size)) + # sorted = padded[sorted_indices] + # Bwd: d_padded = scatter via sorted_indices + # = d_sorted[argsort(sorted_indices)] + # d_replicated = d_padded[:num_real] + # d_inputs_2d = d_replicated.reshape(T, topk, H).sum(axis=1) + sorted_indices = state["sorted_indices"] + num_real = num_real_tokens + padding = padding_size + topk = num_experts_per_tok + unsort_indices = jnp.argsort(sorted_indices) + d_padded = d_sorted_x[unsort_indices] + if padding > 0: + d_replicated = d_padded[:num_real] + else: + d_replicated = d_padded + num_tokens = inputs_2d_shape[0] + hidden = inputs_2d_shape[-1] + d_inputs_2d = d_replicated.reshape(num_tokens, topk, hidden).sum(axis=1) + return d_inputs_2d + + # TRITON: bwd is unpermute_with_mask_map[_and_unpad]. + num_tokens = inputs_2d_shape[0] + hidden = inputs_2d_shape[-1] + if state["pad_offsets"] is not None: + d_inputs_2d, _ = unpermute_with_mask_map_and_unpad( + d_sorted_x, + state["row_id_map"], + None, + None, + state["pad_offsets"], + num_tokens, + num_experts, + hidden, + ) + else: + d_inputs_2d, _ = unpermute_with_mask_map( + d_sorted_x, + state["row_id_map"], + None, + None, + num_tokens, + num_experts, + hidden, + ) + return d_inputs_2d + + +# ============================================================================= +# Per-shard body +# ============================================================================= + + +def _body_fwd( + captured: dict, + *, + # Statics + num_experts: int, + num_experts_per_tok: int, + activation_type: str, + score_function: ScoreFunction, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: float, + aux_loss_coeff: float, + permutation_backend: PermutationBackend, + align_size: int, + gate_inside_vjp: bool, + quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], + dtype: jnp.dtype, + # EP-only statics + ep_active: bool, + ep_axis: Optional[str], + data_parallelism_axes: Tuple[str, ...], + fsdp_sizes: Tuple[int, ...], + num_ep: int, + num_experts_local: int, + recv_buffer_rows: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: + """Per-shard forward body. Returns ``(output, aux_loss, ctx_dict)``. + + ``aux_loss`` is always materialized (zeros scalar when disabled) so + the ``shard_map``'s ``out_specs`` has a static structure. + """ + if not gate_inside_vjp: + raise NotImplementedError( + "gate_inside_vjp=False is deferred to a follow-up PR; for now" + " the gate GEMM lives inside the MoE VJP." + ) + + x = captured["inputs"] + gate_kernel = captured["gate_kernel"] + wi_0 = captured["wi_0"] + wi_1 = captured["wi_1"] + wo = captured["wo"] + wi_0_bias = captured.get("wi_0_bias") + wi_1_bias = captured.get("wi_1_bias") + wo_bias = captured.get("wo_bias") + expert_bias = captured.get("expert_bias") + + batch_size, sequence_length, hidden = x.shape + + # ---------------- Stage 1: gate ---------------- + gate_kernel_cast = gate_kernel.astype(x.dtype) + gate_logits = jnp.einsum("bsh,he->bse", x, gate_kernel_cast) + logits_2d = gate_logits.reshape(-1, num_experts) + inputs_2d = x.reshape(-1, hidden) + + # ---------------- Stage 2: routing ---------------- + # Under EP, expert_bias is sharded P(ep_axis); the router needs the + # full E-dim view, so all_gather it. + if ep_active and expert_bias is not None: + full_expert_bias = jax.lax.all_gather(expert_bias, axis_name=ep_axis, tiled=True) + else: + full_expert_bias = expert_bias + # Pass an empty array sentinel when expert_bias is unused (the + # underlying primitive expects a real ndarray, not None). + eb_arg = ( + full_expert_bias if full_expert_bias is not None else jnp.zeros((0,), dtype=jnp.float32) + ) + sparse_probs, routing_map, saved_scores = tex.fused_topk_with_score_function_fwd( + logits_2d, + topk=num_experts_per_tok, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, + ) + sparse_probs = sparse_probs.astype(dtype) + + # ---------------- Stage 2b: aux loss ---------------- + if aux_loss_coeff > 0.0: + if ep_active: + collective_axes: Any = ( + ep_axis if not data_parallelism_axes else (ep_axis, *data_parallelism_axes) + ) + global_logits_2d = jax.lax.all_gather( + logits_2d, axis_name=collective_axes, axis=0, tiled=True + ) + _, global_routing_map, _ = tex.fused_topk_with_score_function_fwd( + global_logits_2d, + topk=num_experts_per_tok, + use_pre_softmax=use_pre_softmax, + num_groups=-1 if num_groups is None else num_groups, + group_topk=-1 if group_topk is None else group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=eb_arg, + compute_aux_scores=False, + ) + aux_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) + aux_logits_for_score = global_logits_2d + else: + aux_tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) + aux_logits_for_score = logits_2d + # Aux-side scores: clean per-expert scores (no grouped routing, + # no bias). compute_aux_scores=True takes a separate path that + # ignores the grouping knobs. + aux_probs, _aux_routing_map, aux_saved_scores = tex.fused_topk_with_score_function_fwd( + aux_logits_for_score.astype(jnp.float32), + topk=num_experts_per_tok, + use_pre_softmax=False, + num_groups=-1, + group_topk=-1, + scaling_factor=1.0, + score_function=score_function, + expert_bias=jnp.zeros((0,), dtype=jnp.float32), + compute_aux_scores=True, + ) + aux_loss, aux_const_buf = tex.fused_moe_aux_loss_fwd( + aux_probs.astype(jnp.float32), + aux_tokens_per_expert.astype(jnp.int32), + topk=num_experts_per_tok, + coeff=aux_loss_coeff, + ) + else: + aux_loss = jnp.zeros((), dtype=dtype) + aux_const_buf = None + aux_tokens_per_expert = None + aux_logits_for_score = None + aux_saved_scores = None + + # ---------------- Stage 3: dispatch ---------------- + shard_id = jax.lax.axis_index(ep_axis) if ep_active else None + sorted_x, dispatch_state = _dispatch( + inputs_2d, + sparse_probs, + routing_map, + backend=permutation_backend, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + align_size=align_size, + ep_active=ep_active, + ep_axis=ep_axis, + num_ep=num_ep, + recv_buffer_rows=recv_buffer_rows, + shard_id=shard_id, + ) + local_group_sizes = dispatch_state["group_sizes"] + + # ---------------- Stage 4: per-expert FFN (inlined) ---------------- + q_set_w0, q_set_w1, q_set_wo = quantizer_sets + if q_set_w0 == noop_quantizer_set: + wi_0 = wi_0.astype(sorted_x.dtype) + if q_set_w1 == noop_quantizer_set: + wi_1 = wi_1.astype(sorted_x.dtype) + if q_set_wo == noop_quantizer_set: + wo = wo.astype(sorted_x.dtype) + + # GEMM 1: layer_w0 = sorted_x @ wi_0 + casted_sorted_x_w0 = tex.grouped_quantize( + sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1 + ) + casted_wi_0 = tex.grouped_quantize(wi_0, q_set_w0.kernel, flatten_axis=-1) + layer_w0 = tex.grouped_gemm( + casted_sorted_x_w0.get_tensor(usage=TensorUsage.LHS), + casted_wi_0.get_tensor(usage=TensorUsage.RHS), + contracting_dims=((1,), (1,)), + bias=wi_0_bias, + ) + casted_sorted_x_lhs_trans = casted_sorted_x_w0.get_tensor(usage=TensorUsage.LHS_TRANS) + casted_wi_0_rhs_trans = casted_wi_0.get_tensor(usage=TensorUsage.RHS_TRANS) + if isinstance(casted_sorted_x_lhs_trans, ScaledTensor): + casted_sorted_x_lhs_trans = casted_sorted_x_lhs_trans.checkpoint(q_set_w0.x) + if isinstance(casted_wi_0_rhs_trans, ScaledTensor): + casted_wi_0_rhs_trans = casted_wi_0_rhs_trans.checkpoint(q_set_w0.kernel) + + # GEMM 2: layer_w1 = sorted_x @ wi_1 + casted_sorted_x_w1 = tex.grouped_quantize( + sorted_x, q_set_w1.x, local_group_sizes, flatten_axis=-1 + ) + casted_wi_1 = tex.grouped_quantize(wi_1, q_set_w1.kernel, flatten_axis=-1) + layer_w1 = tex.grouped_gemm( + casted_sorted_x_w1.get_tensor(usage=TensorUsage.LHS), + casted_wi_1.get_tensor(usage=TensorUsage.RHS), + contracting_dims=((1,), (1,)), + bias=wi_1_bias, + ) + casted_wi_1_rhs_trans = casted_wi_1.get_tensor(usage=TensorUsage.RHS_TRANS) + if isinstance(casted_wi_1_rhs_trans, ScaledTensor): + casted_wi_1_rhs_trans = casted_wi_1_rhs_trans.checkpoint(q_set_w1.kernel) + + # Activation: intermediate = act(layer_w0) * layer_w1 + act_fn = _convert_to_activation_function(activation_type) + intermediate = act_fn(layer_w0) * layer_w1 + + # GEMM 3: expert_outputs = intermediate @ wo + casted_intermediate = tex.grouped_quantize( + intermediate, q_set_wo.x, local_group_sizes, flatten_axis=-1 + ) + casted_wo = tex.grouped_quantize(wo, q_set_wo.kernel, flatten_axis=-1) + expert_outputs = tex.grouped_gemm( + casted_intermediate.get_tensor(usage=TensorUsage.LHS), + casted_wo.get_tensor(usage=TensorUsage.RHS), + contracting_dims=((1,), (1,)), + bias=wo_bias, + ) + casted_intermediate_lhs_trans = casted_intermediate.get_tensor(usage=TensorUsage.LHS_TRANS) + casted_wo_rhs_trans = casted_wo.get_tensor(usage=TensorUsage.RHS_TRANS) + if isinstance(casted_intermediate_lhs_trans, ScaledTensor): + casted_intermediate_lhs_trans = casted_intermediate_lhs_trans.checkpoint(q_set_wo.x) + if isinstance(casted_wo_rhs_trans, ScaledTensor): + casted_wo_rhs_trans = casted_wo_rhs_trans.checkpoint(q_set_wo.kernel) + + # ---------------- Stage 5: combine ---------------- + # Compute per-shard static shape info once and pass through both + # _combine and (later) the bwd helpers via kwargs -- never via the + # state dict, which gets pytree-flattened across shard_map and would + # coerce Python ints into JitTracer 0-d arrays. + _static_shape = _compute_static_shape_info( + batch_size=batch_size, + sequence_length=sequence_length, + hidden=hidden, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + align_size=align_size, + ep_active=ep_active, + num_ep=num_ep, + fsdp_sizes=fsdp_sizes, + recv_buffer_rows=recv_buffer_rows, + ) + output = _combine( + expert_outputs, + dispatch_state, + backend=permutation_backend, + ep_active=ep_active, + batch_size=batch_size, + sequence_length=sequence_length, + dtype=dtype, + num_experts_per_tok=num_experts_per_tok, + num_real_tokens=_static_shape["num_real_tokens"], + padding_size=_static_shape["padding_size"], + pre_a2a_buffer_shape=_static_shape["pre_a2a_buffer_shape"], + ep_axis=ep_axis, + shard_id=shard_id, + num_ep=num_ep, + ) + + # ---------------- Build ctx dict ---------------- + ctx: dict = { + "x": x, + "gate_kernel": gate_kernel, + "logits_2d": logits_2d, + "saved_scores": saved_scores, + "routing_map": routing_map, + "dispatch": dispatch_state, + "casted_sorted_x_lhs_trans": casted_sorted_x_lhs_trans, + "casted_wi_0_rhs_trans": casted_wi_0_rhs_trans, + "casted_wi_1_rhs_trans": casted_wi_1_rhs_trans, + "layer_w0": layer_w0, + "layer_w1": layer_w1, + "casted_intermediate_lhs_trans": casted_intermediate_lhs_trans, + "casted_wo_rhs_trans": casted_wo_rhs_trans, + "expert_outputs": expert_outputs, + "local_group_sizes": local_group_sizes, + } + if expert_bias is not None: + ctx["expert_bias"] = expert_bias + if wi_0_bias is not None: + ctx["has_wi_bias"] = True # NOTE: this is python bool; we DON'T store it + # (we only store array leaves in ctx; structural flags travel via statics). + del ctx["has_wi_bias"] + if aux_loss_coeff > 0.0: + ctx["aux_const_buf"] = aux_const_buf + ctx["aux_tokens_per_expert"] = aux_tokens_per_expert + ctx["aux_logits_for_score"] = aux_logits_for_score + ctx["aux_saved_scores"] = aux_saved_scores + + return output, aux_loss, ctx + + +def _body_bwd( + ctx: dict, + dy_pair: Tuple[jnp.ndarray, jnp.ndarray], + *, + num_experts: int, + num_experts_per_tok: int, + activation_type: str, + score_function: ScoreFunction, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: float, + aux_loss_coeff: float, + permutation_backend: PermutationBackend, + align_size: int, + gate_inside_vjp: bool, + quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet], + dtype: jnp.dtype, + ep_active: bool, + ep_axis: Optional[str], + data_parallelism_axes: Tuple[str, ...], + fsdp_sizes: Tuple[int, ...], + num_ep: int, + num_experts_local: int, + recv_buffer_rows: int, + # Static side info (kept here rather than inside ctx because they're + # python flags / shapes, not array leaves): + has_wi_bias: bool, + has_wo_bias: bool, + has_expert_bias: bool, + x_shape: Tuple[int, ...], +) -> dict: + """Per-shard backward body. Returns a dict of grads keyed identically + to the ``captured`` dict consumed by :func:`_body_fwd`.""" + if not gate_inside_vjp: + raise NotImplementedError("gate_inside_vjp=False is deferred to a follow-up PR.") + + d_output, d_aux_loss = dy_pair + q_set_w0, q_set_w1, q_set_wo = quantizer_sets + batch_size, sequence_length, hidden = x_shape + shard_id = jax.lax.axis_index(ep_axis) if ep_active else None + + # Recompute per-shard static shape info from existing statics + # (Python ints / int tuples). Plumbed via kwargs to _combine_bwd + # and _dispatch_bwd -- NOT through the ctx dict, because the + # dict gets pytree-flattened across the bwd shard_map's in_specs + # and Python ints would be coerced into JitTracer 0-d arrays + # (breaking ``if padding > 0`` and ``jnp.zeros(shape)`` callsites). + # ``batch_size`` here is the GLOBAL batch size (captured in + # ``x_shape`` by the outer fwd rule), hence ``batch_is_per_shard=False``. + _static_shape = _compute_static_shape_info( + batch_size=batch_size, + sequence_length=sequence_length, + hidden=hidden, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + align_size=align_size, + ep_active=ep_active, + num_ep=num_ep, + fsdp_sizes=fsdp_sizes, + recv_buffer_rows=recv_buffer_rows, + batch_is_per_shard=False, + ) + + # Compute per-shard input shape: under the EP shard_map body, the + # gradient tensors live at per-shard shape, so the dispatch_bwd + # reshape target and ``d_x_from_dispatch.reshape(x_shape)`` below + # must use the per-shard shape rather than the captured global + # ``x_shape``. + if ep_active: + import math as _math # local import keeps the no-EP path zero-overhead. + + dp_size = _math.prod(fsdp_sizes) if fsdp_sizes else 1 + per_shard_batch = batch_size // (num_ep * dp_size) + per_shard_x_shape: Tuple[int, ...] = (per_shard_batch, sequence_length, hidden) + else: + per_shard_x_shape = x_shape + + # ---------------- Combine bwd ---------------- + d_expert_outputs, d_routing_weights = _combine_bwd( + d_output, + ctx["dispatch"], + ctx["expert_outputs"], + backend=permutation_backend, + ep_active=ep_active, + batch_size=batch_size, + sequence_length=sequence_length, + dtype=dtype, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + num_real_tokens=_static_shape["num_real_tokens"], + padding_size=_static_shape["padding_size"], + post_a2a_buffer_shape=_static_shape["post_a2a_buffer_shape"], + ep_axis=ep_axis, + shard_id=shard_id, + num_ep=num_ep, + ) + + # ---------------- FFN bwd: GEMM 3 (wo) ---------------- + casted_d_eo = tex.grouped_quantize( + d_expert_outputs, q_set_wo.dgrad, ctx["local_group_sizes"], flatten_axis=-1 + ) + d_intermediate = tex.grouped_gemm( + casted_d_eo.get_tensor(usage=TensorUsage.LHS), + ctx["casted_wo_rhs_trans"], + contracting_dims=((1,), (2,)), + ) + d_wo = tex.grouped_gemm( + ctx["casted_intermediate_lhs_trans"], + casted_d_eo.get_tensor(usage=TensorUsage.RHS), + contracting_dims=((0,), (0,)), + ) + d_wo_bias = ( + tex.grouped_dbias(d_expert_outputs, ctx["local_group_sizes"]) if has_wo_bias else None + ) + + # ---------------- Activation bwd ---------------- + # intermediate = act(layer_w0) * layer_w1 + # d(layer_w0) = vjp(act, layer_w0)(d_intermediate * layer_w1) + # d(layer_w1) = d_intermediate * act(layer_w0) + act_fn = _convert_to_activation_function(activation_type) + act_w0, dact_w0_pullback = jax.vjp(act_fn, ctx["layer_w0"]) + d_layer_w1 = d_intermediate * act_w0 + (d_layer_w0,) = dact_w0_pullback(d_intermediate * ctx["layer_w1"]) + + # ---------------- FFN bwd: GEMM 2 (wi_1) ---------------- + casted_d_layer_w1 = tex.grouped_quantize( + d_layer_w1, q_set_w1.dgrad, ctx["local_group_sizes"], flatten_axis=-1 + ) + d_sorted_x_from_w1 = tex.grouped_gemm( + casted_d_layer_w1.get_tensor(usage=TensorUsage.LHS), + ctx["casted_wi_1_rhs_trans"], + contracting_dims=((1,), (2,)), + ) + d_wi_1 = tex.grouped_gemm( + ctx["casted_sorted_x_lhs_trans"], + casted_d_layer_w1.get_tensor(usage=TensorUsage.RHS), + contracting_dims=((0,), (0,)), + ) + d_wi_1_bias = tex.grouped_dbias(d_layer_w1, ctx["local_group_sizes"]) if has_wi_bias else None + + # ---------------- FFN bwd: GEMM 1 (wi_0) ---------------- + casted_d_layer_w0 = tex.grouped_quantize( + d_layer_w0, q_set_w0.dgrad, ctx["local_group_sizes"], flatten_axis=-1 + ) + d_sorted_x_from_w0 = tex.grouped_gemm( + casted_d_layer_w0.get_tensor(usage=TensorUsage.LHS), + ctx["casted_wi_0_rhs_trans"], + contracting_dims=((1,), (2,)), + ) + d_wi_0 = tex.grouped_gemm( + ctx["casted_sorted_x_lhs_trans"], + casted_d_layer_w0.get_tensor(usage=TensorUsage.RHS), + contracting_dims=((0,), (0,)), + ) + d_wi_0_bias = tex.grouped_dbias(d_layer_w0, ctx["local_group_sizes"]) if has_wi_bias else None + + d_sorted_x = d_sorted_x_from_w0 + d_sorted_x_from_w1 + + # ---------------- Dispatch bwd ---------------- + inputs_2d_shape = (per_shard_x_shape[0] * per_shard_x_shape[1], hidden) + d_inputs_2d = _dispatch_bwd( + d_sorted_x, + ctx["dispatch"], + inputs_2d_shape=inputs_2d_shape, + backend=permutation_backend, + ep_active=ep_active, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + num_real_tokens=_static_shape["num_real_tokens"], + padding_size=_static_shape["padding_size"], + pre_a2a_buffer_shape=_static_shape["pre_a2a_buffer_shape"], + ep_axis=ep_axis, + shard_id=shard_id, + num_ep=num_ep, + ) + d_x_from_dispatch = d_inputs_2d.reshape(per_shard_x_shape) + + # ---------------- Routing bwd ---------------- + # The probs cotangent comes from _combine_bwd. For PURE_JAX it's the + # cotangent of routing_weights (post-routing_map_to_selected_experts); + # we need to bridge back to sparse_probs. For TRITON it's already the + # cotangent of merging_probs == sparse_probs. + if d_routing_weights is not None: + if permutation_backend is PermutationBackend.PURE_JAX: + # routing_map_to_selected_experts: + # selected_experts = argsort(routing_map)[..., -topk:] + # weights = take_along_axis(sparse_probs, selected_experts, axis=-1) + # routing_map is bool (non-diff); the gradient of weights + # w.r.t. sparse_probs is a scatter-into-zero along the + # selected_experts indices. + selected_experts = jnp.argsort(ctx["routing_map"], axis=-1)[..., -num_experts_per_tok:] + d_sparse_probs = jnp.zeros_like(ctx["saved_scores"]).astype(d_routing_weights.dtype) + d_sparse_probs = jnp.take_along_axis(d_sparse_probs, selected_experts, axis=-1) + # Actually scatter: build via jnp.zeros + .at[].set + d_sparse_probs = jnp.zeros(ctx["routing_map"].shape, dtype=d_routing_weights.dtype) + d_sparse_probs = d_sparse_probs.at[ + jnp.arange(ctx["routing_map"].shape[0])[:, None], selected_experts + ].set(d_routing_weights) + else: + d_sparse_probs = d_routing_weights.astype(jnp.float32) + else: + d_sparse_probs = jnp.zeros(ctx["routing_map"].shape, dtype=jnp.float32) + + # Topk bwd primitive: returns d_logits (no d_expert_bias). + d_logits_2d_main = tex.fused_topk_with_score_function_bwd( + ctx["routing_map"], + ctx["saved_scores"], + d_sparse_probs.astype(ctx["saved_scores"].dtype), + topk=num_experts_per_tok, + use_pre_softmax=use_pre_softmax, + scaling_factor=scaling_factor, + score_function=score_function, + compute_aux_scores=False, + ) + + # ---------------- Aux loss bwd ---------------- + if aux_loss_coeff > 0.0: + # Step 1: aux_loss bwd -> d_aux_probs + aux_num_tokens = ctx["aux_logits_for_score"].shape[0] + d_aux_probs = tex.fused_moe_aux_loss_bwd( + ctx["aux_const_buf"], + ctx["aux_tokens_per_expert"].astype(jnp.int32), + d_aux_loss.reshape(()), + num_tokens=aux_num_tokens, + ) + # Step 2: aux-side topk bwd (compute_aux_scores=True path). + # The routing_map argument is ignored in this branch (the kernel + # uses saved_scores); pass any shape-correct integer tensor. + d_aux_logits = tex.fused_topk_with_score_function_bwd( + jnp.zeros(ctx["aux_logits_for_score"].shape, dtype=jnp.bool_), + ctx["aux_saved_scores"], + d_aux_probs.astype(ctx["aux_saved_scores"].dtype), + topk=num_experts_per_tok, + use_pre_softmax=False, + scaling_factor=1.0, + score_function=score_function, + compute_aux_scores=True, + ) + # Step 3: under EP the aux logits were all_gathered along + # ``(ep_axis, *data_parallelism_axes)`` (the latter being FSDP + # axes that shard the batch). The bwd is the inverse of that + # multi-axis tiled all_gather: ``dynamic_slice`` to pick out + # this shard's local rows from the global cotangent. + # + # JAX's convention for tiled ``all_gather(axis_name=(a, b, ...))`` + # is row-major over the tuple: the shard at mesh position + # ``(i_a, i_b, ...)`` writes to rows + # ``[(i_a * size_b * ... + i_b * ... + ...) * local_T : + # + local_T)``. We invert that by computing the same flat + # index here and slicing. + if ep_active: + local_T_aux = ctx["logits_2d"].shape[0] + flat_shard = shard_id # ep is the outermost axis in the gather tuple + for ax, sz in zip(data_parallelism_axes, fsdp_sizes): + flat_shard = flat_shard * sz + jax.lax.axis_index(ax) + d_aux_logits_local = jax.lax.dynamic_slice( + d_aux_logits.astype(ctx["logits_2d"].dtype), + start_indices=(flat_shard * local_T_aux, 0), + slice_sizes=(local_T_aux, num_experts), + ) + else: + d_aux_logits_local = d_aux_logits.astype(d_logits_2d_main.dtype) + d_logits_2d = d_logits_2d_main + d_aux_logits_local.astype(d_logits_2d_main.dtype) + else: + d_logits_2d = d_logits_2d_main + + # ---------------- Gate bwd ---------------- + d_gate_logits = d_logits_2d.reshape(per_shard_x_shape[0], per_shard_x_shape[1], num_experts) + gate_kernel_cast = ctx["gate_kernel"].astype(ctx["x"].dtype) + d_x_from_gate = jnp.einsum("bse,he->bsh", d_gate_logits, gate_kernel_cast) + d_gate_kernel = jnp.einsum("bsh,bse->he", ctx["x"], d_gate_logits).astype( + ctx["gate_kernel"].dtype + ) + d_x = d_x_from_gate + d_x_from_dispatch + + # Reduce per-rank partial contributions to match the out_specs + # declared by _build_grads_specs: + # gate_kernel : P() -> psum across (ep, *fsdp) + # wi_0/wi_1/wo : P(ep_axis, ...) -> psum across (*fsdp) only + # inputs : P((ep, fsdp), ...) -> already shard-local, no reduction + if ep_active: + replicate_all = (ep_axis,) + tuple(data_parallelism_axes) + d_gate_kernel = jax.lax.psum(d_gate_kernel, axis_name=replicate_all) + if data_parallelism_axes: + replicate_fsdp = tuple(data_parallelism_axes) + d_wi_0 = jax.lax.psum(d_wi_0, axis_name=replicate_fsdp) + d_wi_1 = jax.lax.psum(d_wi_1, axis_name=replicate_fsdp) + d_wo = jax.lax.psum(d_wo, axis_name=replicate_fsdp) + if has_wi_bias: + d_wi_0_bias = jax.lax.psum(d_wi_0_bias, axis_name=replicate_fsdp) + d_wi_1_bias = jax.lax.psum(d_wi_1_bias, axis_name=replicate_fsdp) + if has_wo_bias: + d_wo_bias = jax.lax.psum(d_wo_bias, axis_name=replicate_fsdp) + + grads: dict = { + "inputs": d_x, + "gate_kernel": d_gate_kernel, + "wi_0": d_wi_0, + "wi_1": d_wi_1, + "wo": d_wo, + } + if has_wi_bias: + grads["wi_0_bias"] = d_wi_0_bias + grads["wi_1_bias"] = d_wi_1_bias + if has_wo_bias: + grads["wo_bias"] = d_wo_bias + if has_expert_bias: + # expert_bias has no gradient through topk (the topk bwd returns + # None for it). Emit a structural zero so the outer rule has + # something to package. + grads["expert_bias"] = jnp.zeros_like(ctx["expert_bias"]) + return grads + + +# ============================================================================= +# Spec builders for shard_map (lockstep with ctx_dict / captured_dict) +# ============================================================================= + + +def _build_in_specs( + ep_axis: str, + batch_pspec_axis: Any, + *, + has_bias: bool, + has_expert_bias: bool, +) -> dict: + """Build the ``in_specs`` dict for the EP fwd shard_map.""" + specs: dict = { + "inputs": P(batch_pspec_axis, None, None), + "gate_kernel": P(), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if has_bias: + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + specs[name] = P(ep_axis, None) + if has_expert_bias: + specs["expert_bias"] = P(ep_axis) + return specs + + +def _build_dispatch_specs( + ep_axis: str, + *, + backend: PermutationBackend, + ep_active: bool, +) -> dict: + """Build the spec dict for a DispatchState dict returned by + :func:`_dispatch` from inside a shard_map. Keys must match what + :func:`_dispatch` actually populates for the given (backend, ep_active).""" + specs: dict = {"group_sizes": P()} + if backend is PermutationBackend.PURE_JAX: + specs["sorted_indices"] = P() + specs["routing_weights"] = P() + else: + specs["row_id_map"] = P() + specs["pad_offsets"] = P() + specs["merging_probs"] = P() + if ep_active: + specs["all_shards_tokens_per_expert"] = P() + specs["local_perm_row_id_map"] = P() + # NOTE: per-shard compile-time-constant shape info + # (num_real_tokens, padding_size, pre/post_a2a_buffer_shape) + # is intentionally NOT in the state dict; see _compute_static_shape_info. + return specs + + +def _build_ctx_specs( + ep_axis: str, + batch_pspec_axis: Any, + *, + backend: PermutationBackend, + ep_active: bool, + has_bias: bool, + has_expert_bias: bool, + aux_loss_enabled: bool, +) -> dict: + """Build the spec dict for the ``ctx`` returned by :func:`_body_fwd`.""" + specs: dict = { + # Per-shard local activations along the batch axis. + "x": P(batch_pspec_axis, None, None), + "gate_kernel": P(), + "logits_2d": P(batch_pspec_axis, None), + "saved_scores": P(batch_pspec_axis, None), + "routing_map": P(batch_pspec_axis, None), + "dispatch": _build_dispatch_specs(ep_axis, backend=backend, ep_active=ep_active), + # FFN residuals: the LHS_TRANS / RHS_TRANS variants of + # grouped_quantize have leading "rows"/"experts" dims that are + # already shard-local (post-dispatch). Use P(ep_axis,...) on + # leading dim; that works whether the leaf is a plain ndarray + # or a ScaledTensor (shard_map applies the spec leaf-wise to + # the registered ScaledTensor pytree). + "casted_sorted_x_lhs_trans": P(), + "casted_wi_0_rhs_trans": P(ep_axis, None, None), + "casted_wi_1_rhs_trans": P(ep_axis, None, None), + "layer_w0": P(), + "layer_w1": P(), + "casted_intermediate_lhs_trans": P(), + "casted_wo_rhs_trans": P(ep_axis, None, None), + "expert_outputs": P(), + "local_group_sizes": P(), + } + if has_expert_bias: + specs["expert_bias"] = P(ep_axis) + if aux_loss_enabled: + specs["aux_const_buf"] = P() + specs["aux_tokens_per_expert"] = P() + specs["aux_logits_for_score"] = P() + specs["aux_saved_scores"] = P() + return specs + + +def _build_grads_specs( + ep_axis: str, + batch_pspec_axis: Any, + *, + has_bias: bool, + has_expert_bias: bool, +) -> dict: + """Spec dict for the grads dict returned by :func:`_body_bwd`.""" + return _build_in_specs( + ep_axis, + batch_pspec_axis, + has_bias=has_bias, + has_expert_bias=has_expert_bias, + ) + + +# ============================================================================= +# Top-level VJP rules +# ============================================================================= + + +def _moe_fwd_rule( + # IMPORTANT — calling convention for jax.custom_vjp fwd rule. + # + # JAX uses ``_argnums_partial`` (jax/_src/api_util.py) when wiring up + # the fwd rule. That helper preserves the ORIGINAL positional order + # of the decorated function: dyn (= diff) args sit at their original + # positions and static (= nondiff) args fill the remaining slots in + # nondiff_argnums order. So the fwd rule MUST take args in the + # SAME positional order as ``_moe`` -- diff first (positions 0..8), + # then nondiff (positions 9..28), all POSITIONAL (no ``*,`` -- they + # arrive as positional, not as kwargs). + # + # NOTE: this is the OPPOSITE convention from ``_moe_bwd_rule``, which + # uses ``prepend_static_args`` -- there the static args come FIRST, + # followed by ``ctx`` and ``dy_pair``. + x, + gate_kernel, + wi_0, + wi_1, + wo, + wi_0_bias, + wi_1_bias, + wo_bias, + expert_bias, + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + aux_loss_coeff, + permutation_backend, + align_size, + gate_inside_vjp, + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + quantizer_sets, + dtype, +): + x = with_sharding_constraint_by_logical_axes(x, input_axes) + ep_active = ep_axis is not None + body_kwargs = dict( + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + activation_type=activation_type, + score_function=score_function, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + aux_loss_coeff=aux_loss_coeff, + permutation_backend=permutation_backend, + align_size=align_size, + gate_inside_vjp=gate_inside_vjp, + quantizer_sets=quantizer_sets, + dtype=dtype, + ep_axis=ep_axis, + data_parallelism_axes=data_parallelism_axes, + ) + captured: dict = { + "inputs": x, + "gate_kernel": gate_kernel, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + has_bias = wi_0_bias is not None + has_expert_bias = expert_bias is not None + if has_bias: + captured["wi_0_bias"] = wi_0_bias + captured["wi_1_bias"] = wi_1_bias + captured["wo_bias"] = wo_bias + if has_expert_bias: + captured["expert_bias"] = expert_bias + + if not ep_active: + output, aux_loss, ctx = _body_fwd( + captured, + **body_kwargs, + ep_active=False, + fsdp_sizes=(), + num_ep=1, + num_experts_local=num_experts, + recv_buffer_rows=0, + ) + # Carry static side info into ctx for the bwd rule (as Python + # objects on the dict; not part of the tree pytree leaves). + ctx["__static__"] = dict( + has_wi_bias=has_bias, + has_wo_bias=has_bias, + has_expert_bias=has_expert_bias, + x_shape=x.shape, + num_experts_local=num_experts, + recv_buffer_rows=0, + ) + return (output, aux_loss), ctx + + # ---------------- EP path ---------------- + from jax.experimental.shard_map import shard_map + + mesh = _get_mesh() + if mesh is None or mesh.empty: + raise ValueError("moe(...) requires an active jax.sharding.Mesh when ep_axis is set.") + num_ep = mesh.shape[ep_axis] + if num_experts % num_ep != 0: + raise ValueError(f"num_experts={num_experts} must be divisible by EP size={num_ep}") + num_experts_local = num_experts // num_ep + + if not data_parallelism_axes: + batch_pspec_axis: Any = ep_axis + else: + batch_pspec_axis = (ep_axis, *data_parallelism_axes) + dp_size = 1 + for ax in data_parallelism_axes: + dp_size *= mesh.shape[ax] + + global_batch_size, sequence_length, _hidden = x.shape + topk = num_experts_per_tok + if global_batch_size % (num_ep * dp_size) != 0: + raise ValueError(f"batch={global_batch_size} not divisible by ep*dp={num_ep * dp_size}") + recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + if align_size > 0: + recv_buffer_rows += num_experts * (align_size - 1) + + in_specs = _build_in_specs( + ep_axis, + batch_pspec_axis, + has_bias=has_bias, + has_expert_bias=has_expert_bias, + ) + output_spec = P(batch_pspec_axis, None, None) + aux_spec = P() + ctx_spec = _build_ctx_specs( + ep_axis, + batch_pspec_axis, + backend=permutation_backend, + ep_active=True, + has_bias=has_bias, + has_expert_bias=has_expert_bias, + aux_loss_enabled=(aux_loss_coeff > 0.0), + ) + + _fsdp_sizes: Tuple[int, ...] = tuple(mesh.shape[ax] for ax in data_parallelism_axes) + + def _shardmap_body(captured_local): + return _body_fwd( + captured_local, + **body_kwargs, + ep_active=True, + fsdp_sizes=_fsdp_sizes, + num_ep=num_ep, + num_experts_local=num_experts_local, + recv_buffer_rows=recv_buffer_rows, + ) + + output, aux_loss, ctx = shard_map( + _shardmap_body, + mesh=mesh, + in_specs=(in_specs,), + out_specs=(output_spec, aux_spec, ctx_spec), + check_rep=False, + )(captured) + ctx["__static__"] = dict( + has_wi_bias=has_bias, + has_wo_bias=has_bias, + has_expert_bias=has_expert_bias, + x_shape=x.shape, + num_experts_local=num_experts_local, + recv_buffer_rows=recv_buffer_rows, + ) + return (output, aux_loss), ctx + + +def _moe_bwd_rule( + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + aux_loss_coeff, + permutation_backend, + align_size, + gate_inside_vjp, + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + quantizer_sets, + dtype, + ctx, + dy_pair, +): + static = ctx.pop("__static__") + has_wi_bias = static["has_wi_bias"] + has_wo_bias = static["has_wo_bias"] + has_expert_bias = static["has_expert_bias"] + x_shape = static["x_shape"] + num_experts_local = static["num_experts_local"] + recv_buffer_rows = static["recv_buffer_rows"] + + ep_active = ep_axis is not None + mesh = _get_mesh() if ep_active else None + fsdp_sizes: Tuple[int, ...] = ( + tuple(mesh.shape[ax] for ax in data_parallelism_axes) if ep_active else () + ) + body_kwargs = dict( + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + activation_type=activation_type, + score_function=score_function, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + aux_loss_coeff=aux_loss_coeff, + permutation_backend=permutation_backend, + align_size=align_size, + gate_inside_vjp=gate_inside_vjp, + quantizer_sets=quantizer_sets, + dtype=dtype, + ep_axis=ep_axis, + data_parallelism_axes=data_parallelism_axes, + fsdp_sizes=fsdp_sizes, + num_ep=1 if not ep_active else mesh.shape[ep_axis], + num_experts_local=num_experts_local, + recv_buffer_rows=recv_buffer_rows, + has_wi_bias=has_wi_bias, + has_wo_bias=has_wo_bias, + has_expert_bias=has_expert_bias, + x_shape=x_shape, + ) + + if not ep_active: + grads = _body_bwd(ctx, dy_pair, ep_active=False, **body_kwargs) + # Apply sharding constraints on grads. + grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( + grads["gate_kernel"], gate_kernel_axes + ) + grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) + grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) + grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) + grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) + return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + + from jax.experimental.shard_map import shard_map + + if not data_parallelism_axes: + batch_pspec_axis: Any = ep_axis + else: + batch_pspec_axis = (ep_axis, *data_parallelism_axes) + ctx_spec = _build_ctx_specs( + ep_axis, + batch_pspec_axis, + backend=permutation_backend, + ep_active=True, + has_bias=has_wi_bias, + has_expert_bias=has_expert_bias, + aux_loss_enabled=(aux_loss_coeff > 0.0), + ) + dy_specs = (P(batch_pspec_axis, None, None), P()) + grads_spec = _build_grads_specs( + ep_axis, batch_pspec_axis, has_bias=has_wi_bias, has_expert_bias=has_expert_bias + ) + + def _bwd_body(ctx_local, dy_local): + return _body_bwd(ctx_local, dy_local, ep_active=True, **body_kwargs) + + grads = shard_map( + _bwd_body, + mesh=mesh, + in_specs=(ctx_spec, dy_specs), + out_specs=grads_spec, + check_rep=False, + )(ctx, dy_pair) + + grads["gate_kernel"] = with_sharding_constraint_by_logical_axes( + grads["gate_kernel"], gate_kernel_axes + ) + grads["wi_0"] = with_sharding_constraint_by_logical_axes(grads["wi_0"], wi_kernel_axes) + grads["wi_1"] = with_sharding_constraint_by_logical_axes(grads["wi_1"], wi_kernel_axes) + grads["wo"] = with_sharding_constraint_by_logical_axes(grads["wo"], wo_kernel_axes) + grads["inputs"] = with_sharding_constraint_by_logical_axes(grads["inputs"], input_axes) + return _grads_dict_to_tuple(grads, has_wi_bias, has_wo_bias, has_expert_bias) + + +def _grads_dict_to_tuple( + grads: dict, has_wi_bias: bool, has_wo_bias: bool, has_expert_bias: bool +) -> Tuple: + """Pack the body_bwd's grads dict into the positional tuple JAX expects.""" + return ( + grads["inputs"], + grads["gate_kernel"], + grads["wi_0"], + grads["wi_1"], + grads["wo"], + grads.get("wi_0_bias") if has_wi_bias else None, + grads.get("wi_1_bias") if has_wi_bias else None, + grads.get("wo_bias") if has_wo_bias else None, + grads.get("expert_bias") if has_expert_bias else None, + ) + + +# ============================================================================= +# custom_vjp + public entry +# ============================================================================= + + +@partial(jax.custom_vjp, nondiff_argnums=tuple(range(9, 29))) +def _moe( + x, + gate_kernel, + wi_0, + wi_1, + wo, + wi_0_bias, + wi_1_bias, + wo_bias, + expert_bias, + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + aux_loss_coeff, + permutation_backend, + align_size, + gate_inside_vjp, + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + quantizer_sets, + dtype, +): + # Call in `_moe`'s own signature order to match what JAX will pass + # the fwd rule via ``_argnums_partial``. See the comment block at + # the top of ``_moe_fwd_rule`` for why this differs from + # ``_moe_bwd_rule``'s convention. + output_pair, _ = _moe_fwd_rule( + x, + gate_kernel, + wi_0, + wi_1, + wo, + wi_0_bias, + wi_1_bias, + wo_bias, + expert_bias, + num_experts, + num_experts_per_tok, + activation_type, + score_function, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + aux_loss_coeff, + permutation_backend, + align_size, + gate_inside_vjp, + ep_axis, + data_parallelism_axes, + input_axes, + gate_kernel_axes, + wi_kernel_axes, + wo_kernel_axes, + quantizer_sets, + dtype, + ) + return output_pair + + +_moe.defvjp(_moe_fwd_rule, _moe_bwd_rule) + + +def moe( + x: jnp.ndarray, + gate_kernel: jnp.ndarray, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, + expert_bias: Optional[jnp.ndarray] = None, + *, + # Architecture + num_experts: int, + num_experts_per_tok: int, + activation_type: str = "silu", + # Routing + score_function: Union[str, ScoreFunction] = "softmax", + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: float = 1.0, + aux_loss_coeff: float = 0.0, + # Permutation + permutation_backend: PermutationBackend = PermutationBackend.PURE_JAX, + align_size: int = 0, + # Gate placement (Phuong: "perhaps as an option") + gate_inside_vjp: bool = True, + # Parallelism (resolved by caller from MeshResource) + ep_axis: Optional[str] = None, + data_parallelism_axes: Tuple[str, ...] = (), + # Logical axes for sharding constraints + input_axes: Tuple[Optional[str], ...] = (), + gate_kernel_axes: Tuple[Optional[str], ...] = (), + wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp"), + wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed"), + # Quantization + quantizer_sets: Tuple[QuantizerSet, QuantizerSet, QuantizerSet] = ( + noop_quantizer_set, + noop_quantizer_set, + noop_quantizer_set, + ), + dtype: jnp.dtype = jnp.float32, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Run a full MoE block under a single fused custom_vjp. + + Parameters and return are documented at the call site of + ``_MoEBlock.__call__``. See module docstring for design rationale. + """ + if not isinstance(permutation_backend, PermutationBackend): + raise TypeError( + f"permutation_backend must be a PermutationBackend, got {permutation_backend!r}" + ) + # Normalize string score_function ("softmax" / "sigmoid") to the + # ScoreFunction enum once here. The underlying primitive + # ``tex.fused_topk_with_score_function_fwd`` expects an int-coercible + # value (the enum has integer .value), and the public router wrapper + # we bypass also normalizes here. + score_function = _validate_score_function(score_function) + + output, aux_loss = _moe( + x, + gate_kernel, + wi_0, + wi_1, + wo, + wi_0_bias, + wi_1_bias, + wo_bias, + expert_bias, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + activation_type=activation_type, + score_function=score_function, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + aux_loss_coeff=aux_loss_coeff, + permutation_backend=permutation_backend, + align_size=align_size, + gate_inside_vjp=gate_inside_vjp, + ep_axis=ep_axis, + data_parallelism_axes=data_parallelism_axes, + input_axes=input_axes, + gate_kernel_axes=gate_kernel_axes, + wi_kernel_axes=wi_kernel_axes, + wo_kernel_axes=wo_kernel_axes, + quantizer_sets=quantizer_sets, + dtype=dtype, + ) + if aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss From 84166d08ff2cd7598d5e8be87a2cab6a1dd8d392 Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 21 May 2026 15:52:04 -0700 Subject: [PATCH 16/18] test(jax): parametrize MP MoE tests over (recipe, backend) Add MXFP8BlockScaling alongside bf16 to test_fwd_and_bwd / test_aux_loss / test_pure_jax_triton_parity. Tests skip the FP8 recipe on GPUs older than sm_100. Tolerance widened from 5e-2 to 3e-1 for MXFP8 parity to absorb block-scale quantization noise. WIP -- DO NOT MERGE into teddy/moe_block until dlcluster verifies all six (recipe, backend) combos pass. Signed-off-by: tdophung --- tests/jax/test_multiprocess_moe_vjp.py | 97 +++++++++++++++++++++----- 1 file changed, 80 insertions(+), 17 deletions(-) diff --git a/tests/jax/test_multiprocess_moe_vjp.py b/tests/jax/test_multiprocess_moe_vjp.py index ddf04f0ea4..1a23b34b66 100644 --- a/tests/jax/test_multiprocess_moe_vjp.py +++ b/tests/jax/test_multiprocess_moe_vjp.py @@ -135,11 +135,15 @@ def _inject_moe(request): if not request.node.get_closest_marker("triton"): yield return + import transformer_engine.jax as te + from transformer_engine.common import recipe as te_recipe from transformer_engine.jax.flax import _MoEBlock as MoEBlock from transformer_engine.jax.moe import PermutationBackend from transformer_engine.jax.sharding import MeshResource, global_shard_guard mod = sys.modules[__name__] + mod.te = te + mod.te_recipe = te_recipe mod.MoEBlock = MoEBlock mod.PermutationBackend = PermutationBackend mod.MeshResource = MeshResource @@ -147,6 +151,49 @@ def _inject_moe(request): yield +# ``recipe`` parametrize values used across all tests below. ``None`` +# = plain bf16; the named recipes route through TE's autocast and +# exercise the FP8/MXFP8 quantization paths in _body_fwd/_body_bwd. +# Only recipes that work on TE Blackwell are included; older GPUs +# skip via the ``hardware_supports`` guard below. +RECIPE_NAMES = ("bf16", "MXFP8BlockScaling") + + +def _resolve_recipe(name): + """Return ``(use_fp8, recipe_instance)`` for the parametrize id.""" + if name == "bf16": + return False, None + if name == "MXFP8BlockScaling": + return True, te_recipe.MXFP8BlockScaling() # noqa: F821 + raise ValueError(f"unknown recipe name: {name!r}") + + +def _hardware_supports(recipe_name): + """Skip an FP8 recipe on GPUs that don't have the hw for it.""" + if recipe_name == "bf16": + return True + from transformer_engine_jax import get_device_compute_capability + + arch = get_device_compute_capability(0) + if recipe_name == "MXFP8BlockScaling": + return arch >= 100 + return False + + +def _autocast_ctx(recipe_name): + """Context manager that turns FP8 on for non-bf16 recipes.""" + use_fp8, recipe_inst = _resolve_recipe(recipe_name) + return te.autocast(enabled=use_fp8, recipe=recipe_inst) # noqa: F821 + + +def _tol_finite_grad(recipe_name): + """Per-recipe absolute tolerance for parity grad comparison.""" + if recipe_name == "bf16": + return 5e-2 + # MXFP8 grads carry block-scale quantization noise; loosen accordingly. + return 3e-1 + + # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- @@ -245,7 +292,10 @@ class TestMoeVjpMultiprocess: """ @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) - def test_fwd_and_bwd(self, mesh, backend_name): + @pytest.mark.parametrize("recipe_name", RECIPE_NAMES) + def test_fwd_and_bwd(self, mesh, backend_name, recipe_name): + if not _hardware_supports(recipe_name): + pytest.skip(f"recipe {recipe_name} not supported on this GPU") backend = PermutationBackend(backend_name) # noqa: F821 block = _make_block( num_experts=NUM_EXPERTS, @@ -258,20 +308,25 @@ def test_fwd_and_bwd(self, mesh, backend_name): (BATCH, SEQ, HIDDEN), dtype=jnp.bfloat16, ) - variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1)) + with _autocast_ctx(recipe_name): + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(1)) # Local-shard checks (see _local_shard docstring for why). out_local = _local_shard(output) assert output.dtype == x.dtype assert np.all(np.isfinite(out_local)), "output has NaN/Inf" assert aux is None - grads = _grad_step(block, variables, mesh, x) + with _autocast_ctx(recipe_name): + grads = _grad_step(block, variables, mesh, x) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): g_local = _local_shard(_unwrap(grads["params"][name])) assert np.all(np.isfinite(g_local)), f"{name} grad has NaN/Inf" assert np.any(g_local != 0.0), f"{name} grad is identically zero" @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) - def test_aux_loss(self, mesh, backend_name): + @pytest.mark.parametrize("recipe_name", RECIPE_NAMES) + def test_aux_loss(self, mesh, backend_name, recipe_name): + if not _hardware_supports(recipe_name): + pytest.skip(f"recipe {recipe_name} not supported on this GPU") backend = PermutationBackend(backend_name) # noqa: F821 block = _make_block( num_experts=NUM_EXPERTS, @@ -285,18 +340,23 @@ def test_aux_loss(self, mesh, backend_name): (BATCH, SEQ, HIDDEN), dtype=jnp.bfloat16, ) - variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(5)) + with _autocast_ctx(recipe_name): + variables, output, aux = _init_apply(block, mesh, x, jax.random.PRNGKey(5)) out_local = _local_shard(output) assert np.all(np.isfinite(out_local)), "output has NaN/Inf under aux" assert aux is not None assert aux.shape == () aux_local = _local_shard(aux) assert np.isfinite(aux_local), "aux is NaN/Inf" - grads = _grad_step(block, variables, mesh, x) + with _autocast_ctx(recipe_name): + grads = _grad_step(block, variables, mesh, x) g_gate_local = _local_shard(_unwrap(grads["params"]["gate_kernel"])) assert np.all(np.isfinite(g_gate_local)), "gate grad NaN/Inf under aux" - def test_pure_jax_triton_parity(self, mesh): + @pytest.mark.parametrize("recipe_name", RECIPE_NAMES) + def test_pure_jax_triton_parity(self, mesh, recipe_name): + if not _hardware_supports(recipe_name): + pytest.skip(f"recipe {recipe_name} not supported on this GPU") block_pj = _make_block( num_experts=NUM_EXPERTS, num_experts_per_tok=TOPK, @@ -314,22 +374,25 @@ def test_pure_jax_triton_parity(self, mesh): (BATCH, SEQ, HIDDEN), dtype=jnp.bfloat16, ) - variables, out_pj, _ = _init_apply(block_pj, mesh, x, jax.random.PRNGKey(7)) - with mesh, global_shard_guard( # noqa: F821 - MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) # noqa: F821 - ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): - x_sh = _shard_inputs(x, mesh) - out_tr, _ = jax.jit(block_tr.apply)(variables, x_sh) + tol = _tol_finite_grad(recipe_name) + with _autocast_ctx(recipe_name): + variables, out_pj, _ = _init_apply(block_pj, mesh, x, jax.random.PRNGKey(7)) + with mesh, global_shard_guard( # noqa: F821 + MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS) # noqa: F821 + ), nn_partitioning.axis_rules(LOGICAL_AXIS_RULES): + x_sh = _shard_inputs(x, mesh) + out_tr, _ = jax.jit(block_tr.apply)(variables, x_sh) out_pj_local = _local_shard(out_pj) out_tr_local = _local_shard(out_tr) diff = float(np.max(np.abs(out_pj_local - out_tr_local))) - assert diff < 5e-2, f"forward parity breach: max_abs_diff={diff}" + assert diff < tol, f"forward parity breach: max_abs_diff={diff} (tol={tol})" - grads_pj = _grad_step(block_pj, variables, mesh, x) - grads_tr = _grad_step(block_tr, variables, mesh, x) + with _autocast_ctx(recipe_name): + grads_pj = _grad_step(block_pj, variables, mesh, x) + grads_tr = _grad_step(block_tr, variables, mesh, x) for name in ("gate_kernel", "wi_0", "wi_1", "wo"): g_pj = _local_shard(_unwrap(grads_pj["params"][name])) g_tr = _local_shard(_unwrap(grads_tr["params"][name])) d = float(np.max(np.abs(g_pj - g_tr))) - assert d < 5e-2, f"grad parity breach on {name}: max_abs_diff={d}" + assert d < tol, f"grad parity breach on {name}: max_abs_diff={d} (tol={tol})" From ee9b3ce3335c19be09f141dbbb2eae830eb97acd Mon Sep 17 00:00:00 2001 From: tdophung Date: Thu, 21 May 2026 17:40:49 -0700 Subject: [PATCH 17/18] [JAX] Address review comments transformer_engine/jax/moe.py: - Hoist 'import math' to module top (was two local imports). - Trim the verbose _moe_fwd_rule arg-order block comment. - Update PermutationBackend docstring: TRITON is the recommended default and is faster than PURE_JAX on current hardware. - Rename layer_w0 / layer_w1 to gate_proj_out / up_proj_out so the names reflect what they are (SwiGLU projection outputs, not weights). - moe() now rejects overlapping EP / FSDP axes up front instead of letting JAX produce a duplicate-axis PartitionSpec. transformer_engine/jax/permutation.py: - Drop reference to the temporary MaxText-fork compute_ragged_all_to_all_params helpers. tests/jax/test_moe_vjp.py, tests/jax/test_multiprocess_moe_vjp.py: - Add a module-level Blackwell (sm_100+) skip; grouped GEMM is Blackwell-only today. - Move the 'triton' pytest marker from the class onto the triton parametrize variant only, so the pure_jax variant still runs in environments without Triton. Signed-off-by: tdophung --- tests/jax/test_moe_vjp.py | 46 +++++++---- tests/jax/test_multiprocess_moe_vjp.py | 31 ++++++-- transformer_engine/jax/moe.py | 106 +++++++++++++------------ transformer_engine/jax/permutation.py | 4 +- 4 files changed, 111 insertions(+), 76 deletions(-) diff --git a/tests/jax/test_moe_vjp.py b/tests/jax/test_moe_vjp.py index 92d95bc896..d4cf60973d 100644 --- a/tests/jax/test_moe_vjp.py +++ b/tests/jax/test_moe_vjp.py @@ -38,14 +38,32 @@ import numpy as np import pytest +from transformer_engine_jax import get_device_compute_capability + +# The MoE custom_vjp uses grouped GEMM, which is currently +# Blackwell-only (sm_100+). Skip the whole file on older arches. +if get_device_compute_capability(0) < 100: + pytest.skip( + "MoE custom_vjp tests require Blackwell (sm_100+) for grouped GEMM", + allow_module_level=True, + ) + +# Parametrize values for the dispatch / combine backend. Only the +# ``triton`` variant is gated by the ``triton`` marker (so the +# ``pure_jax`` variant still runs on environments without Triton). +BACKEND_PARAMS = [ + pytest.param("pure_jax", id="pure_jax"), + pytest.param("triton", id="triton", marks=pytest.mark.triton), +] + -# Lazy import (mirrors the gating in the old test file): the underlying -# kernels require triton + the fused-router CUDA kernel. @pytest.fixture(autouse=True, scope="function") def _inject_moe(request): - if not request.node.get_closest_marker("triton"): - yield - return + """Inject MoEBlock / PermutationBackend / moe symbols into the test + module namespace. Done as a fixture rather than a top-level import so + a stray ``pytest tests/jax/`` collection on a build without the TE + JAX bits still produces a clean skip via the module-level guards + above rather than an ImportError at collection time.""" import sys from transformer_engine.jax.flax import _MoEBlock as MoEBlock from transformer_engine.jax.moe import PermutationBackend, moe @@ -256,11 +274,10 @@ def aux_only(params, x): # ----------------------------------------------------------------------------- -@pytest.mark.triton class TestMoeVjpForward: """Forward shape / finiteness / parity vs pure-JAX reference.""" - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_forward_shape_and_finite(self, backend_name): backend = PermutationBackend(backend_name) # noqa: F821 key = jax.random.PRNGKey(0) @@ -273,7 +290,7 @@ def test_forward_shape_and_finite(self, backend_name): assert jnp.all(jnp.isfinite(out)) assert aux is None - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_forward_parity_vs_pure_jax_reference(self, backend_name): backend = PermutationBackend(backend_name) # noqa: F821 key = jax.random.PRNGKey(1) @@ -304,12 +321,11 @@ def test_pure_jax_triton_equivalence(self): np.testing.assert_allclose(np.array(out_pj), np.array(out_tr), atol=2e-5, rtol=2e-5) -@pytest.mark.triton class TestMoeVjpBackward: """Backward parity vs pure-JAX reference (which uses ``jax.vjp`` over plain JAX ops, giving us the canonical pullback).""" - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_grads_finite_and_nonzero(self, backend_name): backend = PermutationBackend(backend_name) # noqa: F821 key = jax.random.PRNGKey(3) @@ -322,7 +338,7 @@ def test_grads_finite_and_nonzero(self, backend_name): assert jnp.all(jnp.isfinite(g)), f"{name} grad has NaN/Inf" assert jnp.any(g != 0.0), f"{name} grad is identically zero" - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_grads_match_pure_jax_reference(self, backend_name): backend = PermutationBackend(backend_name) # noqa: F821 key = jax.random.PRNGKey(4) @@ -361,11 +377,10 @@ def test_grads_match_pure_jax_reference(self, backend_name): ) -@pytest.mark.triton class TestMoeVjpAuxLoss: """Aux-loss path: forward + grad parity.""" - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_aux_loss_returned_and_finite(self, backend_name): backend = PermutationBackend(backend_name) # noqa: F821 key = jax.random.PRNGKey(5) @@ -378,7 +393,7 @@ def test_aux_loss_returned_and_finite(self, backend_name): assert jnp.isfinite(aux) assert jnp.abs(aux) < 1e2 - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_aux_loss_parity_vs_reference(self, backend_name): backend = PermutationBackend(backend_name) # noqa: F821 key = jax.random.PRNGKey(6) @@ -395,7 +410,7 @@ def test_aux_loss_parity_vs_reference(self, backend_name): ) np.testing.assert_allclose(float(aux_te), float(aux_ref), atol=1e-5, rtol=1e-5) - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) def test_aux_loss_grads_propagate_to_logits(self, backend_name): """The aux-loss bwd path must produce non-zero gate-kernel grads when only the aux-loss scalar is differentiated (no main-output @@ -417,7 +432,6 @@ def test_aux_loss_grads_propagate_to_logits(self, backend_name): # ----------------------------------------------------------------------------- -@pytest.mark.triton class TestMoEBlockFlaxWrapper: """Sanity-check the thin Flax wrapper: forward + grad on init.""" diff --git a/tests/jax/test_multiprocess_moe_vjp.py b/tests/jax/test_multiprocess_moe_vjp.py index 1a23b34b66..886e3ac2fc 100644 --- a/tests/jax/test_multiprocess_moe_vjp.py +++ b/tests/jax/test_multiprocess_moe_vjp.py @@ -104,6 +104,24 @@ def _read_mp_options(): allow_module_level=True, ) +from transformer_engine_jax import get_device_compute_capability + +# Grouped GEMM in the MoE custom_vjp currently requires Blackwell +# (sm_100+). Skip the whole file on older arches. +if get_device_compute_capability(0) < 100: + pytest.skip( + "MoE custom_vjp tests require Blackwell (sm_100+) for grouped GEMM", + allow_module_level=True, + ) + +# Parametrize values for the dispatch / combine backend. Only the +# ``triton`` variant carries the ``triton`` marker, so the +# ``pure_jax`` variant still runs on environments without Triton. +BACKEND_PARAMS = [ + pytest.param("pure_jax", id="pure_jax"), + pytest.param("triton", id="triton", marks=pytest.mark.triton), +] + NUM_DEVICES_REQUIRED = 4 EP_AXIS = "ep" @@ -132,9 +150,11 @@ def mesh(): @pytest.fixture(autouse=True, scope="function") def _inject_moe(request): - if not request.node.get_closest_marker("triton"): - yield - return + """Inject TE / MoEBlock symbols into the test module namespace. + Done lazily so a stray ``pytest tests/jax/`` collection on a build + without the TE JAX bits still produces a clean skip via the + module-level guards above rather than an ImportError at collection + time.""" import transformer_engine.jax as te from transformer_engine.common import recipe as te_recipe from transformer_engine.jax.flax import _MoEBlock as MoEBlock @@ -285,13 +305,12 @@ def _local_shard(x): TOPK = 2 -@pytest.mark.triton class TestMoeVjpMultiprocess: """Multiprocess (one-GPU-per-process) correctness checks for the unified MoE custom_vjp. """ - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) @pytest.mark.parametrize("recipe_name", RECIPE_NAMES) def test_fwd_and_bwd(self, mesh, backend_name, recipe_name): if not _hardware_supports(recipe_name): @@ -322,7 +341,7 @@ def test_fwd_and_bwd(self, mesh, backend_name, recipe_name): assert np.all(np.isfinite(g_local)), f"{name} grad has NaN/Inf" assert np.any(g_local != 0.0), f"{name} grad is identically zero" - @pytest.mark.parametrize("backend_name", ["pure_jax", "triton"]) + @pytest.mark.parametrize("backend_name", BACKEND_PARAMS) @pytest.mark.parametrize("recipe_name", RECIPE_NAMES) def test_aux_loss(self, mesh, backend_name, recipe_name): if not _hardware_supports(recipe_name): diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index f7b0880091..fe923a4914 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -48,6 +48,7 @@ bwd rule. None of these helpers form a custom_vjp boundary. """ +import math from enum import Enum from functools import partial from typing import Any, Callable, NewType, Optional, Tuple, Union @@ -104,10 +105,11 @@ class PermutationBackend(Enum): """Token-dispatch / combine backend used by :func:`moe`. - * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain XLA; - typically faster than ``TRITON`` in current testing because XLA can - fuse the ops with surrounding work. - * ``TRITON``: TE's fused Triton kernels. + * ``TRITON``: TE's fused Triton kernels. Faster than ``PURE_JAX`` + on current hardware and the recommended default. + * ``PURE_JAX``: ``jnp.argsort`` + gather paths compiled as plain + XLA; useful as a numerical reference and on builds without + Triton available. """ PURE_JAX = "pure_jax" @@ -167,8 +169,8 @@ class PermutationBackend(Enum): # "casted_sorted_x_lhs_trans" ScaledTensor or ndarray # "casted_wi_0_rhs_trans" ScaledTensor or ndarray # "casted_wi_1_rhs_trans" ScaledTensor or ndarray -# "layer_w0" ndarray (pre-activation) -# "layer_w1" ndarray +# "gate_proj_out" ndarray (pre-activation) +# "up_proj_out" ndarray # "casted_intermediate_lhs_trans" ScaledTensor or ndarray # "casted_wo_rhs_trans" ScaledTensor or ndarray # "expert_outputs" ndarray (FFN output, needed for TRITON @@ -249,8 +251,6 @@ def _compute_static_shape_info( ``(recv_buffer_rows, hidden)`` when EP is active, ``None`` otherwise. """ - import math - if ep_active and not batch_is_per_shard: dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 per_shard_batch = batch_size // (num_ep * dp_size) @@ -976,12 +976,12 @@ def _body_fwd( if q_set_wo == noop_quantizer_set: wo = wo.astype(sorted_x.dtype) - # GEMM 1: layer_w0 = sorted_x @ wi_0 + # GEMM 1: gate_proj_out = sorted_x @ wi_0 casted_sorted_x_w0 = tex.grouped_quantize( sorted_x, q_set_w0.x, local_group_sizes, flatten_axis=-1 ) casted_wi_0 = tex.grouped_quantize(wi_0, q_set_w0.kernel, flatten_axis=-1) - layer_w0 = tex.grouped_gemm( + gate_proj_out = tex.grouped_gemm( casted_sorted_x_w0.get_tensor(usage=TensorUsage.LHS), casted_wi_0.get_tensor(usage=TensorUsage.RHS), contracting_dims=((1,), (1,)), @@ -994,12 +994,12 @@ def _body_fwd( if isinstance(casted_wi_0_rhs_trans, ScaledTensor): casted_wi_0_rhs_trans = casted_wi_0_rhs_trans.checkpoint(q_set_w0.kernel) - # GEMM 2: layer_w1 = sorted_x @ wi_1 + # GEMM 2: up_proj_out = sorted_x @ wi_1 casted_sorted_x_w1 = tex.grouped_quantize( sorted_x, q_set_w1.x, local_group_sizes, flatten_axis=-1 ) casted_wi_1 = tex.grouped_quantize(wi_1, q_set_w1.kernel, flatten_axis=-1) - layer_w1 = tex.grouped_gemm( + up_proj_out = tex.grouped_gemm( casted_sorted_x_w1.get_tensor(usage=TensorUsage.LHS), casted_wi_1.get_tensor(usage=TensorUsage.RHS), contracting_dims=((1,), (1,)), @@ -1009,9 +1009,9 @@ def _body_fwd( if isinstance(casted_wi_1_rhs_trans, ScaledTensor): casted_wi_1_rhs_trans = casted_wi_1_rhs_trans.checkpoint(q_set_w1.kernel) - # Activation: intermediate = act(layer_w0) * layer_w1 + # Activation: intermediate = act(gate_proj_out) * up_proj_out act_fn = _convert_to_activation_function(activation_type) - intermediate = act_fn(layer_w0) * layer_w1 + intermediate = act_fn(gate_proj_out) * up_proj_out # GEMM 3: expert_outputs = intermediate @ wo casted_intermediate = tex.grouped_quantize( @@ -1076,8 +1076,8 @@ def _body_fwd( "casted_sorted_x_lhs_trans": casted_sorted_x_lhs_trans, "casted_wi_0_rhs_trans": casted_wi_0_rhs_trans, "casted_wi_1_rhs_trans": casted_wi_1_rhs_trans, - "layer_w0": layer_w0, - "layer_w1": layer_w1, + "gate_proj_out": gate_proj_out, + "up_proj_out": up_proj_out, "casted_intermediate_lhs_trans": casted_intermediate_lhs_trans, "casted_wo_rhs_trans": casted_wo_rhs_trans, "expert_outputs": expert_outputs, @@ -1168,9 +1168,7 @@ def _body_bwd( # must use the per-shard shape rather than the captured global # ``x_shape``. if ep_active: - import math as _math # local import keeps the no-EP path zero-overhead. - - dp_size = _math.prod(fsdp_sizes) if fsdp_sizes else 1 + dp_size = math.prod(fsdp_sizes) if fsdp_sizes else 1 per_shard_batch = batch_size // (num_ep * dp_size) per_shard_x_shape: Tuple[int, ...] = (per_shard_batch, sequence_length, hidden) else: @@ -1215,45 +1213,45 @@ def _body_bwd( ) # ---------------- Activation bwd ---------------- - # intermediate = act(layer_w0) * layer_w1 - # d(layer_w0) = vjp(act, layer_w0)(d_intermediate * layer_w1) - # d(layer_w1) = d_intermediate * act(layer_w0) + # intermediate = act(gate_proj_out) * up_proj_out + # d(gate_proj_out) = vjp(act, gate_proj_out)(d_intermediate * up_proj_out) + # d(up_proj_out) = d_intermediate * act(gate_proj_out) act_fn = _convert_to_activation_function(activation_type) - act_w0, dact_w0_pullback = jax.vjp(act_fn, ctx["layer_w0"]) - d_layer_w1 = d_intermediate * act_w0 - (d_layer_w0,) = dact_w0_pullback(d_intermediate * ctx["layer_w1"]) + act_gate_proj_out, dact_gate_proj_pullback = jax.vjp(act_fn, ctx["gate_proj_out"]) + d_up_proj_out = d_intermediate * act_gate_proj_out + (d_gate_proj_out,) = dact_gate_proj_pullback(d_intermediate * ctx["up_proj_out"]) # ---------------- FFN bwd: GEMM 2 (wi_1) ---------------- - casted_d_layer_w1 = tex.grouped_quantize( - d_layer_w1, q_set_w1.dgrad, ctx["local_group_sizes"], flatten_axis=-1 + casted_d_up_proj_out = tex.grouped_quantize( + d_up_proj_out, q_set_w1.dgrad, ctx["local_group_sizes"], flatten_axis=-1 ) d_sorted_x_from_w1 = tex.grouped_gemm( - casted_d_layer_w1.get_tensor(usage=TensorUsage.LHS), + casted_d_up_proj_out.get_tensor(usage=TensorUsage.LHS), ctx["casted_wi_1_rhs_trans"], contracting_dims=((1,), (2,)), ) d_wi_1 = tex.grouped_gemm( ctx["casted_sorted_x_lhs_trans"], - casted_d_layer_w1.get_tensor(usage=TensorUsage.RHS), + casted_d_up_proj_out.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_1_bias = tex.grouped_dbias(d_layer_w1, ctx["local_group_sizes"]) if has_wi_bias else None + d_wi_1_bias = tex.grouped_dbias(d_up_proj_out, ctx["local_group_sizes"]) if has_wi_bias else None # ---------------- FFN bwd: GEMM 1 (wi_0) ---------------- - casted_d_layer_w0 = tex.grouped_quantize( - d_layer_w0, q_set_w0.dgrad, ctx["local_group_sizes"], flatten_axis=-1 + casted_d_gate_proj_out = tex.grouped_quantize( + d_gate_proj_out, q_set_w0.dgrad, ctx["local_group_sizes"], flatten_axis=-1 ) d_sorted_x_from_w0 = tex.grouped_gemm( - casted_d_layer_w0.get_tensor(usage=TensorUsage.LHS), + casted_d_gate_proj_out.get_tensor(usage=TensorUsage.LHS), ctx["casted_wi_0_rhs_trans"], contracting_dims=((1,), (2,)), ) d_wi_0 = tex.grouped_gemm( ctx["casted_sorted_x_lhs_trans"], - casted_d_layer_w0.get_tensor(usage=TensorUsage.RHS), + casted_d_gate_proj_out.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_0_bias = tex.grouped_dbias(d_layer_w0, ctx["local_group_sizes"]) if has_wi_bias else None + d_wi_0_bias = tex.grouped_dbias(d_gate_proj_out, ctx["local_group_sizes"]) if has_wi_bias else None d_sorted_x = d_sorted_x_from_w0 + d_sorted_x_from_w1 @@ -1495,8 +1493,8 @@ def _build_ctx_specs( "casted_sorted_x_lhs_trans": P(), "casted_wi_0_rhs_trans": P(ep_axis, None, None), "casted_wi_1_rhs_trans": P(ep_axis, None, None), - "layer_w0": P(), - "layer_w1": P(), + "gate_proj_out": P(), + "up_proj_out": P(), "casted_intermediate_lhs_trans": P(), "casted_wo_rhs_trans": P(ep_axis, None, None), "expert_outputs": P(), @@ -1534,20 +1532,8 @@ def _build_grads_specs( def _moe_fwd_rule( - # IMPORTANT — calling convention for jax.custom_vjp fwd rule. - # - # JAX uses ``_argnums_partial`` (jax/_src/api_util.py) when wiring up - # the fwd rule. That helper preserves the ORIGINAL positional order - # of the decorated function: dyn (= diff) args sit at their original - # positions and static (= nondiff) args fill the remaining slots in - # nondiff_argnums order. So the fwd rule MUST take args in the - # SAME positional order as ``_moe`` -- diff first (positions 0..8), - # then nondiff (positions 9..28), all POSITIONAL (no ``*,`` -- they - # arrive as positional, not as kwargs). - # - # NOTE: this is the OPPOSITE convention from ``_moe_bwd_rule``, which - # uses ``prepend_static_args`` -- there the static args come FIRST, - # followed by ``ctx`` and ``dy_pair``. + # Args MUST match the positional order of ``_moe`` (diff first, + # then nondiff). See ``_moe_bwd_rule`` for the opposite convention. x, gate_kernel, wi_0, @@ -1647,6 +1633,24 @@ def _moe_fwd_rule( raise ValueError(f"num_experts={num_experts} must be divisible by EP size={num_ep}") num_experts_local = num_experts // num_ep + # Reject overlapping EP / FSDP axes. Listing ep_axis in + # data_parallelism_axes would produce a duplicate-axis PartitionSpec + # ((ep, ep, ...)) which JAX rejects, and would also double-count + # num_ep in dp_size (under-sizing recv_buffer_rows by a factor of + # num_ep). Catch it up front with a clear error. + for ax in data_parallelism_axes: + if ax not in mesh.shape: + raise ValueError( + f"data_parallelism_axes contains {ax!r} but mesh has" + f" axes {tuple(mesh.shape.keys())}" + ) + if ax == ep_axis: + raise ValueError( + f"data_parallelism_axes={data_parallelism_axes!r} contains the EP" + f" axis {ep_axis!r}; EP is implicit in the batch sharding and must" + " not also be listed as a data-parallel axis." + ) + if not data_parallelism_axes: batch_pspec_axis: Any = ep_axis else: diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 157575a441..d9afae066b 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -1018,9 +1018,7 @@ def pure_jax_token_combine( # which is why every slice into ``all_shards_tokens_per_expert`` uses # :func:`jax.lax.dynamic_slice`. # -# These functions are pure JAX (no MaxText / TE dependencies) and equivalent -# to :func:`maxtext.layers.te_permutation.compute_ragged_all_to_all_params` -# / :func:`compute_reverse_ragged_all_to_all_params`. +# These functions are pure JAX (no TE-internal dependencies). def compute_ragged_all_to_all_params( From 2b69d72a06f4fdad69c751616f1293c3ef402427 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 May 2026 00:42:09 +0000 Subject: [PATCH 18/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/moe.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/moe.py b/transformer_engine/jax/moe.py index fe923a4914..8cf0ee1aec 100644 --- a/transformer_engine/jax/moe.py +++ b/transformer_engine/jax/moe.py @@ -1235,7 +1235,9 @@ def _body_bwd( casted_d_up_proj_out.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_1_bias = tex.grouped_dbias(d_up_proj_out, ctx["local_group_sizes"]) if has_wi_bias else None + d_wi_1_bias = ( + tex.grouped_dbias(d_up_proj_out, ctx["local_group_sizes"]) if has_wi_bias else None + ) # ---------------- FFN bwd: GEMM 1 (wi_0) ---------------- casted_d_gate_proj_out = tex.grouped_quantize( @@ -1251,7 +1253,9 @@ def _body_bwd( casted_d_gate_proj_out.get_tensor(usage=TensorUsage.RHS), contracting_dims=((0,), (0,)), ) - d_wi_0_bias = tex.grouped_dbias(d_gate_proj_out, ctx["local_group_sizes"]) if has_wi_bias else None + d_wi_0_bias = ( + tex.grouped_dbias(d_gate_proj_out, ctx["local_group_sizes"]) if has_wi_bias else None + ) d_sorted_x = d_sorted_x_from_w0 + d_sorted_x_from_w1