Skip to content

Add torch_compile flag for training networks#28

Open
wenxin0319 wants to merge 2 commits into
NVlabs:mainfrom
wenxin0319:main
Open

Add torch_compile flag for training networks#28
wenxin0319 wants to merge 2 commits into
NVlabs:mainfrom
wenxin0319:main

Conversation

@wenxin0319

Copy link
Copy Markdown

FastGen currently relies on diffusers-based model execution, which leaves performance on the table during training.

This PR adds an opt-in torch_compile flag that wraps training networks with torch.compile, enabling PyTorch's compiler optimizations (operator fusion, memory planning, kernel autotuning) for significant speedups on common models.

Benchmark (QwenImage, 20.43B params, NVIDIA H100, bfloat16, 512x512):

Setting │ Time/iter │ Std
Baseline (no compile) │ 0.694s │ 0.094s
torch.compile (max-autotune) │ 0.447s │ 0.014s

which is Speedup 1.55x (55% faster)

Compiled iterations also show much lower variance (0.014s vs 0.094s), meaning more consistent training throughput. The one-time compilation overhead (~5-10 min with max-autotune) is amortized over the full training run.

Changes:

  • Add torch_compile: bool = False config option in BaseModelConfig
  • Add _apply_torch_compile() in FastGenModel that compiles the main network (self.net)
  • Override _apply_torch_compile() in DMD2Model to also compile teacher and fake_score networks
  • Add comprehensive tests covering compile on/off for both SFT and DMD2 models, including training step validation
  • Add bench_compile.py benchmark script for measuring compile speedup

Usage:
Set torch_compile=True in model config to enable.

@wenxin0319

Copy link
Copy Markdown
Author

@juliusberner Could you please take a look at my PR? Thanks!

@juliusberner juliusberner left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR and the benchmarking, I left a few comments!

Comment thread fastgen/configs/config.py Outdated
ddp_find_unused_parameters: bool = True

# enable torch.compile for training networks
torch_compile: bool = False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more general, e.g.:

# torch.compile mode for inference speedup ("default", "reduce-overhead", "max-autotune")
# None disables torch.compile.
torch_compile_mode: Optional[str] = None

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I will do that

Comment thread fastgen/methods/model.py Outdated
# instantiate all necessary nets and submodules
self.build_model()

# optionally compile networks with torch.compile

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think compilation should happen after the FSDP/DDP wrapping

synchronize()
torch.cuda.empty_cache()

def _apply_torch_compile(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define a compile_dict similar to the fsdp_dict/model_dict, that contains all modules that should be compiled? We could also add the VAE there (note that we would need to search for submodules of the VAE, since the VAEs themselves are not instances of torch.nn.module).

@wenxin0319 wenxin0319 force-pushed the main branch 2 times, most recently from f930ee8 to eb763ac Compare June 9, 2026 03:49
- Replace torch_compile: bool with torch_compile_mode: Optional[str]
  (e.g. "default", "reduce-overhead", "max-autotune"; None disables).
- Introduce a compile_dict property (like fsdp_dict/model_dict) holding
  all modules to compile. DMD2 extends it with teacher/fake_score. The
  base also includes the VAE: VAE wrappers are not nn.Modules, so we
  search their attributes for the underlying nn.Module submodule(s).
- Compile in place via nn.Module.compile() and apply it from the trainer
  *after* DDP/FSDP wrapping so torch.compile composes with the wrappers.
- Update tests accordingly (in-place compile detection, compile_dict and
  VAE submodule discovery).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants