feat(pt/ptexpt): freeze parameters when infer only#5512
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR optimizes PyTorch ModelWrapper inference by temporarily freezing model parameters to reduce autograd graph size, and replaces the prior “inference-only forward flag” behavior with an explicit skip_loss path for training wrappers.
Changes:
- Add a parameter-freezing context manager used during pure inference in both
ptandpt_exptwrappers. - Introduce
skip_lossin theptwrapper to get predictions without loss construction while keeping parameter gradients enabled (used byKFWrapper). - Add unit tests covering parameter freezing/restoration, exception safety, multitask head selection, and
skip_lossgradient behavior.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt_expt/test_wrapper.py | Adds tests for experimental wrapper’s inference parameter-freezing behavior. |
| source/tests/pt/test_wrapper.py | Adds tests for stable wrapper’s inference freezing and skip_loss gradient retention. |
| deepmd/pt_expt/train/wrapper.py | Freezes parameters during inference via a context manager; refactors forward to use a no-loss helper. |
| deepmd/pt/train/wrapper.py | Replaces inference_only arg with skip_loss; adds parameter-freezing context and no-loss helper. |
| deepmd/pt/optimizer/KFWrapper.py | Updates calls to the wrapper to use skip_loss=True instead of the removed flag. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThe PR refactors ModelWrapper implementations in both PT and PT_EXPT variants to support temporary parameter freezing during inference and loss-free predictions. A new ChangesParameter Freezing and Loss-Skip Inference
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/pt/test_wrapper.py (1)
1-14: Run the repo-required lint/format + targeted pytest commands before merge.Please validate this file with
ruff format --check .,ruff check ., and targeted pytest invocation for single-case iteration.As per coding guidelines:
**/*.py: “run `ruff check .`” and “`ruff format .`”;**/tests/**/*.py: “Use `pytest` for testing single test cases … instead of full test suite.”🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/pt/test_wrapper.py` around lines 1 - 14, Run the repo-required formatting and lint checks and run the specific test case: execute "ruff format --check ." and "ruff check ." to ensure this test file and project conform to style rules, then run a targeted pytest for the single test in this file (e.g., using pytest -k test_wrapper or pytest path::TestCase::test_method for the specific test) to validate ModelWrapper behavior; fix any reported lint/format issues and any failing assertions in tests referencing ModelWrapper so the file passes ruff and the targeted pytest invocation.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/pt/test_wrapper.py`:
- Around line 1-14: Run the repo-required formatting and lint checks and run the
specific test case: execute "ruff format --check ." and "ruff check ." to ensure
this test file and project conform to style rules, then run a targeted pytest
for the single test in this file (e.g., using pytest -k test_wrapper or pytest
path::TestCase::test_method for the specific test) to validate ModelWrapper
behavior; fix any reported lint/format issues and any failing assertions in
tests referencing ModelWrapper so the file passes ruff and the targeted pytest
invocation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 56792c8d-dde7-473c-8bfd-66535c391b3b
📒 Files selected for processing (5)
deepmd/pt/optimizer/KFWrapper.pydeepmd/pt/train/wrapper.pydeepmd/pt_expt/train/wrapper.pysource/tests/pt/test_wrapper.pysource/tests/pt_expt/test_wrapper.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5512 +/- ##
=======================================
Coverage 81.52% 81.53%
=======================================
Files 872 872
Lines 97964 98007 +43
Branches 4241 4241
=======================================
+ Hits 79865 79907 +42
- Misses 16795 16798 +3
+ Partials 1304 1302 -2 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Summary by CodeRabbit
New Features
Improvements
Tests