Skip to content

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974

Open
vthumbe1503 wants to merge 17 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix
Open

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
vthumbe1503 wants to merge 17 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 11, 2026

Description

Fixes DCP Sync checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the fsdp group as the amax reduction group in the quantizer

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Untyped_storage implementation needed for FSDP2 + DCP

    • untyped_storage is now defined for the base QuantizedTensor to return empty storage. Untyped_storage refers to the backing storage that we use to create all the internal tensors. Since we use make_wrapper_subclass to create TE QuantizedTensors, we use dont have any backing storage associated with the tensor. data_ptr on our Custom QuantizedTensor also returns 0.
    • The main issue is that FSDP2 maintains sharded param tensor for checkpointing. It does so by calling view(-1) on our Quantized sharded model parameters. We return back a dequantized 1D tensor in TE. So, the sharded tensor that FSDP2 maintains for checkpointing is BF16 and Quantized sharded param is our custom FP8 tensor. It evaluates untyped_storage(BF16 sharded tensor reloaded from disk) == untyped_storage(Quantized sharded parameter) to see if the same_tensor. With us returning empty storage now, this would never be equal to sharded tensor's untyped storage.
  • DCP Aync/Sync Checkpoint loading

    • For Sync cases previously we were going through the route of dequantization to BF16 before saving to disk, which happened through the to_new_empty function
    • For both syn/async, dequantizing is not ideal. And so we now have .cpu() and .to() implemented for QuantizedTensor which dont go through dequantization and rather just copy inner tensors of QuantizedTensor to cpu if needed in blocking/non-blocking way.
  • NVFP4 Allgather Correctness issues

    • Allgather with FSDP2 was very far away from fp32 allgather for the same values. This was due to us not setting the amax reduction group in the quantizer.
  • TE_DType Serialization issues with DCP Checkpointing

    • DCP uses torch.load(weights_only=True), whose Unpickler rejects every GLOBAL reference that isn't in add_safe_globals — and getattr is intentionally not allow-listed.
    • So we override the default enum reduction in pybind:
default:      (getattr, (tex.DType, "kFloat8E4M3"))   # needs getattr + tex.DType allow-listed
pybind override: (tex.DType, (int_value,))            # only needs tex.DType allow-listed

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes several DCP checkpoint loading bugs for FSDP2 with quantized models, resolves NVFP4 allgather numerical errors, and removes the getattr safe-global hole from the previous iteration.

  • untyped_storage() / _to_copy in QuantizedTensor: Returns an empty zero-byte storage so FSDP2's staged-tensor identity check never falsely matches a quantized param; a new _to_copy dispatch handler preserves the quantized subclass when moving to CPU rather than dequantizing.
  • __reduce_ex__ refactor across all tensor types: Each type now points to a module-level _make_*_in_reduce_ex function added to torch.serialization.add_safe_globals, replacing the classmethod approach that required getattr to be safe-listed.
  • tex.DType pybind11 pickle override: Custom __reduce__/__reduce_ex__ emit (tex.DType, (int,)) so only the class itself needs to be safe-listed.
  • NVFP4 amax_reduction_group fix: base.py now sets the FSDP shard group on NVFP4Quantizer, fixing the allgather amax reduction that caused large numerical error.

Confidence Score: 5/5

Safe to merge; changes are surgical bug fixes targeting well-defined FSDP2 + DCP failure modes with no impact on the forward/backward training path.

All four fix areas are self-contained and well-reasoned. The getattr safe-global hole from the prior review is resolved. The only finding is a truncated comment.

No files require special attention; the storage and tensor files are changed consistently across all four quantization types.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds untyped_storage() returning zero-byte storage (fixes FSDP2 identity checks), new _to_copy dispatch handler (preserves QuantizedTensor subclass when moving devices), and returns QuantizedTensor from cpu() instead of dequantizing.
transformer_engine/pytorch/init.py Adds torch.serialization.add_safe_globals for all tensor types, storage classes, quantizers, module-level reconstruct functions, and tex.DType; getattr is no longer added.
transformer_engine/common/util/pybind_helper.h Overrides __reduce_ex__ and __reduce__ for the TE_DType pybind11 enum to emit (tex.DType, (int_value,)) instead of the default (getattr, (cls, name)) form.
transformer_engine/pytorch/module/base.py Extends the amax_reduction_group assignment to cover NVFP4Quantizer in addition to Float8CurrentScalingQuantizer.
transformer_engine/pytorch/tensor/float8_tensor.py Adds __reduce_ex__ pointing at the new module-level _make_float8_tensor_in_reduce_ex, removing the old CPU-dequantize fallback; device is now explicitly passed at every construction site.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Replaces classmethod reference in __reduce_ex__ with module-level _make_mxfp8_tensor_in_reduce_ex, adds device=tensor.device throughout dispatch handlers.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Same refactor as MXFP8/Float8 tensors; additionally fixes device=rowwise_data.device in post_all_gather reconstruction.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Refactors __reduce_ex__ to use module-level function, removes the per-buffer untyped_storage() override (now handled by the base class), threads device= through all reconstruction and dispatch sites.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Removes device from get_metadata() since all construction paths now supply it explicitly; Float8Tensor already had a pure-PyTorch CPU dequantize fallback.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py Adds CPU-to-CUDA-to-CPU bounce in _FromMXFP8Func.forward so that dequantizing a CPU-resident MXFP8Tensor still produces correct results via tex.dequantize.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Same CPU-to-CUDA bounce pattern as MXFP8; one truncated comment at line 55 needs completing.

Reviews (10): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.

  1. storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.

  2. storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()

Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.

For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.

If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.

As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Outdated
Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Outdated
Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need dequant + quant here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are doing it anymore

Comment thread transformer_engine/pytorch/quantized_tensor.py Outdated
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).

>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
Suggested change
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
target_size = size

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the torch dispatch function now. So we dont have size here

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Comment thread tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Outdated
Comment thread transformer_engine/pytorch/quantized_tensor.py Outdated
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the fsdp2_dcp_laod_fix branch from 3589ffa to 4197bee Compare May 13, 2026 04:00
@vthumbe1503 vthumbe1503 requested a review from ksivaman as a code owner May 13, 2026 04:00
pre-commit-ci Bot and others added 2 commits May 13, 2026 04:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/__init__.py Outdated
vthumbe1503 and others added 6 commits May 13, 2026 04:19
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503 vthumbe1503 added the bug Something isn't working label May 18, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

timmoon10
timmoon10 previously approved these changes May 19, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py Outdated
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants