Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deepmd/pt/optimizer/KFWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
def update_energy(
self, inputs: dict, Etot_label: torch.Tensor, update_prefactor: float = 1
) -> None:
model_pred, _, _ = self.model(**inputs, inference_only=True)
model_pred, _, _ = self.model(**inputs, skip_loss=True)
Etot_predict = model_pred["energy"]
natoms_sum = int(inputs["atype"].shape[-1])
self.optimizer.set_grad_prefactor(natoms_sum)
Expand Down Expand Up @@ -66,7 +66,7 @@ def update_force(

for i in range(index.shape[0]):
self.optimizer.zero_grad()
model_pred, _, _ = self.model(**inputs, inference_only=True)
model_pred, _, _ = self.model(**inputs, skip_loss=True)
Etot_predict = model_pred["energy"]
natoms_sum = int(inputs["atype"].shape[-1])
force_predict = model_pred["force"]
Expand Down Expand Up @@ -105,7 +105,7 @@ def update_denoise_coord(

for i in range(index.shape[0]):
self.optimizer.zero_grad()
model_pred, _, _ = self.model(**inputs, inference_only=True)
model_pred, _, _ = self.model(**inputs, skip_loss=True)
updated_coord = model_pred["updated_coord"]
natoms_sum = int(inputs["atype"].shape[-1])
error_tmp = clean_coord[:, index[i]] - updated_coord[:, index[i]]
Expand Down
81 changes: 64 additions & 17 deletions deepmd/pt/train/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from collections.abc import (
Generator,
)
from contextlib import (
contextmanager,
)
from typing import (
Any,
)
Expand Down Expand Up @@ -161,7 +167,7 @@ def forward(
cur_lr: torch.Tensor | None = None,
label: torch.Tensor | None = None,
task_key: torch.Tensor | None = None,
inference_only: bool = False,
skip_loss: bool = False,
Comment thread
OutisLi marked this conversation as resolved.
do_atomic_virial: bool = False,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
Expand All @@ -188,23 +194,64 @@ def forward(
if has_spin:
input_dict["spin"] = spin

if self.inference_only or inference_only:
model_pred = self.model[task_key](**input_dict)
if self.modifier is not None:
modifier_pred = self.modifier(**input_dict)
for k, v in modifier_pred.items():
model_pred[k] = model_pred[k] + v
# A loss-free wrapper is a pure inference object, so parameters can be
# treated as constants while coordinate gradients remain enabled.
if self.inference_only:
with self._frozen_parameter_context():
model_pred = self._forward_without_loss(task_key, input_dict)
return model_pred, None, None
else:
natoms = atype.shape[-1]
model_pred, loss, more_loss = self.loss[task_key](
input_dict,
self.model[task_key],
label,
natoms=natoms,
learning_rate=cur_lr,
)
return model_pred, loss, more_loss
# Training wrappers may request predictions without loss construction
# and still backpropagate those predictions into model parameters
# (for example, KFWrapper updates).
if skip_loss:
model_pred = self._forward_without_loss(task_key, input_dict)
return model_pred, None, None

natoms = atype.shape[-1]
model_pred, loss, more_loss = self.loss[task_key](
input_dict,
self.model[task_key],
label,
natoms=natoms,
learning_rate=cur_lr,
)
return model_pred, loss, more_loss

@contextmanager
def _frozen_parameter_context(self) -> Generator[None, None, None]:
"""
Freeze model parameters during pure inference.

Conservative inference still differentiates model outputs with respect
to coordinates to obtain forces and virials. Parameter gradients are not
part of that contract, so disabling them trims the autograd graph while
leaving the coordinate-gradient path intact.
"""
params = tuple(self.parameters())
requires_grad = tuple(param.requires_grad for param in params)
if not any(requires_grad):
yield
return
for param in params:
param.requires_grad_(False)
try:
yield
finally:
for param, flag in zip(params, requires_grad, strict=True):
param.requires_grad_(flag)

def _forward_without_loss(
self,
task_key: str,
input_dict: dict[str, Any],
) -> Any:
"""Return predictions without constructing a loss."""
model_pred = self.model[task_key](**input_dict)
if self.modifier is not None:
modifier_pred = self.modifier(**input_dict)
for key, value in modifier_pred.items():
model_pred[key] = model_pred[key] + value
return model_pred

def set_extra_state(self, state: dict) -> None:
self.model_params = state["model_params"]
Expand Down
63 changes: 52 additions & 11 deletions deepmd/pt_expt/train/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from collections.abc import (
Generator,
)
from contextlib import (
contextmanager,
)
from typing import (
Any,
)
Expand Down Expand Up @@ -193,19 +199,54 @@ def forward(
"charge_spin": charge_spin,
}

model_pred = self.model[task_key](**input_dict)
if self.inference_only:
with self._frozen_parameter_context():
model_pred = self._forward_without_loss(task_key, input_dict)
return model_pred, None, None

if self.inference_only or label is None:
model_pred = self._forward_without_loss(task_key, input_dict)
if label is None:
return model_pred, None, None
else:
natoms = atype.shape[-1]
loss, more_loss = self.loss[task_key](
cur_lr,
natoms,
model_pred,
label,
)
return model_pred, loss, more_loss

natoms = atype.shape[-1]
loss, more_loss = self.loss[task_key](
cur_lr,
natoms,
model_pred,
label,
)
return model_pred, loss, more_loss

@contextmanager
def _frozen_parameter_context(self) -> Generator[None, None, None]:
"""
Freeze model parameters during pure inference.

Inference still differentiates model outputs with respect to
coordinates for force and virial evaluation. Parameter gradients are not
needed in that path, so disabling them keeps the autograd graph smaller
without changing coordinate derivatives.
"""
params = tuple(self.parameters())
requires_grad = tuple(param.requires_grad for param in params)
if not any(requires_grad):
yield
return
for param in params:
param.requires_grad_(False)
try:
yield
finally:
for param, flag in zip(params, requires_grad, strict=True):
param.requires_grad_(flag)

def _forward_without_loss(
self,
task_key: str,
input_dict: dict[str, Any],
) -> dict[str, torch.Tensor]:
"""Return model predictions without constructing a loss."""
return self.model[task_key](**input_dict)

def set_extra_state(self, state: dict) -> None:
self.model_params = state.get("model_params", {})
Expand Down
136 changes: 136 additions & 0 deletions source/tests/pt/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Tests for PyTorch model wrapper behavior."""

from __future__ import (
annotations,
)

import unittest

import torch

from deepmd.pt.train.wrapper import (
ModelWrapper,
)


class _LinearToyModel(torch.nn.Module):
def __init__(self, *, fail_forward: bool = False) -> None:
super().__init__()
self.linear = torch.nn.Linear(3, 1, bias=False, device="cpu")
self.scale = torch.nn.Parameter(torch.ones((), device="cpu"))
self.fail_forward = fail_forward
self.last_requires_grad: tuple[bool, ...] | None = None

def forward(
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: torch.Tensor | None = None,
do_atomic_virial: bool = False,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
charge_spin: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
del atype, box, do_atomic_virial, fparam, aparam, charge_spin
self.last_requires_grad = tuple(
param.requires_grad for param in self.parameters()
)
if self.fail_forward:
raise RuntimeError("intentional toy failure")
coord_req = coord.clone().requires_grad_(True)
atom_energy = self.scale * self.linear(coord_req).sum(dim=1, keepdim=True)
energy = atom_energy.sum(dim=1)
force = -torch.autograd.grad(energy.sum(), coord_req, create_graph=True)[0]
return {
"atom_energy": atom_energy,
"energy": energy,
"force": force,
}


class _EnergyLoss(torch.nn.Module):
def forward(
self,
input_dict: dict[str, torch.Tensor],
model: torch.nn.Module,
label: torch.Tensor | None,
natoms: int,
learning_rate: torch.Tensor | None = None,
) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict]:
del label, natoms, learning_rate
pred = model(**input_dict)
loss = pred["energy"].sum()
return pred, loss, {}


class TestModelWrapper(unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(20240611)
self.coord = torch.randn(2, 5, 3, device="cpu")
self.atype = torch.zeros(2, 5, dtype=torch.long, device="cpu")

def test_inference_wrapper_freezes_parameters_without_changing_predictions(
self,
) -> None:
model = _LinearToyModel()
reference_model = _LinearToyModel()
reference_model.load_state_dict(model.state_dict())
wrapper = ModelWrapper(model)
reference_wrapper = ModelWrapper(reference_model, _EnergyLoss())

ref, _, _ = reference_wrapper(self.coord, self.atype, skip_loss=True)
out, _, _ = wrapper(self.coord, self.atype)

self.assertEqual(model.last_requires_grad, (False, False))
self.assertEqual(reference_model.last_requires_grad, (True, True))
self.assertTrue(all(param.requires_grad for param in wrapper.parameters()))
for key in ("atom_energy", "energy", "force"):
torch.testing.assert_close(out[key], ref[key])

def test_inference_wrapper_restores_mixed_parameter_flags(self) -> None:
model = _LinearToyModel()
model.linear.weight.requires_grad_(False)
wrapper = ModelWrapper(model)

wrapper(self.coord, self.atype)

self.assertEqual(model.last_requires_grad, (False, False))
self.assertFalse(model.linear.weight.requires_grad)
self.assertTrue(model.scale.requires_grad)

def test_inference_wrapper_restores_parameters_after_exception(self) -> None:
model = _LinearToyModel(fail_forward=True)
wrapper = ModelWrapper(model)

with self.assertRaisesRegex(RuntimeError, "intentional toy failure"):
wrapper(self.coord, self.atype)

self.assertEqual(model.last_requires_grad, (False, False))
self.assertTrue(all(param.requires_grad for param in wrapper.parameters()))

def test_multitask_inference_wrapper_freezes_selected_head(self) -> None:
model_a = _LinearToyModel()
model_b = _LinearToyModel()
wrapper = ModelWrapper({"a": model_a, "b": model_b})

wrapper(self.coord, self.atype, task_key="b")

self.assertIsNone(model_a.last_requires_grad)
self.assertEqual(model_b.last_requires_grad, (False, False))
self.assertTrue(all(param.requires_grad for param in wrapper.parameters()))

def test_skip_loss_keeps_training_gradients(self) -> None:
model = _LinearToyModel()
wrapper = ModelWrapper(model, _EnergyLoss())

pred, _, _ = wrapper(self.coord, self.atype, skip_loss=True)
pred["energy"].sum().backward()

self.assertEqual(model.last_requires_grad, (True, True))
self.assertIsNotNone(model.linear.weight.grad)
self.assertIsNotNone(model.scale.grad)


if __name__ == "__main__":
unittest.main()
Loading
Loading