Skip to content

[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011

Open
cael-ling wants to merge 5 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-rht-cast-fusion-swizzled-sf-output
Open

[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011
cael-ling wants to merge 5 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-rht-cast-fusion-swizzled-sf-output

Conversation

@cael-ling
Copy link
Copy Markdown
Contributor

@cael-ling cael-ling commented May 19, 2026

Description

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a kEnableSwizzleSFOutput switch for that, but the framework never set the matching with_gemm_swizzled_scales flag on
NVFP4 outputs -- it was a false with a TODO. This PR wires it through and saves ~25 us per quantize on LLaMA-class shapes (1.18x – 1.36x on the quant + swizzle path that te.Linear runs).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Kernel side (transformer_engine/common/hadamard_transform/):

  • row_cast_col_hadamard_transform_cast_fusion.cu &
    group_row_cast_col_hadamard_transform_cast_fusion.cu: drive the
    existing kEnableSwizzleSFOutput template parameter from
    output.with_gemm_swizzled_scales. The grouped kernel additionally
    NVTE_CHECKs the flag is consistent across all tensors in a group
    (it honours a single boolean).
  • The graph-safe grouped variant already had this wired correctly --
    no change.

Framework side (transformer_engine/pytorch/csrc/):

  • New static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)
    mirroring the dispatch-time eligibility check in
    NVFP4Quantizer::quantize_impl (rows%64==0 && cols%128==0 && SM100/110).
  • NVFP4Quantizer::create_tensor, NVFP4Quantizer::convert_and_update_tensor,
    and bulk_allocate_nvfp4_tensors now set
    with_gemm_swizzled_scales = optimize_for_gemm && with_rht && shape_eligible.
    For the grouped allocator the flag is True only if every tensor in
    the group is eligible.
  • Belt-and-suspenders NVTE_CHECK(!out.with_gemm_swizzled_scales) at
    the entry of quantize_with_rht_unfused_helper. The framework gate
    already keeps user code from tripping it; this only fires if a future
    low-level caller bypasses the gate.

Performance

SM100a, bf16 input, rowwise + columnwise SF, RHT + post-RHT amax.
Per-quantize wall-clock median via torch.utils.benchmark.Timer.blocked_autorange.
quant + swizzle = quantizer(x); tex.swizzle_scales_for_gemm_(t) --
exactly what te.Linear runs before its GEMM.

shape baseline SUT saved speedup note
(8192, 5120) 108.6 us 81.9 us 26.6 us 1.33x eligible
(8192, 10240) 107.8 us 90.2 us 17.5 us 1.19x eligible
(8192, 2560) 107.7 us 79.9 us 27.8 us 1.35x eligible
(8192, 11328) 236.3 us 236.3 us 0.0 us 1.00x ineligible
(8192, 3584) 106.0 us 78.6 us 27.4 us 1.35x eligible
(5120, 8192) 101.2 us 76.0 us 25.3 us 1.33x eligible
(10240, 8192) 107.8 us 90.4 us 17.4 us 1.19x eligible
(2560, 8192) 101.4 us 74.9 us 26.4 us 1.35x eligible
(11328, 8192) 114.4 us 93.2 us 21.2 us 1.23x eligible
(3584, 8192) 101.6 us 74.9 us 26.7 us 1.36x eligible
(4096, 16384) 100.2 us 75.0 us 25.2 us 1.34x eligible
(14336, 16384) 232.1 us 197.5 us 34.6 us 1.18x eligible
  • 11/12 shapes get 1.18x – 1.36x on the quant + swizzle path.
  • The single ineligible shape (8192, 11328) shows 1.00x as expected;
    the gate clamped, the unfused fallback ran, and the result is byte-
    identical to baseline (no regression, no crash).
  • quant_only is unchanged on all shapes within noise -- writing
    swizzled SF inside the cast-fusion kernel is essentially free; the
    entire win comes from eliminating the standalone swizzle pass.
    Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py (also has a
    --profile mode for ncu / nsys).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…ectly

Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two
standalone swizzle kernels (rowwise + columnwise) whose only job was to
move scale factors into the layout cuBLAS LT consumes. The cast-fusion
kernel already had a `kEnableSwizzleSFOutput` switch for that, but the
framework never set the matching `with_gemm_swizzled_scales` flag on
NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through.

Changes:
* Single + grouped Hadamard cast-fusion kernels: drive
  `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`.
* NVFP4Quantizer create_tensor / convert_and_update_tensor /
  bulk_allocate_nvfp4_tensors: set the flag when
  `optimize_for_gemm && with_rht && shape eligible`, with eligibility
  in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion
  (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites.
* Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper
  in case a future low-level caller bypasses the gate.

The shape gate is part of this PR (not a follow-up) because LLaMA-class
shapes like (8192, 11328) have K%128==64. Without the gate the framework
would set the flag, dispatch would fall to the unfused path that can't
emit swizzled SF, and the process would abort. With the gate, ineligible
shapes silently fall back to the original code path.

Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median,
`quant + swizzle` path -- what te.Linear actually runs):

  (8192,  5120)    108.6 ->  81.9 us   1.33x   eligible
  (8192, 11328)    236.3 -> 236.3 us   1.00x   ineligible, gate clamped
  (11328, 8192)    114.4 ->  93.2 us   1.23x   eligible
  (14336,16384)    232.1 -> 197.5 us   1.18x   eligible

11/12 production-class shapes get 1.18x - 1.36x. The one ineligible
shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged
across all shapes -- the savings come entirely from eliminating the
standalone swizzle pass, not from a faster quant kernel.

Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py

Tests:
* new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py:
  byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases
  verifying the shape gate clamps correctly and that quantizer(x) on an
  ineligible shape does not raise.
* tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added
  optimize_for_gemm parametrization for the legacy grouped path.
* test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe
  variant already had the wiring).

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner May 19, 2026 03:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 19, 2026

Greptile Summary

This PR wires the existing kEnableSwizzleSFOutput template switch in the NVFP4 RHT cast-fusion kernels to the per-tensor with_gemm_swizzled_scales flag, removing the two standalone swizzle kernel launches that followed every quantize call in te.Linear. The framework side adds a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion and plumbs the flag through create_tensor, convert_and_update_tensor, and bulk_allocate_nvfp4_tensors, yielding a 1.18x–1.36x speedup on the quant + swizzle path for eligible LLaMA-class shapes.

  • Kernel side (row_cast_col_hadamard_transform_cast_fusion.cu, group_row_cast_col_hadamard_transform_cast_fusion.cu): replaces the hardcoded false / TODO with output.with_gemm_swizzled_scales; the grouped kernel adds an NVTE_CHECK that all tensors in a group share the same flag.
  • Framework side (quantizer.cpp, cast.cpp): introduces is_eligible_for_rht_cast_fusion (shape + SM100/110 check), gates with_gemm_swizzled_scales on optimize_for_gemm && with_rht && shape_eligible, and adds a belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper to catch any future misconfigured low-level callers.
  • Tests / benchmarks: new single-tensor fidelity test (test_nvfp4_rht_quantize_swizzle_fusion.py), extended group-quantize test with optimize_for_gemm parametrization, plus two benchmark scripts.

Confidence Score: 5/5

Safe to merge — the core eligibility gate, kernel dispatch, and group-consistency checks all look correct, and the unfused fallback path is protected by both the framework gate and a new defensive NVTE_CHECK.

The logic is well-contained: is_eligible_for_rht_cast_fusion correctly mirrors the dispatch condition in quantize_impl, the group allocator enforces consistent quantizer configs before ANDing in per-shape eligibility, and the kernel side simply reads the flag already set on the tensor. The one unresolved issue (missing skipif on test_nvfp4_rht_swizzle_fusion_shape_gate) was flagged in a prior review cycle and does not affect production correctness.

No files require special attention for merge safety; the test guard omission in test_nvfp4_rht_quantize_swizzle_fusion.py is a CI hygiene concern rather than a correctness risk.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Core framework change: new is_eligible_for_rht_cast_fusion helper and with_gemm_swizzled_scales wiring in create_tensor, convert_and_update_tensor, and quantize_with_rht_unfused_helper; logic is correct and the earlier NVTE_ERROR for non-BF16 RHT inputs prevents any mismatch from reaching the new defensive check.
transformer_engine/pytorch/csrc/extensions/cast.cpp Grouped-path with_gemm_swizzled_scales logic: correctly validates that all quantizers agree on optimize_for_gemm/with_rht, and per-shape eligibility uses the for_grouped_kernel=true (128-row alignment) variant; the previous P1 concern about missing validation was addressed by the author.
transformer_engine/common/hadamard_transform/row_cast_col_hadamard_transform_cast_fusion.cu Removes the hardcoded false TODO and reads output_.with_gemm_swizzled_scales to drive the existing kEnableSwizzleSFOutput compile-time switch; minimal and correct change.
transformer_engine/common/hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu Replaces hardcoded false with output_list[0]->with_gemm_swizzled_scales and adds an NVTE_CHECK loop enforcing all outputs share the same flag, consistent with the single-launch constraint.
tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py New single-tensor fidelity test; test_nvfp4_rht_quantize_swizzle_fusion is correctly guarded with @pytest.mark.skipif, but test_nvfp4_rht_swizzle_fusion_shape_gate is missing the same guard (flagged in a prior review cycle).
tests/pytorch/nvfp4/test_nvfp4_group_quantize.py Adds optimize_for_gemm parametrization with well-justified skip conditions for the unfused/ineligible-shape fallback; swizzle reference applied to cropped-but-no-padding-needed scale tensors is correct for the constrained shape set.
transformer_engine/pytorch/csrc/common.h Adds the is_eligible_for_rht_cast_fusion static method declaration; straightforward header addition.
benchmarks/benchmark_rht_cast_swizzle_fusion.py New benchmark script comparing quant-only vs quant+swizzle paths across production shapes; clean and well-documented.
benchmarks/profile_rht_cast_swizzle_fusion.py New profiling/verification script using torch.profiler to confirm standalone swizzle kernels disappear; the out-of-order import re after a function definition was flagged in a prior review cycle.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["NVFP4Quantizer(x)\noptimize_for_gemm=True\nwith_rht=True"] --> B["create_tensor / convert_and_update_tensor\nwith_gemm_swizzled_scales =\n  optimize_for_gemm && with_rht &&\n  is_eligible_for_rht_cast_fusion(shape)"]

    B --> C{Shape eligible?\nrows%64==0 &&\ncols%128==0 &&\nSM 100..110}

    C -- Yes --> D["with_gemm_swizzled_scales = true\n(output tensor pre-allocated)"]
    C -- No --> E["with_gemm_swizzled_scales = false\n(output tensor pre-allocated)"]

    D --> F["quantize_impl\neligible_for_rht_cast_fusion =\n  dtype==BF16 && shape OK"]
    E --> F

    F --> G{eligible_for_rht_cast_fusion?}

    G -- Yes --> H["hadamard_transform_cast_fusion\nkEnableSwizzleSFOutput driven by\noutput_.with_gemm_swizzled_scales\nEmits GEMM-swizzled SF directly"]
    G -- No, with_rht=True --> I["quantize_with_rht_unfused_helper\nNVTE_CHECK(!swizzled)\nEmits compact SF"]
    G -- No, with_rht=False --> J["nvte_quantize_v2\nEmits compact SF"]

    H --> K["tex.swizzle_scales_for_gemm_(t)\nEarly-returns (no-op)\nSF already in GEMM layout"]
    I --> L["tex.swizzle_scales_for_gemm_(t)\nLaunches standalone\nswizzle_{row,col}_scaling_kernel"]
    J --> L

    K --> M["cuBLAS LT NVFP4 GEMM\n1.18x-1.36x faster vs baseline"]
    L --> M
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +751 to +753
const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm &&
quantizer_cpp_list[0]->with_rht &&
all_tensors_rht_cast_fusion_eligible;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 optimize_for_gemm and with_rht read only from first quantizer without validation

with_gemm_swizzled_scales is derived exclusively from quantizer_cpp_list[0], so if any later quantizer in the group has a different optimize_for_gemm or with_rht value, its tensors are silently allocated with the wrong SF layout. The shape-eligibility loop below correctly iterates every tensor, but there is no matching check that all quantizers agree on optimize_for_gemm/with_rht. The split-quantize path at line 1276 documents this assumption explicitly (// Assume all quantizers have identical config); the same note or an NVTE_CHECK loop here would make the contract visible and consistent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with_gemm_swizzled_scales was derived from quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking that other quantizers in the group agreed; if any later quantizer had a different value, its tensors would be silently allocated with the wrong SF layout.

Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:

  • adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  • adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  • extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.

Reviewer feedback: with_gemm_swizzled_scales was derived from
quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking
that other quantizers in the group agreed; if any later quantizer
had a different value, its tensors would be silently allocated with
the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
  * adds an explicit comment block calling out the group-wide
    identical-config assumption and which fields this PR enforces
    vs. which are pre-existing;
  * adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
    and with_rht across the group (the two fields the
    with_gemm_swizzled_scales gate depends on), with error messages
    that print the offending tensor index and the disagreeing values;
  * extracts the [0] reads into group_optimize_for_gemm and
    group_with_rht locals so the same value feeds both the check
    and the gate.
Other from-[0] reads (rowwise_usage, row_scaled_nvfp4,
columnwise_usage, scaling_mode, dtype) are pre-existing assumptions
and remain out of scope for this PR.
Signed-off-by: Cael Ling <caell@nvidia.com>
Comment on lines +722 to +732
// Quantization parameters. Like the NVFP4 split-quantize path
// (see split_quantize_nvfp4_impl in this file), we assume all
// quantizers in the group share an identical config and read
// group-wide flags from quantizer_cpp_list[0]. The grouped RHT
// cast-fusion kernel honours a single with_gemm_swizzled_scales
// boolean across the whole group, so optimize_for_gemm and with_rht
// must in particular agree across all quantizers; the NVTE_CHECK
// loop below enforces that for the fields the swizzled-SF gate
// depends on. (The other group-wide reads from [0] -- rowwise_usage,
// row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype -- are
// pre-existing assumptions and out of scope for this PR.)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super happy about the style of those comments - they reference multiple other files, and while right now the comment matches the reality, it will easily drift. We should concentrate on commenting the invariants and assumptions needed for this file only.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

Comment on lines +744 to +753
// Only the RHT cast-fusion quant kernel supports direct swizzled SF
// emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 ->
// quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject
// a swizzled-flagged output, so we gate on with_rht to avoid silent
// data corruption / hard aborts on non-RHT paths. Additionally we
// require *all* tensors in the group to be shape-eligible for RHT
// cast-fusion, because the grouped kernel honours a single boolean
// and the unfused fallback rejects swizzled output (see NVTE_CHECK
// at group_row_cast_col_hadamard_transform_cast_fusion.cu and
// quantize_with_rht_unfused_helper).
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

Comment on lines +377 to +383
* Matches the dispatch logic in NVFP4Quantizer::quantize_impl.
* The dtype check (BF16) is implicit -- with_rht=True requires
* BF16 input by construction, so callers gate on with_rht first.
* When false, the dispatch falls back to quantize_with_rht_unfused
* which cannot emit GEMM-swizzled SF; framework gates that opt
* into with_gemm_swizzled_scales must therefore also check this
* to avoid mismatched-flag aborts in the fallback path.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is mostly talking about the internal implementation choices rather than what that function actually does (which is covered by the first sentence).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit removed the prose about dispatch internals and caller responsibilities.

* into with_gemm_swizzled_scales must therefore also check this
* to avoid mismatched-flag aborts in the fallback path.
*/
static bool is_eligible_for_rht_cast_fusion(size_t rows, size_t cols);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it take arbitrary shape rather than assuming it will be 2D?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Changed the signature to take the full tensor shape (const std::vector<size_t>& shape) and moved the get_2d_dims(...) flatten inside the function. All four call sites (create_tensor, convert_and_update_tensor, quantize_impl, and the grouped path in cast.cpp) now pass the shape directly without pre-flattening. The bulk loop in cast.cpp also no longer calls get_2d_dims per iteration since the function takes care of it.

Comment on lines +1764 to +1767
// Must mirror the eligibility check in NVFP4Quantizer::quantize_impl
// (search for "eligible_for_rht_cast_fusion" in this file). The dtype
// check (BF16) is implicit: with_rht is only valid for BF16 input by
// construction.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it have to mirror the check in that other function? Considering that both of these functions are in the same file and in the same class, can't we just call one from the other to keep a single source of truth?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, the proper fix is to call one from the other. quantize_impl now delegates the shape/arch predicate to NVFP4Quantizer::is_eligible_for_rht_cast_fusion(...) instead of re-inlining the same check. The BF16 dtype guard stays as an explicit && at the call site because it's specific to quantize_impl (the allocation callers don't have an input tensor to check). I also replaced the hand-rolled rows = product(input.shape[:-1]) loop with get_2d_dims(input.shape()) so the flattening rule isn't duplicated either. The shape/arch eligibility now has a single source of truth.

// neither of which supports emitting SF in the GEMM-swizzled layout (their
// backing kernels NVTE_CHECK reject swizzled-flagged output). Surface a clean
// error here instead of letting it abort deep inside the kernel with an
// opaque message. JAX hard-asserts eligibility upfront; PyTorch matches that
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we mention JAX in the PyTorch source files?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New commit dropped the JAX reference and the surrounding narration. The remaining 2-line comment just explains why this NVTE_CHECK is here.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 19, 2026

Please also handle the convert_and_update_tensor path since it also needs changes.

bool all_tensors_rht_cast_fusion_eligible = true;
for (size_t i = 0; i < num_tensors; ++i) {
const auto [rows, cols] = get_2d_dims(shape_list[i]);
if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouped kernel that supports the swizzle will only run for rows being divisible by 128, but this function will allow tensors divisible by 64.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — this was a real bug, not just an over-permissive style.

Before this fix, is_eligible_for_rht_cast_fusion(shape) used a single row-alignment constraint of rows % 64 == 0 (the single-tensor RHT cast-fusion kernel's entry check at row_cast_col_hadamard_transform_cast_fusion.cu:1161). The bulk-allocation path in cast.cpp was calling this same lax check, so shapes like rows in {64, 192, 320, ...} — all satisfying % 64 == 0 — would pass eligibility, get with_gemm_swizzled_scales=True, and then hard-abort inside the grouped kernel whose entry asserts first_logical_dim % 128 == 0
(graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385).

The fix adds a for_grouped_kernel parameter on is_eligible_for_rht_cast_fusion so callers select the constraint
that matches the kernel they will actually invoke:

  • false (default): rows % 64 == 0, single-tensor kernel
  • true: rows % 128 == 0, grouped kernel

The bulk-allocation caller in cast.cpp passes /*for_grouped_kernel=*/true; the three single-tensor callers
(create_tensor, convert_and_update_tensor, quantize_impl) keep the default false. Shapes with rows in {64, 192, 320, ...} now correctly fail the grouped-path eligibility and fall back to the unfused path instead of reaching the grouped kernel.

// (search for "eligible_for_rht_cast_fusion" in this file). The dtype
// check (BF16) is implicit: with_rht is only valid for BF16 input by
// construction.
return rows % 64 == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 &&
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the rows % 64 == 0 a requirement here rather than rows % 128 == 0?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 64 here is correct for the single-tensor cast-fusion kernel — its entry check is NVTE_CHECK(M % 64 == 0, ...) at row_cast_col_hadamard_transform_cast_fusion.cu:1161. The 128 you're thinking of is the grouped kernel's stricter requirement at graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385.

cael-ling and others added 2 commits May 19, 2026 20:14
Functional fix:
- `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT
  eligibility check (`rows % 64 == 0`), but the grouped kernel asserts
  `first_logical_dim % 128 == 0` at entry. Shapes with rows in
  {64, 192, 320, ...} would pass eligibility, set
  `with_gemm_swizzled_scales=True`, and then hard-abort inside the
  grouped kernel with an opaque NVTE_CHECK message. Adding a
  `for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion`
  selects the correct row alignment: 64 for the single-tensor kernel,
  128 for the grouped variant. Only the bulk-allocation caller passes
  `true`; the three single-tensor callers keep the default `false`.
Refactors:
- `is_eligible_for_rht_cast_fusion` now takes the full tensor shape
  (`std::vector<size_t>`) and flattens internally with `get_2d_dims`,
  so the four call sites no longer pre-flatten and duplicate the
  flatten rule.
- `quantize_impl` delegates the shape/arch eligibility to
  `is_eligible_for_rht_cast_fusion` instead of inlining the same
  predicate, and its hand-rolled `rows = product(shape[:-1])` loop is
  replaced with `get_2d_dims(input.shape())`. The shape/arch
  eligibility now has a single source of truth.
Comment cleanups:
- Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`,
  `create_tensor`, `convert_and_update_tensor`, and
  `quantize_with_rht_unfused_helper`. Removed cross-references to
  other functions/files, code narration of subsequent lines, the JAX
  reference in PyTorch source, and the "see X for rationale" pattern.
- Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single
  brief sentence.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling
Copy link
Copy Markdown
Contributor Author

Please also handle the convert_and_update_tensor path since it also needs changes.

Done. Both create_tensor and convert_and_update_tensor now have the same 2-line comment on the gating; removed the previous "See NVFP4Quantizer::create_tensor for the rationale" cross-reference. I also trimmed create_tensor's long rationale block (which referenced specific .cu/.cuh filenames and quantize_with_rht_unfused's internal behavior) in the same pass, so the two functions are consistent.

@cael-ling cael-ling requested a review from ptrendx May 21, 2026 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants