Skip to content

Add MXFP8 quantized_model_init memory profiler for FSDP2 qinit analysis#3008

Draft
savitha-eng wants to merge 5 commits into
NVIDIA:mainfrom
savitha-eng:savitha/mxfp8-memory-profiler
Draft

Add MXFP8 quantized_model_init memory profiler for FSDP2 qinit analysis#3008
savitha-eng wants to merge 5 commits into
NVIDIA:mainfrom
savitha-eng:savitha/mxfp8-memory-profiler

Conversation

@savitha-eng
Copy link
Copy Markdown

@savitha-eng savitha-eng commented May 18, 2026

Summary

Standalone memory profiler script for diagnosing MXFP8 quantized_model_init memory behavior with FSDP2. Creates one or more te.TransformerLayer blocks with 8B-scale dimensions, wraps with FSDP2 fully_shard, and runs forward+backward+step iterations while recording PyTorch memory history.

Issue observed: When using quantized_model_init + FSDP2, MXFP8 quantized weight tensors from mxfp8_tensor.py:quantize_impl are never freed. FSDP2 calls .view(numel,) to flatten params, which triggers _ViewFunc dequantize fallback, and the intermediate tensors leak. With --num-layers 4, the leaked memory accumulates per layer.

Quick repro (requires 2+ GPUs)

# BF16 baseline (control — no leak)
torchrun --nproc-per-node 2 examples/pytorch/quantized_model_init/single_block_memory_profile.py --mode bare-fsdp2

# MXFP8 + qinit + FSDP2 (shows leaked tensors)
torchrun --nproc-per-node 2 examples/pytorch/quantized_model_init/single_block_memory_profile.py --mode mxfp8-fsdp2

# 4 layers (shows cross-layer accumulation)
torchrun --nproc-per-node 2 examples/pytorch/quantized_model_init/single_block_memory_profile.py --mode mxfp8-fsdp2 --num-layers 4

Snapshots saved to /tmp/single_block_snapshots/ — view at https://pytorch.org/memory_viz

Available modes

Mode Description
bare BF16 baseline, no FP8, no FSDP2
mxfp8 MXFP8 + quantized_model_init, no FSDP2
fp8-no-qinit FP8 autocast without qinit, no FSDP2
mxfp8-no-qinit MXFP8 autocast without qinit, no FSDP2
bare-fsdp2 BF16 + FSDP2 (control)
mxfp8-fsdp2 MXFP8 + qinit + FSDP2 (repro)
fp8-no-qinit-fsdp2 FP8 autocast + FSDP2, no qinit
mxfp8-no-qinit-fsdp2 MXFP8 autocast + FSDP2, no qinit

Additional flags: --model-size {8b,70b}, --num-layers N, --no-hpiv, --recipe {mxfp8,float8block,auto}

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • Add examples/pytorch/quantized_model_init/single_block_memory_profile.py — self-contained memory profiler with 8 modes for comparing BF16 vs MXFP8 vs FP8 autocast, with and without FSDP2
  • No changes to TE library code

@savitha-eng savitha-eng changed the title Add MXFP8 single-block memory profiler for FSDP2 qinit analysis Add MXFP8 quantized_model_init memory profiler for FSDP2 qinit analysis May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant