Skip to content

feat(pt/ptexpt): freeze parameters when infer only#5512

Open
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/infer
Open

feat(pt/ptexpt): freeze parameters when infer only#5512
OutisLi wants to merge 2 commits into
deepmodeling:masterfrom
OutisLi:pr/infer

Conversation

@OutisLi

@OutisLi OutisLi commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • New Features

    • Support returning model predictions without constructing a loss (skip_loss).
  • Improvements

    • Inference mode now temporarily freezes model parameters to avoid unintended gradient tracking while preserving coordinate gradients and predictions.
    • Centralized loss-free forward path for clearer behavior between training and inference.
  • Tests

    • Added unit tests covering inference freezing, restoration after exceptions, multitask selective freezing, and skip_loss behavior.

Copilot AI review requested due to automatic review settings June 11, 2026 04:56
@dosubot dosubot Bot added the new feature label Jun 11, 2026

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR optimizes PyTorch ModelWrapper inference by temporarily freezing model parameters to reduce autograd graph size, and replaces the prior “inference-only forward flag” behavior with an explicit skip_loss path for training wrappers.

Changes:

  • Add a parameter-freezing context manager used during pure inference in both pt and pt_expt wrappers.
  • Introduce skip_loss in the pt wrapper to get predictions without loss construction while keeping parameter gradients enabled (used by KFWrapper).
  • Add unit tests covering parameter freezing/restoration, exception safety, multitask head selection, and skip_loss gradient behavior.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
source/tests/pt_expt/test_wrapper.py Adds tests for experimental wrapper’s inference parameter-freezing behavior.
source/tests/pt/test_wrapper.py Adds tests for stable wrapper’s inference freezing and skip_loss gradient retention.
deepmd/pt_expt/train/wrapper.py Freezes parameters during inference via a context manager; refactors forward to use a no-loss helper.
deepmd/pt/train/wrapper.py Replaces inference_only arg with skip_loss; adds parameter-freezing context and no-loss helper.
deepmd/pt/optimizer/KFWrapper.py Updates calls to the wrapper to use skip_loss=True instead of the removed flag.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/train/wrapper.py
Comment thread deepmd/pt/train/wrapper.py Outdated
@coderabbitai

coderabbitai Bot commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b7093170-69af-421b-b2e0-00aa5a211f70

📥 Commits

Reviewing files that changed from the base of the PR and between 3419621 and 62de8c8.

📒 Files selected for processing (1)
  • deepmd/pt/train/wrapper.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/train/wrapper.py

📝 Walkthrough

Walkthrough

The PR refactors ModelWrapper implementations in both PT and PT_EXPT variants to support temporary parameter freezing during inference and loss-free predictions. A new _frozen_parameter_context() temporarily disables gradient tracking on model parameters while preserving coordinate-dependent output gradients. The optimizer integration uses the new interface, and comprehensive tests validate freezing semantics and parameter restoration.

Changes

Parameter Freezing and Loss-Skip Inference

Layer / File(s) Summary
PT ModelWrapper forward refactoring with frozen-parameter context
deepmd/pt/train/wrapper.py
Added contextmanager and Generator imports. Updated forward signature from inference_only to skip_loss parameter. Refactored forward logic to route inference through _frozen_parameter_context() that temporarily disables and restores parameter requires_grad. Introduced _forward_without_loss() helper to obtain predictions without loss computation.
PT_EXPT ModelWrapper forward refactoring with frozen-parameter context
deepmd/pt_expt/train/wrapper.py
Added contextmanager and Generator imports. Refactored forward method to execute inference inside _frozen_parameter_context(). Loss computed only when inference_only is false and label is provided; otherwise returns predictions with (None, None) for loss pair. New _forward_without_loss() helper encapsulates prediction-only logic.
Optimizer integration with skip_loss flag
deepmd/pt/optimizer/KFWrapper.py
Updated KFOptimizerWrapper.update_energy(), update_force(), and update_denoise_coord() to pass skip_loss=True instead of inference_only=True to model invocations. No changes to gradient computation, distributed reduction, or backward/step logic.
PT ModelWrapper behavior validation tests
source/tests/pt/test_wrapper.py
New test suite validates that wrapper freezes parameters during inference while preserving output equivalence, restores mixed requires_grad flags after forward, restores flags after exceptions, selectively freezes multitask heads via task_key, and that skip_loss=True preserves training gradients for backward propagation.
PT_EXPT ModelWrapper behavior validation tests
source/tests/pt_expt/test_wrapper.py
New test suite validates pt_expt wrapper freezing semantics with toy model and loss callable. Tests verify parameter freezing matches reference predictions, mixed flag restoration, exception recovery, selective multitask freezing, and that training mode without label enables gradient backprop.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(pt/ptexpt): freeze parameters when infer only' accurately describes the main change—implementing parameter freezing during inference mode across PyTorch and experimental PyTorch modules.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/pt/test_wrapper.py (1)

1-14: Run the repo-required lint/format + targeted pytest commands before merge.

Please validate this file with ruff format --check ., ruff check ., and targeted pytest invocation for single-case iteration.

As per coding guidelines: **/*.py: “run `ruff check .`” and “`ruff format .`”; **/tests/**/*.py: “Use `pytest` for testing single test cases … instead of full test suite.”

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@source/tests/pt/test_wrapper.py` around lines 1 - 14, Run the repo-required
formatting and lint checks and run the specific test case: execute "ruff format
--check ." and "ruff check ." to ensure this test file and project conform to
style rules, then run a targeted pytest for the single test in this file (e.g.,
using pytest -k test_wrapper or pytest path::TestCase::test_method for the
specific test) to validate ModelWrapper behavior; fix any reported lint/format
issues and any failing assertions in tests referencing ModelWrapper so the file
passes ruff and the targeted pytest invocation.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@source/tests/pt/test_wrapper.py`:
- Around line 1-14: Run the repo-required formatting and lint checks and run the
specific test case: execute "ruff format --check ." and "ruff check ." to ensure
this test file and project conform to style rules, then run a targeted pytest
for the single test in this file (e.g., using pytest -k test_wrapper or pytest
path::TestCase::test_method for the specific test) to validate ModelWrapper
behavior; fix any reported lint/format issues and any failing assertions in
tests referencing ModelWrapper so the file passes ruff and the targeted pytest
invocation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 56792c8d-dde7-473c-8bfd-66535c391b3b

📥 Commits

Reviewing files that changed from the base of the PR and between 890e38a and 3419621.

📒 Files selected for processing (5)
  • deepmd/pt/optimizer/KFWrapper.py
  • deepmd/pt/train/wrapper.py
  • deepmd/pt_expt/train/wrapper.py
  • source/tests/pt/test_wrapper.py
  • source/tests/pt_expt/test_wrapper.py

@codecov

codecov Bot commented Jun 11, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 91.66667% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.53%. Comparing base (890e38a) to head (62de8c8).

Files with missing lines Patch % Lines
deepmd/pt/train/wrapper.py 93.54% 2 Missing ⚠️
deepmd/pt_expt/train/wrapper.py 92.30% 2 Missing ⚠️
deepmd/pt/optimizer/KFWrapper.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master    #5512   +/-   ##
=======================================
  Coverage   81.52%   81.53%           
=======================================
  Files         872      872           
  Lines       97964    98007   +43     
  Branches     4241     4241           
=======================================
+ Hits        79865    79907   +42     
- Misses      16795    16798    +3     
+ Partials     1304     1302    -2     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi added this pull request to the merge queue Jun 12, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks Jun 13, 2026
@OutisLi OutisLi added this pull request to the merge queue Jun 13, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks Jun 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants