GGEMM+srelu kernels for MxFP8 Nemotron#2981
Conversation
|
/te-ci pytorch |
|
Please sign-off your commits @sraman-rgb |
Greptile SummaryThis PR refactors the fused GroupedMLP kernel hierarchy into a shared base class and adds
Confidence Score: 5/5The refactor is well-structured and the SReLU kernel wiring follows the established GLU pattern closely; the two flagged items are clarifying questions rather than confirmed failures. The class hierarchy generalisation is clean, dscales_tensor is always an allocated tensor, the recompute-FC2-input path is guarded by multiple independent checks, and test coverage spans both unit-level ScaledSReLU and the full grouped-MLP integration. forward_grouped_mlp.py (prob_tensor dtype) and _common.py (_nvidia_cudnn_frontend_supports_wgrad guard) Important Files Changed
Sequence DiagramsequenceDiagram
participant Fuser
participant GLUFwd as ForwardGroupedMLP_CuTeGEMMGLU_MXFP8
participant SReLUFwd as ForwardGroupedMLP_CuTeGEMMUnary_MXFP8
participant SReLUBwd as BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8
participant cuDNN as cuDNN FE Kernels
Fuser->>GLUFwd: fuse_forward_ops GLU pattern
GLUFwd->>cuDNN: grouped_gemm_glu_wrapper_sm100
cuDNN-->>GLUFwd: fc2_in scales and activation_in
Fuser->>SReLUFwd: fuse_forward_srelu_ops SReLU pattern
SReLUFwd->>cuDNN: grouped_gemm_srelu_wrapper_sm100
cuDNN-->>SReLUFwd: fc2_in scales and activation_in
Note over SReLUFwd: Save activation_in and scales
Note over SReLUFwd: optionally skip saving fc2_x
Fuser->>SReLUBwd: fuse_backward_srelu_ops
SReLUBwd->>cuDNN: grouped_gemm_dsrelu_wrapper_sm100
cuDNN-->>SReLUBwd: FC1 dy tensors and grad_scales
cuDNN-->>SReLUBwd: optional recomputed FC2 input
SReLUBwd->>cuDNN: grouped_gemm_wgrad for FC1 and FC2
Reviews (8): Last reviewed commit: "Address grouped MLP ScaledSReLU review c..." | Re-trigger Greptile |
8373402 to
765d2e9
Compare
Signed-off-by: sraman-rgb <sraman@nvidia.com>
765d2e9 to
43093cc
Compare
timmoon10
left a comment
There was a problem hiding this comment.
Overall looks good, but we've gotten to the point where we need to start thinking about how to gracefully handle adding new activations. It seems that every model has a different activation function.
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
912b1d9 to
46b3169
Compare
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM. We might want to wait on the cudnn release and apt cudnn guards are added.
| else: | ||
| try: | ||
| validate_grouped_mlp_dims(window[0], window[1], window[2]) | ||
| except (TypeError, ValueError): | ||
| matches_pattern = False |
There was a problem hiding this comment.
We would want to disable srelu fusion based on cudnn version here eventually before the merge
| scales.detach().to(dtype=dtype).reshape(-1, 1, 1) | ||
| if scales is not None | ||
| else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) |
There was a problem hiding this comment.
This might be a hold over from before right? And we do expect scales passed to be never None. So we can revert the change?
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: