Add torch_compile flag for training networks#28
Open
wenxin0319 wants to merge 2 commits into
Open
Conversation
Author
|
@juliusberner Could you please take a look at my PR? Thanks! |
juliusberner
suggested changes
Jun 7, 2026
| ddp_find_unused_parameters: bool = True | ||
|
|
||
| # enable torch.compile for training networks | ||
| torch_compile: bool = False |
Contributor
There was a problem hiding this comment.
Can we make this more general, e.g.:
# torch.compile mode for inference speedup ("default", "reduce-overhead", "max-autotune")
# None disables torch.compile.
torch_compile_mode: Optional[str] = None
| # instantiate all necessary nets and submodules | ||
| self.build_model() | ||
|
|
||
| # optionally compile networks with torch.compile |
Contributor
There was a problem hiding this comment.
I think compilation should happen after the FSDP/DDP wrapping
| synchronize() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def _apply_torch_compile(self): |
Contributor
There was a problem hiding this comment.
Can we define a compile_dict similar to the fsdp_dict/model_dict, that contains all modules that should be compiled? We could also add the VAE there (note that we would need to search for submodules of the VAE, since the VAEs themselves are not instances of torch.nn.module).
f930ee8 to
eb763ac
Compare
- Replace torch_compile: bool with torch_compile_mode: Optional[str] (e.g. "default", "reduce-overhead", "max-autotune"; None disables). - Introduce a compile_dict property (like fsdp_dict/model_dict) holding all modules to compile. DMD2 extends it with teacher/fake_score. The base also includes the VAE: VAE wrappers are not nn.Modules, so we search their attributes for the underlying nn.Module submodule(s). - Compile in place via nn.Module.compile() and apply it from the trainer *after* DDP/FSDP wrapping so torch.compile composes with the wrappers. - Update tests accordingly (in-place compile detection, compile_dict and VAE submodule discovery).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
FastGen currently relies on diffusers-based model execution, which leaves performance on the table during training.
This PR adds an opt-in torch_compile flag that wraps training networks with torch.compile, enabling PyTorch's compiler optimizations (operator fusion, memory planning, kernel autotuning) for significant speedups on common models.
Benchmark (QwenImage, 20.43B params, NVIDIA H100, bfloat16, 512x512):
Setting │ Time/iter │ Std
Baseline (no compile) │ 0.694s │ 0.094s
torch.compile (max-autotune) │ 0.447s │ 0.014s
which is Speedup 1.55x (55% faster)
Compiled iterations also show much lower variance (0.014s vs 0.094s), meaning more consistent training throughput. The one-time compilation overhead (~5-10 min with max-autotune) is amortized over the full training run.
Changes:
Usage:
Set torch_compile=True in model config to enable.