[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974vthumbe1503 wants to merge 17 commits into
Conversation
Greptile SummaryThis PR fixes several DCP checkpoint loading bugs for FSDP2 with quantized models, resolves NVFP4 allgather numerical errors, and removes the
Confidence Score: 5/5Safe to merge; changes are surgical bug fixes targeting well-defined FSDP2 + DCP failure modes with no impact on the forward/backward training path. All four fix areas are self-contained and well-reasoned. The getattr safe-global hole from the prior review is resolved. The only finding is a truncated comment. No files require special attention; the storage and tensor files are changed consistently across all four quantization types. Important Files Changed
Reviews (10): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
Empty storage breaks shared-storage detection in existing callers
QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
There was a problem hiding this comment.
Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.
There was a problem hiding this comment.
Need to resolve this comment after going thoroughly over noop_cat consequences on Quantizedtensors
There was a problem hiding this comment.
The behavior is unchanged with the change. And I would argue the implementation now is more correct with the change. untyped_storage() default implementation from QuantizedTensor(torch.Tensor) before this change, gives a storage with two properties.
-
storage.nbytes() returns bytes based on the fake_dtype that we use to register our QuantizedTensor as a torchTensor using make_wrapper_subclass method of torch.
-
storage.data_ptr() gives an error saying it is an invalid storage and there is no data_ptr()
Both of them is not ideal.
The first one is grossly incrorrect due to two reasons. First we manage the backing storage for the inner tensors of QuantizedTensor and torch has no idea about it. Second nbytes based on fake_dtype is misleading since that might not actually be the number of bytes we actually allocate.
Second one is causing problems with FSDP2 now since it expects some storage for identity check.
For QuantizedTensor, noop_cat today always returns an actual torch.cat which goes through a dequantization luckily due to this condition being true. This condition is going to be true now with the change as well since nbytes() would return 0.
If we do QuantizedTensor.data_ptr() today it gives you 0. QuantizedTensor.untyped_storage().data_ptr() will give invalid storage error which is inconsistent. And giving empty storage as empty storage will fix this inconsitency.
As far as idenity checking goes, FSDP2 does all the comparisong logic only if data_ptr() is not 0. And it also doesnt really make sense to compare two empty storages.
|
/te-ci L1 pytorch |
| msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) | ||
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances |
There was a problem hiding this comment.
Why do we need dequant + quant here?
There was a problem hiding this comment.
We are doing it anymore
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. | ||
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() |
There was a problem hiding this comment.
An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).
>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() | |
| target_size = size |
There was a problem hiding this comment.
Changed the torch dispatch function now. So we dont have size here
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
3589ffa to
4197bee
Compare
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
Description
Fixes DCP Sync checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes NVFP4 allgather + dequant numerical errors for fsdp2. Turns out this was due to us not setting the fsdp group as the amax reduction group in the quantizer
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Untyped_storage implementation needed for FSDP2 + DCP
DCP Aync/Sync Checkpoint loading
NVFP4 Allgather Correctness issues
TE_DType Serialization issues with DCP Checkpointing
Checklist: