diff --git a/deepmd/pt/optimizer/KFWrapper.py b/deepmd/pt/optimizer/KFWrapper.py index 3ab7ffe7a9..d95b8a3712 100644 --- a/deepmd/pt/optimizer/KFWrapper.py +++ b/deepmd/pt/optimizer/KFWrapper.py @@ -28,7 +28,7 @@ def __init__( def update_energy( self, inputs: dict, Etot_label: torch.Tensor, update_prefactor: float = 1 ) -> None: - model_pred, _, _ = self.model(**inputs, inference_only=True) + model_pred, _, _ = self.model(**inputs, skip_loss=True) Etot_predict = model_pred["energy"] natoms_sum = int(inputs["atype"].shape[-1]) self.optimizer.set_grad_prefactor(natoms_sum) @@ -66,7 +66,7 @@ def update_force( for i in range(index.shape[0]): self.optimizer.zero_grad() - model_pred, _, _ = self.model(**inputs, inference_only=True) + model_pred, _, _ = self.model(**inputs, skip_loss=True) Etot_predict = model_pred["energy"] natoms_sum = int(inputs["atype"].shape[-1]) force_predict = model_pred["force"] @@ -105,7 +105,7 @@ def update_denoise_coord( for i in range(index.shape[0]): self.optimizer.zero_grad() - model_pred, _, _ = self.model(**inputs, inference_only=True) + model_pred, _, _ = self.model(**inputs, skip_loss=True) updated_coord = model_pred["updated_coord"] natoms_sum = int(inputs["atype"].shape[-1]) error_tmp = clean_coord[:, index[i]] - updated_coord[:, index[i]] diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 1d741dd534..da710f4fdf 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -1,5 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from collections.abc import ( + Generator, +) +from contextlib import ( + contextmanager, +) from typing import ( Any, ) @@ -161,7 +167,7 @@ def forward( cur_lr: torch.Tensor | None = None, label: torch.Tensor | None = None, task_key: torch.Tensor | None = None, - inference_only: bool = False, + skip_loss: bool = False, do_atomic_virial: bool = False, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, @@ -188,23 +194,64 @@ def forward( if has_spin: input_dict["spin"] = spin - if self.inference_only or inference_only: - model_pred = self.model[task_key](**input_dict) - if self.modifier is not None: - modifier_pred = self.modifier(**input_dict) - for k, v in modifier_pred.items(): - model_pred[k] = model_pred[k] + v + # A loss-free wrapper is a pure inference object, so parameters can be + # treated as constants while coordinate gradients remain enabled. + if self.inference_only: + with self._frozen_parameter_context(): + model_pred = self._forward_without_loss(task_key, input_dict) return model_pred, None, None - else: - natoms = atype.shape[-1] - model_pred, loss, more_loss = self.loss[task_key]( - input_dict, - self.model[task_key], - label, - natoms=natoms, - learning_rate=cur_lr, - ) - return model_pred, loss, more_loss + # Training wrappers may request predictions without loss construction + # and still backpropagate those predictions into model parameters + # (for example, KFWrapper updates). + if skip_loss: + model_pred = self._forward_without_loss(task_key, input_dict) + return model_pred, None, None + + natoms = atype.shape[-1] + model_pred, loss, more_loss = self.loss[task_key]( + input_dict, + self.model[task_key], + label, + natoms=natoms, + learning_rate=cur_lr, + ) + return model_pred, loss, more_loss + + @contextmanager + def _frozen_parameter_context(self) -> Generator[None, None, None]: + """ + Freeze model parameters during pure inference. + + Conservative inference still differentiates model outputs with respect + to coordinates to obtain forces and virials. Parameter gradients are not + part of that contract, so disabling them trims the autograd graph while + leaving the coordinate-gradient path intact. + """ + params = tuple(self.parameters()) + requires_grad = tuple(param.requires_grad for param in params) + if not any(requires_grad): + yield + return + for param in params: + param.requires_grad_(False) + try: + yield + finally: + for param, flag in zip(params, requires_grad, strict=True): + param.requires_grad_(flag) + + def _forward_without_loss( + self, + task_key: str, + input_dict: dict[str, Any], + ) -> Any: + """Return predictions without constructing a loss.""" + model_pred = self.model[task_key](**input_dict) + if self.modifier is not None: + modifier_pred = self.modifier(**input_dict) + for key, value in modifier_pred.items(): + model_pred[key] = model_pred[key] + value + return model_pred def set_extra_state(self, state: dict) -> None: self.model_params = state["model_params"] diff --git a/deepmd/pt_expt/train/wrapper.py b/deepmd/pt_expt/train/wrapper.py index 6fd68b8edc..46cff69515 100644 --- a/deepmd/pt_expt/train/wrapper.py +++ b/deepmd/pt_expt/train/wrapper.py @@ -1,5 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +from collections.abc import ( + Generator, +) +from contextlib import ( + contextmanager, +) from typing import ( Any, ) @@ -193,19 +199,54 @@ def forward( "charge_spin": charge_spin, } - model_pred = self.model[task_key](**input_dict) + if self.inference_only: + with self._frozen_parameter_context(): + model_pred = self._forward_without_loss(task_key, input_dict) + return model_pred, None, None - if self.inference_only or label is None: + model_pred = self._forward_without_loss(task_key, input_dict) + if label is None: return model_pred, None, None - else: - natoms = atype.shape[-1] - loss, more_loss = self.loss[task_key]( - cur_lr, - natoms, - model_pred, - label, - ) - return model_pred, loss, more_loss + + natoms = atype.shape[-1] + loss, more_loss = self.loss[task_key]( + cur_lr, + natoms, + model_pred, + label, + ) + return model_pred, loss, more_loss + + @contextmanager + def _frozen_parameter_context(self) -> Generator[None, None, None]: + """ + Freeze model parameters during pure inference. + + Inference still differentiates model outputs with respect to + coordinates for force and virial evaluation. Parameter gradients are not + needed in that path, so disabling them keeps the autograd graph smaller + without changing coordinate derivatives. + """ + params = tuple(self.parameters()) + requires_grad = tuple(param.requires_grad for param in params) + if not any(requires_grad): + yield + return + for param in params: + param.requires_grad_(False) + try: + yield + finally: + for param, flag in zip(params, requires_grad, strict=True): + param.requires_grad_(flag) + + def _forward_without_loss( + self, + task_key: str, + input_dict: dict[str, Any], + ) -> dict[str, torch.Tensor]: + """Return model predictions without constructing a loss.""" + return self.model[task_key](**input_dict) def set_extra_state(self, state: dict) -> None: self.model_params = state.get("model_params", {}) diff --git a/source/tests/pt/test_wrapper.py b/source/tests/pt/test_wrapper.py new file mode 100644 index 0000000000..9fd64bdb06 --- /dev/null +++ b/source/tests/pt/test_wrapper.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for PyTorch model wrapper behavior.""" + +from __future__ import ( + annotations, +) + +import unittest + +import torch + +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) + + +class _LinearToyModel(torch.nn.Module): + def __init__(self, *, fail_forward: bool = False) -> None: + super().__init__() + self.linear = torch.nn.Linear(3, 1, bias=False, device="cpu") + self.scale = torch.nn.Parameter(torch.ones((), device="cpu")) + self.fail_forward = fail_forward + self.last_requires_grad: tuple[bool, ...] | None = None + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + do_atomic_virial: bool = False, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + del atype, box, do_atomic_virial, fparam, aparam, charge_spin + self.last_requires_grad = tuple( + param.requires_grad for param in self.parameters() + ) + if self.fail_forward: + raise RuntimeError("intentional toy failure") + coord_req = coord.clone().requires_grad_(True) + atom_energy = self.scale * self.linear(coord_req).sum(dim=1, keepdim=True) + energy = atom_energy.sum(dim=1) + force = -torch.autograd.grad(energy.sum(), coord_req, create_graph=True)[0] + return { + "atom_energy": atom_energy, + "energy": energy, + "force": force, + } + + +class _EnergyLoss(torch.nn.Module): + def forward( + self, + input_dict: dict[str, torch.Tensor], + model: torch.nn.Module, + label: torch.Tensor | None, + natoms: int, + learning_rate: torch.Tensor | None = None, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict]: + del label, natoms, learning_rate + pred = model(**input_dict) + loss = pred["energy"].sum() + return pred, loss, {} + + +class TestModelWrapper(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(20240611) + self.coord = torch.randn(2, 5, 3, device="cpu") + self.atype = torch.zeros(2, 5, dtype=torch.long, device="cpu") + + def test_inference_wrapper_freezes_parameters_without_changing_predictions( + self, + ) -> None: + model = _LinearToyModel() + reference_model = _LinearToyModel() + reference_model.load_state_dict(model.state_dict()) + wrapper = ModelWrapper(model) + reference_wrapper = ModelWrapper(reference_model, _EnergyLoss()) + + ref, _, _ = reference_wrapper(self.coord, self.atype, skip_loss=True) + out, _, _ = wrapper(self.coord, self.atype) + + self.assertEqual(model.last_requires_grad, (False, False)) + self.assertEqual(reference_model.last_requires_grad, (True, True)) + self.assertTrue(all(param.requires_grad for param in wrapper.parameters())) + for key in ("atom_energy", "energy", "force"): + torch.testing.assert_close(out[key], ref[key]) + + def test_inference_wrapper_restores_mixed_parameter_flags(self) -> None: + model = _LinearToyModel() + model.linear.weight.requires_grad_(False) + wrapper = ModelWrapper(model) + + wrapper(self.coord, self.atype) + + self.assertEqual(model.last_requires_grad, (False, False)) + self.assertFalse(model.linear.weight.requires_grad) + self.assertTrue(model.scale.requires_grad) + + def test_inference_wrapper_restores_parameters_after_exception(self) -> None: + model = _LinearToyModel(fail_forward=True) + wrapper = ModelWrapper(model) + + with self.assertRaisesRegex(RuntimeError, "intentional toy failure"): + wrapper(self.coord, self.atype) + + self.assertEqual(model.last_requires_grad, (False, False)) + self.assertTrue(all(param.requires_grad for param in wrapper.parameters())) + + def test_multitask_inference_wrapper_freezes_selected_head(self) -> None: + model_a = _LinearToyModel() + model_b = _LinearToyModel() + wrapper = ModelWrapper({"a": model_a, "b": model_b}) + + wrapper(self.coord, self.atype, task_key="b") + + self.assertIsNone(model_a.last_requires_grad) + self.assertEqual(model_b.last_requires_grad, (False, False)) + self.assertTrue(all(param.requires_grad for param in wrapper.parameters())) + + def test_skip_loss_keeps_training_gradients(self) -> None: + model = _LinearToyModel() + wrapper = ModelWrapper(model, _EnergyLoss()) + + pred, _, _ = wrapper(self.coord, self.atype, skip_loss=True) + pred["energy"].sum().backward() + + self.assertEqual(model.last_requires_grad, (True, True)) + self.assertIsNotNone(model.linear.weight.grad) + self.assertIsNotNone(model.scale.grad) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/test_wrapper.py b/source/tests/pt_expt/test_wrapper.py new file mode 100644 index 0000000000..594067cd39 --- /dev/null +++ b/source/tests/pt_expt/test_wrapper.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for experimental PyTorch model wrapper behavior.""" + +from __future__ import ( + annotations, +) + +import unittest + +import torch + +from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, +) + + +class _LinearToyModel(torch.nn.Module): + def __init__(self, *, fail_forward: bool = False) -> None: + super().__init__() + self.linear = torch.nn.Linear(3, 1, bias=False, device="cpu") + self.scale = torch.nn.Parameter(torch.ones((), device="cpu")) + self.fail_forward = fail_forward + self.last_requires_grad: tuple[bool, ...] | None = None + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + do_atomic_virial: bool = False, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + del atype, box, do_atomic_virial, fparam, aparam, charge_spin + self.last_requires_grad = tuple( + param.requires_grad for param in self.parameters() + ) + if self.fail_forward: + raise RuntimeError("intentional toy failure") + coord_req = coord.clone().requires_grad_(True) + atom_energy = self.scale * self.linear(coord_req).sum(dim=1, keepdim=True) + energy = atom_energy.sum(dim=1) + force = -torch.autograd.grad(energy.sum(), coord_req, create_graph=True)[0] + return { + "atom_energy": atom_energy, + "energy": energy, + "force": force, + } + + +class _EnergyLoss: + def __call__( + self, + cur_lr: float | torch.Tensor | None, + natoms: int, + model_pred: dict[str, torch.Tensor], + label: dict[str, torch.Tensor] | None, + ) -> tuple[torch.Tensor, dict]: + del cur_lr, natoms, label + loss = model_pred["energy"].sum() + return loss, {} + + +class TestModelWrapper(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(20240611) + self.coord = torch.randn(2, 5, 3, device="cpu") + self.atype = torch.zeros(2, 5, dtype=torch.long, device="cpu") + + def test_inference_wrapper_freezes_parameters_without_changing_predictions( + self, + ) -> None: + model = _LinearToyModel() + reference_model = _LinearToyModel() + reference_model.load_state_dict(model.state_dict()) + wrapper = ModelWrapper(model) + reference_wrapper = ModelWrapper(reference_model, _EnergyLoss()) + + ref, _, _ = reference_wrapper(self.coord, self.atype) + out, _, _ = wrapper(self.coord, self.atype) + + self.assertEqual(model.last_requires_grad, (False, False)) + self.assertEqual(reference_model.last_requires_grad, (True, True)) + self.assertTrue(all(param.requires_grad for param in wrapper.parameters())) + for key in ("atom_energy", "energy", "force"): + torch.testing.assert_close(out[key], ref[key]) + + def test_inference_wrapper_restores_mixed_parameter_flags(self) -> None: + model = _LinearToyModel() + model.linear.weight.requires_grad_(False) + wrapper = ModelWrapper(model) + + wrapper(self.coord, self.atype) + + self.assertEqual(model.last_requires_grad, (False, False)) + self.assertFalse(model.linear.weight.requires_grad) + self.assertTrue(model.scale.requires_grad) + + def test_inference_wrapper_restores_parameters_after_exception(self) -> None: + model = _LinearToyModel(fail_forward=True) + wrapper = ModelWrapper(model) + + with self.assertRaisesRegex(RuntimeError, "intentional toy failure"): + wrapper(self.coord, self.atype) + + self.assertEqual(model.last_requires_grad, (False, False)) + self.assertTrue(all(param.requires_grad for param in wrapper.parameters())) + + def test_multitask_inference_wrapper_freezes_selected_head(self) -> None: + model_a = _LinearToyModel() + model_b = _LinearToyModel() + wrapper = ModelWrapper({"a": model_a, "b": model_b}) + + wrapper(self.coord, self.atype, task_key="b") + + self.assertIsNone(model_a.last_requires_grad) + self.assertEqual(model_b.last_requires_grad, (False, False)) + self.assertTrue(all(param.requires_grad for param in wrapper.parameters())) + + def test_training_wrapper_without_label_keeps_parameter_gradients(self) -> None: + model = _LinearToyModel() + wrapper = ModelWrapper(model, _EnergyLoss()) + + pred, _, _ = wrapper(self.coord, self.atype) + pred["energy"].sum().backward() + + self.assertEqual(model.last_requires_grad, (True, True)) + self.assertIsNotNone(model.linear.weight.grad) + self.assertIsNotNone(model.scale.grad) + + +if __name__ == "__main__": + unittest.main()