Skip to content

MXFP8 + FSDP2 checkpoint resume crashes in reset_sharded_param - add mxfp8 recpipe to fully shard #2951

Open
savitha-eng wants to merge 2 commits intoNVIDIA:mainfrom
savitha-eng:savitha/mxfp8-fsdp2-checkpoint-resume-bug
Open

MXFP8 + FSDP2 checkpoint resume crashes in reset_sharded_param - add mxfp8 recpipe to fully shard #2951
savitha-eng wants to merge 2 commits intoNVIDIA:mainfrom
savitha-eng:savitha/mxfp8-fsdp2-checkpoint-resume-bug

Conversation

@savitha-eng
Copy link
Copy Markdown

@savitha-eng savitha-eng commented May 1, 2026

Bug

The fully_shard.py example crashes on checkpoint resume when using MXFP8BlockScaling. This affects any MXFP8 + FSDP2 training run that needs to resume from a checkpoint. Based on the existing
test_dcp_output_parity xfails, NVFP4 is affected by the same bug.

We hit this while running a Llama 7B MXFP8 convergence run with quantized model init + FusedAdam master weights. Checkpoint save succeeds, but model.load_state_dict() on resume crashes with:

RuntimeError: Attempted to access the data pointer on an invalid python storage.

at torch/distributed/fsdp/_fully_shard/_fsdp_param.py:890 in reset_sharded_param().

Root Cause

model.load_state_dict() calls copy_() on MXFP8Tensor params, which re-quantizes and allocates new internal storage. FSDP2's
reset_sharded_param post-load hook then calls data_ptr() on the now-invalidated old storage.

PyTorch has a # TODO: need to support tensor subclass comment at the crash site.

Mislabeled xfails in TE tests

In tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py::test_dcp_output_parity:

  • The MXFP8 xfail references PR Optimize fp8 block scaling Allgather for FSDP2 #2789 (FusedAdam kernel bug), but this is incorrect — the actual failure is the data_ptr() crash on
    invalid storage, not a FusedAdam issue. Commenting out the xfail and running the test confirms the same data_ptr() crash.
  • The NVFP4 xfail correctly describes the issue: "DCP load_state_dict triggers reset_sharded_param() which calls data_ptr() on
    wrapper subclass with invalid storage."

Both MXFP8 and NVFP4 are blocked by the same underlying bug.

Reproduce

torchrun --nproc_per_node=2 examples/pytorch/quantized_model_init/fully_shard.py

Requires Blackwell GPUs (B200). The only change to the example is adding recipe=MXFP8BlockScaling() to quantized_model_init and
te.autocast calls.

Workaround we used to train the model with checkpoint resumption: 

We patched FSDPParam.reset_sharded_param to wrap the data_ptr() comparison in try/except RuntimeError, setting same_local_tensor = False
on failure. This just tells FSDP to re-record _sharded_param_data, which is always correct.

if type(self._sharded_param_data) is torch.Tensor:
    try:
        same_local_tensor = (
            self._sharded_param_data.untyped_storage().data_ptr() > 0
            and self._sharded_param_data.untyped_storage().data_ptr()
            == local_tensor.untyped_storage().data_ptr()
        )
    except RuntimeError:
        same_local_tensor = False

Environment

- TE: 2.15.0 (te-main)
- PyTorch: 2.7.0
- GPUs: B200 (Blackwell)

@savitha-eng savitha-eng marked this pull request as ready for review May 1, 2026 08:17
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 1, 2026

Greptile Summary

This PR extends examples/pytorch/quantized_model_init/fully_shard.py to use MXFP8BlockScaling — adding the recipe to quantized_model_init and both te.autocast calls — to document an MXFP8 + FSDP2 checkpoint-resume crash.

  • Crash not fixed: model.load_state_dict() at line 187 will still raise RuntimeError: Attempted to access the data pointer on an invalid python storage inside FSDPParam.reset_sharded_param. The try/except workaround described in the PR description is absent from the diff, so the example fails at the exact scenario it is meant to demonstrate.
  • Mislabeled xfail not updated: run_fsdp2_fused_adam.py::test_dcp_output_parity still carries an xfail that attributes the MXFP8 failure to the FusedAdam kernel (PR Optimize fp8 block scaling Allgather for FSDP2 #2789) rather than the data_ptr() crash; the PR description identifies this as incorrect but the test is unchanged.

Confidence Score: 3/5

Not safe to merge as-is — the example will crash at checkpoint resume because the underlying PyTorch reset_sharded_param fix is missing from this PR.

One P1 finding (example crashes at the exact feature it is adding) and one P2 (repeated recipe instantiation). The P1 directly breaks the end-to-end scenario this PR is intended to demonstrate. Score is 3/5 (below the P1 ceiling of 4) because the breakage is in the central new code path.

examples/pytorch/quantized_model_init/fully_shard.py — checkpoint-resume section; also tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py for the mislabeled xfail.

Important Files Changed

Filename Overview
examples/pytorch/quantized_model_init/fully_shard.py Adds MXFP8BlockScaling recipe to quantized_model_init and both autocast calls; the checkpoint-resume path (model.load_state_dict) will crash because the underlying reset_sharded_param data_ptr() bug is not fixed in this PR.

Sequence Diagram

sequenceDiagram
    participant Script as fully_shard.py
    participant TE as TransformerEngine
    participant DCP as torch.distributed.checkpoint
    participant FSDP as FSDP2 (FSDPParam)
    participant PyTorch as PyTorch internals

    Script->>TE: quantized_model_init(recipe=MXFP8BlockScaling())
    TE-->>Script: model with MXFP8Tensor params (meta device)
    Script->>FSDP: fully_shard(model)
    Script->>TE: module.reset_parameters() — materializes shards on GPU
    Script->>TE: autocast(recipe=MXFP8BlockScaling()) — training loop
    Script->>DCP: dcp.save(model.state_dict())
    DCP-->>Script: checkpoint written to disk
    Script->>DCP: dcp.load(state_to_load)
    Script->>PyTorch: model.load_state_dict(state_to_load["model"])
    PyTorch->>TE: copy_() on MXFP8Tensor — re-quantizes, new storage allocated
    PyTorch->>FSDP: reset_sharded_param() post-load hook
    FSDP->>PyTorch: data_ptr() on old (invalidated) storage
    PyTorch-->>FSDP: RuntimeError: invalid python storage
    Note over FSDP,PyTorch: Workaround (try/except) NOT present in this PR
    Script->>TE: autocast(recipe=MXFP8BlockScaling()) — post-checkpoint step
Loading

Comments Outside Diff (1)

  1. examples/pytorch/quantized_model_init/fully_shard.py, line 185-188 (link)

    P1 Example crashes at checkpoint resume without the PyTorch workaround

    model.load_state_dict() on an MXFP8 model calls copy_() on MXFP8Tensor params, which re-quantizes and allocates new internal storage. FSDP2's reset_sharded_param post-load hook then calls data_ptr() on the now-invalidated old storage, crashing with RuntimeError: Attempted to access the data pointer on an invalid python storage.

    The PR description describes a workaround (patching FSDPParam.reset_sharded_param in PyTorch with a try/except RuntimeError around the data_ptr() comparison), but that patch is not included in this diff. Without it the example fails at this line on any MXFP8+FSDP2 run, which is exactly the scenario this example is meant to demonstrate. The same root cause affects the test_dcp_output_parity xfail for MXFP8 in run_fsdp2_fused_adam.py, which is also not updated here.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

optimizer.zero_grad(set_to_none=True)

with te.autocast(enabled=True):
with te.autocast(enabled=True, recipe=MXFP8BlockScaling()):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Repeated recipe instantiation per call

MXFP8BlockScaling() is instantiated three separate times (once in quantized_model_init, once here in the training loop, and once in the post-checkpoint step). While MXFP8BlockScaling is currently stateless so this is functionally safe, the conventional pattern in TE examples is to create a single recipe object at the top and reuse it. That avoids any risk if the recipe gains per-instance state in the future and keeps the configuration change localized to one place.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant