MXFP8 + FSDP2 checkpoint resume crashes in reset_sharded_param - add mxfp8 recpipe to fully shard #2951
Conversation
Greptile SummaryThis PR extends
Confidence Score: 3/5Not safe to merge as-is — the example will crash at checkpoint resume because the underlying PyTorch reset_sharded_param fix is missing from this PR. One P1 finding (example crashes at the exact feature it is adding) and one P2 (repeated recipe instantiation). The P1 directly breaks the end-to-end scenario this PR is intended to demonstrate. Score is 3/5 (below the P1 ceiling of 4) because the breakage is in the central new code path. examples/pytorch/quantized_model_init/fully_shard.py — checkpoint-resume section; also tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py for the mislabeled xfail. Important Files Changed
Sequence DiagramsequenceDiagram
participant Script as fully_shard.py
participant TE as TransformerEngine
participant DCP as torch.distributed.checkpoint
participant FSDP as FSDP2 (FSDPParam)
participant PyTorch as PyTorch internals
Script->>TE: quantized_model_init(recipe=MXFP8BlockScaling())
TE-->>Script: model with MXFP8Tensor params (meta device)
Script->>FSDP: fully_shard(model)
Script->>TE: module.reset_parameters() — materializes shards on GPU
Script->>TE: autocast(recipe=MXFP8BlockScaling()) — training loop
Script->>DCP: dcp.save(model.state_dict())
DCP-->>Script: checkpoint written to disk
Script->>DCP: dcp.load(state_to_load)
Script->>PyTorch: model.load_state_dict(state_to_load["model"])
PyTorch->>TE: copy_() on MXFP8Tensor — re-quantizes, new storage allocated
PyTorch->>FSDP: reset_sharded_param() post-load hook
FSDP->>PyTorch: data_ptr() on old (invalidated) storage
PyTorch-->>FSDP: RuntimeError: invalid python storage
Note over FSDP,PyTorch: Workaround (try/except) NOT present in this PR
Script->>TE: autocast(recipe=MXFP8BlockScaling()) — post-checkpoint step
|
| optimizer.zero_grad(set_to_none=True) | ||
|
|
||
| with te.autocast(enabled=True): | ||
| with te.autocast(enabled=True, recipe=MXFP8BlockScaling()): |
There was a problem hiding this comment.
Repeated recipe instantiation per call
MXFP8BlockScaling() is instantiated three separate times (once in quantized_model_init, once here in the training loop, and once in the post-checkpoint step). While MXFP8BlockScaling is currently stateless so this is functionally safe, the conventional pattern in TE examples is to create a single recipe object at the top and reuse it. That avoids any risk if the recipe gains per-instance state in the future and keeps the configuration change localized to one place.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Bug
The
fully_shard.pyexample crashes on checkpoint resume when usingMXFP8BlockScaling. This affects any MXFP8 + FSDP2 training run that needs to resume from a checkpoint. Based on the existingtest_dcp_output_parityxfails, NVFP4 is affected by the same bug.We hit this while running a Llama 7B MXFP8 convergence run with quantized model init + FusedAdam master weights. Checkpoint save succeeds, but
model.load_state_dict()on resume crashes with:RuntimeError: Attempted to access the data pointer on an invalid python storage.
at
torch/distributed/fsdp/_fully_shard/_fsdp_param.py:890inreset_sharded_param().Root Cause
model.load_state_dict()callscopy_()on MXFP8Tensor params, which re-quantizes and allocates new internal storage. FSDP2'sreset_sharded_parampost-load hook then callsdata_ptr()on the now-invalidated old storage.PyTorch has a
# TODO: need to support tensor subclasscomment at the crash site.Mislabeled xfails in TE tests
In
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py::test_dcp_output_parity:data_ptr()crash oninvalid storage, not a FusedAdam issue. Commenting out the xfail and running the test confirms the same
data_ptr()crash.wrapper subclass with invalid storage."
Both MXFP8 and NVFP4 are blocked by the same underlying bug.
Reproduce