Chore: refactor pt-expt compile#5504
Conversation
for more information, see https://pre-commit.ci
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a new shared compile_utils module (prime helpers, trace-time padding, FX detach-strip, FX rebuild) and updates SeZMModel and training trace/compile to import and use these helpers; training now selects a forbidden-set-aware safe-prime trace frame size and pads frame-only inputs before make_fx and compile. ChangesCompile Utilities and Shape Specialization
Sequence Diagram (trace & compile flow)sequenceDiagram
participant Trainer
participant make_fx
participant DetachFix as strip_saved_tensor_detach
participant Rebuilder as rebuild_graph_module
participant Compiler as torch.compile
Trainer->>make_fx: call make_fx(forward_lower) with padded inputs
make_fx->>DetachFix: produce traced_lower GraphModule
DetachFix->>Rebuilder: remove saved-tensor detaches and rewrite uses
Rebuilder->>Compiler: rebuild graph and return clean GraphModule
Compiler->>Trainer: compile(GraphModule, dynamic=True, backend="inductor")
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
deepmd/pt/utils/compile_utils.py (1)
87-88: 💤 Low valueConsider adding a type guard for extra robustness.
While
aten.detach.defaultinputs are always Nodes in make_fx-generated graphs, adding anisinstancecheck would prevent potentialAttributeErrorif this helper is ever called on malformed graphs.🛡️ Optional defensive check
def _is_detach(n: torch.fx.Node) -> bool: - return n.op == "call_function" and n.target == _DETACH + return isinstance(n, torch.fx.Node) and n.op == "call_function" and n.target == _DETACH🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/utils/compile_utils.py` around lines 87 - 88, Update _is_detach to defensively check that the input is a torch.fx.Node before accessing attributes: change the parameter type to a more permissive Any (or keep current) and add an isinstance(n, torch.fx.Node) guard so you only evaluate n.op and n.target when n is a Node; optionally change the return annotation to typing.TypeGuard[torch.fx.Node] if you want an actual type guard. Ensure the function still checks n.op == "call_function" and n.target == _DETACH after the isinstance check and reference the _is_detach name, torch.fx.Node, and _DETACH in your change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@deepmd/pt/utils/compile_utils.py`:
- Around line 87-88: Update _is_detach to defensively check that the input is a
torch.fx.Node before accessing attributes: change the parameter type to a more
permissive Any (or keep current) and add an isinstance(n, torch.fx.Node) guard
so you only evaluate n.op and n.target when n is a Node; optionally change the
return annotation to typing.TypeGuard[torch.fx.Node] if you want an actual type
guard. Ensure the function still checks n.op == "call_function" and n.target ==
_DETACH after the isinstance check and reference the _is_detach name,
torch.fx.Node, and _DETACH in your change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4c2615d5-fadf-4031-a417-656104bcad39
📒 Files selected for processing (3)
deepmd/pt/model/model/sezm_model.pydeepmd/pt/utils/compile_utils.pydeepmd/pt_expt/train/training.py
There was a problem hiding this comment.
Pull request overview
This PR refactors PyTorch tracing/compilation helper logic into a shared utility module so both the SeZM model compile path and the pt_expt training compile path reuse the same implementations.
Changes:
- Added
deepmd/pt/utils/compile_utils.pywith shared helpers for trace-time prime shape selection, input padding/trimming, detach-chain stripping, and FX graph rebuilding. - Updated
deepmd/pt/model/model/sezm_model.pyto import and use the shared helpers, removing the previous in-file implementations. - Updated
deepmd/pt_expt/train/training.pyto import and use the shared helpers and to coerce multiple trace-time dimensions (nf/nloc/nall) to collision-resistant primes.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| deepmd/pt/utils/compile_utils.py | New shared tracing/compile utility module (prime sizing, padding, detach stripping, graph rebuilding). |
| deepmd/pt/model/model/sezm_model.py | Replaced local helper implementations with imports from the new shared module. |
| deepmd/pt_expt/train/training.py | Replaced local FX graph post-processing helpers with shared imports; updated trace-input coercion to prime sizes for multiple dims. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5504 +/- ##
==========================================
- Coverage 81.52% 81.52% -0.01%
==========================================
Files 872 872
Lines 97964 97967 +3
Branches 4241 4240 -1
==========================================
Hits 79865 79865
- Misses 16795 16801 +6
+ Partials 1304 1301 -3 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 395: The comment in training.py contains EN DASH characters in the phrase
"50–500 and 200–5000+" which triggers Ruff RUF003; update that comment (around
the block where the line mentions real data counts) to use HYPHEN-MINUS instead,
i.e. change "50–500 and 200–5000+" to "50-500 and 200-5000+", ensuring no other
EN DASH characters remain in the same comment or nearby comments.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d06cb1c3-6798-4545-9084-600cf0f7aaab
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
Resolve conflicts from PR deepmodeling#5503 (dpa4) which introduced deepmd/pt/utils/compile_compat.py — a superset of this branch's compile_utils.py. Consolidate onto compile_compat: - sezm_model.py: take upstream import block + return paren - training.py: import next_safe_prime/trace_pad_dim/rebuild_graph_module/ strip_saved_tensor_detach from compile_compat (aliased to local names) - remove redundant compile_utils.py Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
for more information, see https://pre-commit.ci
…move_all) The merge consolidated training.py onto compile_compat.strip_saved_tensor_detach, which is *selective* — it preserves user-explicit .detach() calls. The traced training fn opens with `coord.detach().requires_grad_(True)`, so the selective strip left that boundary detach in place, severing the second-order gradient path and producing the compiled-vs-uncompiled force mismatch (DPA2 test). Rather than duplicate an aggressive remover, add a keyword-only `remove_all` flag to the single compile_compat.strip_saved_tensor_detach: - SeZM inference path (sezm_model.py): default remove_all=False -> selective, preserving legitimate user .detach() calls (dpa4 behaviour). - pt_expt training path (training.py): remove_all=True -> strip every detach, correct because the trace is fed already-detached, grad-enabled inputs. One shared implementation, behaviour selected per call site. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Summary by CodeRabbit
Refactor
Bug Fixes