[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011
[PyTorch] NVFP4 RHT cast-fusion: emit GEMM-swizzled scale factors directly#3011cael-ling wants to merge 5 commits into
Conversation
…ectly Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a `kEnableSwizzleSFOutput` switch for that, but the framework never set the matching `with_gemm_swizzled_scales` flag on NVFP4 outputs -- it was a `false` with a TODO. This PR wires it through. Changes: * Single + grouped Hadamard cast-fusion kernels: drive `kEnableSwizzleSFOutput` from `output.with_gemm_swizzled_scales`. * NVFP4Quantizer create_tensor / convert_and_update_tensor / bulk_allocate_nvfp4_tensors: set the flag when `optimize_for_gemm && with_rht && shape eligible`, with eligibility in a new static helper NVFP4Quantizer::is_eligible_for_rht_cast_fusion (rows%64==0 && cols%128==0 && SM100/110) shared by all three sites. * Belt-and-suspenders NVTE_CHECK in quantize_with_rht_unfused_helper in case a future low-level caller bypasses the gate. The shape gate is part of this PR (not a follow-up) because LLaMA-class shapes like (8192, 11328) have K%128==64. Without the gate the framework would set the flag, dispatch would fall to the unfused path that can't emit swizzled SF, and the process would abort. With the gate, ineligible shapes silently fall back to the original code path. Numbers (GB200 SM100, bf16, rowwise+columnwise, RHT, per-quantize median, `quant + swizzle` path -- what te.Linear actually runs): (8192, 5120) 108.6 -> 81.9 us 1.33x eligible (8192, 11328) 236.3 -> 236.3 us 1.00x ineligible, gate clamped (11328, 8192) 114.4 -> 93.2 us 1.23x eligible (14336,16384) 232.1 -> 197.5 us 1.18x eligible 11/12 production-class shapes get 1.18x - 1.36x. The one ineligible shape gets 1.00x (= unchanged, no regression). `quant_only` is unchanged across all shapes -- the savings come entirely from eliminating the standalone swizzle pass, not from a faster quant kernel. Repro: benchmarks/benchmark_rht_cast_swizzle_fusion.py Tests: * new tests/pytorch/nvfp4/test_nvfp4_rht_quantize_swizzle_fusion.py: byte-equal SF / FP4 data / amax vs swizzled reference; plus 5 cases verifying the shape gate clamps correctly and that quantizer(x) on an ineligible shape does not raise. * tests/pytorch/nvfp4/test_nvfp4_group_quantize.py: added optimize_for_gemm parametrization for the legacy grouped path. * test_nvfp4_group_quantize_graph_safe.py passes unchanged (graph-safe variant already had the wiring). Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR wires the existing
Confidence Score: 5/5Safe to merge — the core eligibility gate, kernel dispatch, and group-consistency checks all look correct, and the unfused fallback path is protected by both the framework gate and a new defensive NVTE_CHECK. The logic is well-contained: No files require special attention for merge safety; the test guard omission in Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["NVFP4Quantizer(x)\noptimize_for_gemm=True\nwith_rht=True"] --> B["create_tensor / convert_and_update_tensor\nwith_gemm_swizzled_scales =\n optimize_for_gemm && with_rht &&\n is_eligible_for_rht_cast_fusion(shape)"]
B --> C{Shape eligible?\nrows%64==0 &&\ncols%128==0 &&\nSM 100..110}
C -- Yes --> D["with_gemm_swizzled_scales = true\n(output tensor pre-allocated)"]
C -- No --> E["with_gemm_swizzled_scales = false\n(output tensor pre-allocated)"]
D --> F["quantize_impl\neligible_for_rht_cast_fusion =\n dtype==BF16 && shape OK"]
E --> F
F --> G{eligible_for_rht_cast_fusion?}
G -- Yes --> H["hadamard_transform_cast_fusion\nkEnableSwizzleSFOutput driven by\noutput_.with_gemm_swizzled_scales\nEmits GEMM-swizzled SF directly"]
G -- No, with_rht=True --> I["quantize_with_rht_unfused_helper\nNVTE_CHECK(!swizzled)\nEmits compact SF"]
G -- No, with_rht=False --> J["nvte_quantize_v2\nEmits compact SF"]
H --> K["tex.swizzle_scales_for_gemm_(t)\nEarly-returns (no-op)\nSF already in GEMM layout"]
I --> L["tex.swizzle_scales_for_gemm_(t)\nLaunches standalone\nswizzle_{row,col}_scaling_kernel"]
J --> L
K --> M["cuBLAS LT NVFP4 GEMM\n1.18x-1.36x faster vs baseline"]
L --> M
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| const bool with_gemm_swizzled_scales = quantizer_cpp_list[0]->optimize_for_gemm && | ||
| quantizer_cpp_list[0]->with_rht && | ||
| all_tensors_rht_cast_fusion_eligible; |
There was a problem hiding this comment.
optimize_for_gemm and with_rht read only from first quantizer without validation
with_gemm_swizzled_scales is derived exclusively from quantizer_cpp_list[0], so if any later quantizer in the group has a different optimize_for_gemm or with_rht value, its tensors are silently allocated with the wrong SF layout. The shape-eligibility loop below correctly iterates every tensor, but there is no matching check that all quantizers agree on optimize_for_gemm/with_rht. The split-quantize path at line 1276 documents this assumption explicitly (// Assume all quantizers have identical config); the same note or an NVTE_CHECK loop here would make the contract visible and consistent.
There was a problem hiding this comment.
with_gemm_swizzled_scales was derived from quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking that other quantizers in the group agreed; if any later quantizer had a different value, its tensors would be silently allocated with the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
- adds an explicit comment block calling out the group-wide
identical-config assumption and which fields this PR enforces
vs. which are pre-existing; - adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
and with_rht across the group (the two fields the
with_gemm_swizzled_scales gate depends on), with error messages
that print the offending tensor index and the disagreeing values; - extracts the [0] reads into group_optimize_for_gemm and
group_with_rht locals so the same value feeds both the check
and the gate.
Reviewer feedback: with_gemm_swizzled_scales was derived from
quantizer_cpp_list[0]->optimize_for_gemm / with_rht without checking
that other quantizers in the group agreed; if any later quantizer
had a different value, its tensors would be silently allocated with
the wrong SF layout.
Following the precedent of the split-quantize path at line 1276
(// Assume all quantizers have identical config), this commit:
* adds an explicit comment block calling out the group-wide
identical-config assumption and which fields this PR enforces
vs. which are pre-existing;
* adds an NVTE_CHECK loop enforcing identical optimize_for_gemm
and with_rht across the group (the two fields the
with_gemm_swizzled_scales gate depends on), with error messages
that print the offending tensor index and the disagreeing values;
* extracts the [0] reads into group_optimize_for_gemm and
group_with_rht locals so the same value feeds both the check
and the gate.
Other from-[0] reads (rowwise_usage, row_scaled_nvfp4,
columnwise_usage, scaling_mode, dtype) are pre-existing assumptions
and remain out of scope for this PR.
Signed-off-by: Cael Ling <caell@nvidia.com>
| // Quantization parameters. Like the NVFP4 split-quantize path | ||
| // (see split_quantize_nvfp4_impl in this file), we assume all | ||
| // quantizers in the group share an identical config and read | ||
| // group-wide flags from quantizer_cpp_list[0]. The grouped RHT | ||
| // cast-fusion kernel honours a single with_gemm_swizzled_scales | ||
| // boolean across the whole group, so optimize_for_gemm and with_rht | ||
| // must in particular agree across all quantizers; the NVTE_CHECK | ||
| // loop below enforces that for the fields the swizzled-SF gate | ||
| // depends on. (The other group-wide reads from [0] -- rowwise_usage, | ||
| // row_scaled_nvfp4, columnwise_usage, scaling_mode, dtype -- are | ||
| // pre-existing assumptions and out of scope for this PR.) |
There was a problem hiding this comment.
I'm not super happy about the style of those comments - they reference multiple other files, and while right now the comment matches the reality, it will easily drift. We should concentrate on commenting the invariants and assumptions needed for this file only.
There was a problem hiding this comment.
New commit removed the prose about dispatch internals and caller responsibilities.
| // Only the RHT cast-fusion quant kernel supports direct swizzled SF | ||
| // emission. Other NVFP4 quant kernels (e.g. nvte_quantize_v2 -> | ||
| // quantize_nvfp4.cuh, quantize_transpose_nvfp4.cuh) NVTE_CHECK reject | ||
| // a swizzled-flagged output, so we gate on with_rht to avoid silent | ||
| // data corruption / hard aborts on non-RHT paths. Additionally we | ||
| // require *all* tensors in the group to be shape-eligible for RHT | ||
| // cast-fusion, because the grouped kernel honours a single boolean | ||
| // and the unfused fallback rejects swizzled output (see NVTE_CHECK | ||
| // at group_row_cast_col_hadamard_transform_cast_fusion.cu and | ||
| // quantize_with_rht_unfused_helper). |
There was a problem hiding this comment.
New commit removed the prose about dispatch internals and caller responsibilities.
| * Matches the dispatch logic in NVFP4Quantizer::quantize_impl. | ||
| * The dtype check (BF16) is implicit -- with_rht=True requires | ||
| * BF16 input by construction, so callers gate on with_rht first. | ||
| * When false, the dispatch falls back to quantize_with_rht_unfused | ||
| * which cannot emit GEMM-swizzled SF; framework gates that opt | ||
| * into with_gemm_swizzled_scales must therefore also check this | ||
| * to avoid mismatched-flag aborts in the fallback path. |
There was a problem hiding this comment.
Again, this is mostly talking about the internal implementation choices rather than what that function actually does (which is covered by the first sentence).
There was a problem hiding this comment.
New commit removed the prose about dispatch internals and caller responsibilities.
| * into with_gemm_swizzled_scales must therefore also check this | ||
| * to avoid mismatched-flag aborts in the fallback path. | ||
| */ | ||
| static bool is_eligible_for_rht_cast_fusion(size_t rows, size_t cols); |
There was a problem hiding this comment.
Shouldn't it take arbitrary shape rather than assuming it will be 2D?
There was a problem hiding this comment.
Good point. Changed the signature to take the full tensor shape (const std::vector<size_t>& shape) and moved the get_2d_dims(...) flatten inside the function. All four call sites (create_tensor, convert_and_update_tensor, quantize_impl, and the grouped path in cast.cpp) now pass the shape directly without pre-flattening. The bulk loop in cast.cpp also no longer calls get_2d_dims per iteration since the function takes care of it.
| // Must mirror the eligibility check in NVFP4Quantizer::quantize_impl | ||
| // (search for "eligible_for_rht_cast_fusion" in this file). The dtype | ||
| // check (BF16) is implicit: with_rht is only valid for BF16 input by | ||
| // construction. |
There was a problem hiding this comment.
Why does it have to mirror the check in that other function? Considering that both of these functions are in the same file and in the same class, can't we just call one from the other to keep a single source of truth?
There was a problem hiding this comment.
Correct, the proper fix is to call one from the other. quantize_impl now delegates the shape/arch predicate to NVFP4Quantizer::is_eligible_for_rht_cast_fusion(...) instead of re-inlining the same check. The BF16 dtype guard stays as an explicit && at the call site because it's specific to quantize_impl (the allocation callers don't have an input tensor to check). I also replaced the hand-rolled rows = product(input.shape[:-1]) loop with get_2d_dims(input.shape()) so the flattening rule isn't duplicated either. The shape/arch eligibility now has a single source of truth.
| // neither of which supports emitting SF in the GEMM-swizzled layout (their | ||
| // backing kernels NVTE_CHECK reject swizzled-flagged output). Surface a clean | ||
| // error here instead of letting it abort deep inside the kernel with an | ||
| // opaque message. JAX hard-asserts eligibility upfront; PyTorch matches that |
There was a problem hiding this comment.
Why do we mention JAX in the PyTorch source files?
There was a problem hiding this comment.
New commit dropped the JAX reference and the surrounding narration. The remaining 2-line comment just explains why this NVTE_CHECK is here.
|
Please also handle the convert_and_update_tensor path since it also needs changes. |
| bool all_tensors_rht_cast_fusion_eligible = true; | ||
| for (size_t i = 0; i < num_tensors; ++i) { | ||
| const auto [rows, cols] = get_2d_dims(shape_list[i]); | ||
| if (!NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)) { |
There was a problem hiding this comment.
The grouped kernel that supports the swizzle will only run for rows being divisible by 128, but this function will allow tensors divisible by 64.
There was a problem hiding this comment.
Good catch — this was a real bug, not just an over-permissive style.
Before this fix, is_eligible_for_rht_cast_fusion(shape) used a single row-alignment constraint of rows % 64 == 0 (the single-tensor RHT cast-fusion kernel's entry check at row_cast_col_hadamard_transform_cast_fusion.cu:1161). The bulk-allocation path in cast.cpp was calling this same lax check, so shapes like rows in {64, 192, 320, ...} — all satisfying % 64 == 0 — would pass eligibility, get with_gemm_swizzled_scales=True, and then hard-abort inside the grouped kernel whose entry asserts first_logical_dim % 128 == 0
(graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385).
The fix adds a for_grouped_kernel parameter on is_eligible_for_rht_cast_fusion so callers select the constraint
that matches the kernel they will actually invoke:
false(default):rows % 64 == 0, single-tensor kerneltrue:rows % 128 == 0, grouped kernel
The bulk-allocation caller in cast.cpp passes /*for_grouped_kernel=*/true; the three single-tensor callers
(create_tensor, convert_and_update_tensor, quantize_impl) keep the default false. Shapes with rows in {64, 192, 320, ...} now correctly fail the grouped-path eligibility and fall back to the unfused path instead of reaching the grouped kernel.
| // (search for "eligible_for_rht_cast_fusion" in this file). The dtype | ||
| // check (BF16) is implicit: with_rht is only valid for BF16 input by | ||
| // construction. | ||
| return rows % 64 == 0 && cols % 128 == 0 && transformer_engine::cuda::sm_arch() >= 100 && |
There was a problem hiding this comment.
Why is the rows % 64 == 0 a requirement here rather than rows % 128 == 0?
There was a problem hiding this comment.
The 64 here is correct for the single-tensor cast-fusion kernel — its entry check is NVTE_CHECK(M % 64 == 0, ...) at row_cast_col_hadamard_transform_cast_fusion.cu:1161. The 128 you're thinking of is the grouped kernel's stricter requirement at graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu:1385.
Functional fix:
- `bulk_allocate_nvfp4_tensors` previously used the single-tensor RHT
eligibility check (`rows % 64 == 0`), but the grouped kernel asserts
`first_logical_dim % 128 == 0` at entry. Shapes with rows in
{64, 192, 320, ...} would pass eligibility, set
`with_gemm_swizzled_scales=True`, and then hard-abort inside the
grouped kernel with an opaque NVTE_CHECK message. Adding a
`for_grouped_kernel` parameter on `is_eligible_for_rht_cast_fusion`
selects the correct row alignment: 64 for the single-tensor kernel,
128 for the grouped variant. Only the bulk-allocation caller passes
`true`; the three single-tensor callers keep the default `false`.
Refactors:
- `is_eligible_for_rht_cast_fusion` now takes the full tensor shape
(`std::vector<size_t>`) and flattens internally with `get_2d_dims`,
so the four call sites no longer pre-flatten and duplicate the
flatten rule.
- `quantize_impl` delegates the shape/arch eligibility to
`is_eligible_for_rht_cast_fusion` instead of inlining the same
predicate, and its hand-rolled `rows = product(shape[:-1])` loop is
replaced with `get_2d_dims(input.shape())`. The shape/arch
eligibility now has a single source of truth.
Comment cleanups:
- Trimmed verbose comments in `bulk_allocate_nvfp4_tensors`,
`create_tensor`, `convert_and_update_tensor`, and
`quantize_with_rht_unfused_helper`. Removed cross-references to
other functions/files, code narration of subsequent lines, the JAX
reference in PyTorch source, and the "see X for rationale" pattern.
- Doxygen on `is_eligible_for_rht_cast_fusion` reduced to a single
brief sentence.
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Done. Both |
Description
Before this PR every NVFP4 RHT-cast-fusion quantize was followed by two standalone swizzle kernels (rowwise + columnwise) whose only job was to move scale factors into the layout cuBLAS LT consumes. The cast-fusion kernel already had a
kEnableSwizzleSFOutputswitch for that, but the framework never set the matchingwith_gemm_swizzled_scalesflag onNVFP4 outputs -- it was a
falsewith aTODO. This PR wires it through and saves ~25 us per quantize on LLaMA-class shapes (1.18x – 1.36x on thequant + swizzlepath thatte.Linearruns).Type of change
Changes
Kernel side (
transformer_engine/common/hadamard_transform/):row_cast_col_hadamard_transform_cast_fusion.cu&group_row_cast_col_hadamard_transform_cast_fusion.cu: drive theexisting
kEnableSwizzleSFOutputtemplate parameter fromoutput.with_gemm_swizzled_scales. The grouped kernel additionallyNVTE_CHECKs the flag is consistent across all tensors in a group(it honours a single boolean).
no change.
Framework side (
transformer_engine/pytorch/csrc/):NVFP4Quantizer::is_eligible_for_rht_cast_fusion(rows, cols)mirroring the dispatch-time eligibility check in
NVFP4Quantizer::quantize_impl(rows%64==0 && cols%128==0 && SM100/110).NVFP4Quantizer::create_tensor,NVFP4Quantizer::convert_and_update_tensor,and
bulk_allocate_nvfp4_tensorsnow setwith_gemm_swizzled_scales = optimize_for_gemm && with_rht && shape_eligible.For the grouped allocator the flag is True only if every tensor in
the group is eligible.
NVTE_CHECK(!out.with_gemm_swizzled_scales)atthe entry of
quantize_with_rht_unfused_helper. The framework gatealready keeps user code from tripping it; this only fires if a future
low-level caller bypasses the gate.
Performance
SM100a, bf16 input, rowwise + columnwise SF, RHT + post-RHT amax.
Per-quantize wall-clock median via
torch.utils.benchmark.Timer.blocked_autorange.quant + swizzle=quantizer(x); tex.swizzle_scales_for_gemm_(t)--exactly what
te.Linearruns before its GEMM.quant + swizzlepath.(8192, 11328)shows 1.00x as expected;the gate clamped, the unfused fallback ran, and the result is byte-
identical to baseline (no regression, no crash).
quant_onlyis unchanged on all shapes within noise -- writingswizzled SF inside the cast-fusion kernel is essentially free; the
entire win comes from eliminating the standalone swizzle pass.
Repro:
benchmarks/benchmark_rht_cast_swizzle_fusion.py(also has a--profilemode for ncu / nsys).Checklist: