-
Notifications
You must be signed in to change notification settings - Fork 623
fix(train): allow zero-step training with bias adjustment #5477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ef84d6c
631039c
5da1412
d27334c
3d7168f
c8b454d
1273e6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -263,6 +263,33 @@ def test_yaml_input(self) -> None: | |
| ) | ||
| self.assertTrue(Path("out.json").exists()) | ||
|
|
||
| @patch("deepmd.pt.train.training.model_change_out_bias") | ||
| def test_zero_step_with_change_bias_saves_initial_checkpoint( | ||
| self, mocked_change_out_bias | ||
| ) -> None: | ||
| def keep_model(model, *_args, **_kwargs): | ||
| return model | ||
|
|
||
| mocked_change_out_bias.side_effect = keep_model | ||
| config = deepcopy(self.config) | ||
| config["training"]["numb_steps"] = 0 | ||
| config["training"]["change_bias_after_training"] = True | ||
| trainer = get_trainer(config) | ||
| trainer.run() | ||
|
|
||
| expected_model = Path(trainer.save_ckpt + "-0.pt") | ||
| self.assertEqual(expected_model, trainer.latest_model) | ||
| self.assertTrue(expected_model.exists()) | ||
| self.assertEqual( | ||
| expected_model, | ||
| Path(Path("checkpoint").read_text().strip()), | ||
| ) | ||
| checkpoint = torch.load(expected_model, map_location="cpu", weights_only=True) | ||
| train_infos = checkpoint["model"]["_extra_state"]["train_infos"] | ||
| self.assertEqual(0, train_infos["step"]) | ||
| self.assertEqual(0.0, train_infos["lr"]) | ||
|
Comment on lines
+289
to
+290
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This regression test passes on the unfixed code, so it does not guard the bug it targets. With The only behavior the fix actually changes is whether Suggest mirroring
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 1273e6e. The PT zero-step regression test now patches Validation:
|
||
| mocked_change_out_bias.assert_not_called() | ||
|
|
||
| def tearDown(self) -> None: | ||
| DPTrainTest.tearDown(self) | ||
| for ff in ["out.json", "input.yaml"]: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as the PT test: this passes on the unfixed code. With
numb_steps=0, the pre-existingif self.num_steps == 0:block re-saves<ckpt>-0.pdwithstep=0, lr=0and rewrites thecheckpointpointer after the bias block, so the assertions here (existence,latest_model, pointer,step==0,lr==0.0) are satisfied regardless of thenum_steps > start_stepguard. The only thing the fix changes — whethermodel_change_out_biasruns — is never asserted.Suggest patching
model_change_out_biasand asserting it is not called fornumb_steps=0. Note the PD suite also has no test exercising the true branch (bias adjustment running fornumb_steps>0), unlike PT'stest_ema_checkpoint_keeps_changed_out_bias.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 1273e6e. The Paddle zero-step regression test now patches
deepmd.pd.train.training.model_change_out_bias, returns the original model if it is called, and assertsassert_not_called()aftertrainer.run(). This directly checks the behavior changed by the guard instead of only checking the checkpoint rewrite path.Validation:
uvx ruff check .uvx ruff format --check .I could not run the Paddle test locally because this environment is missing
paddle(ModuleNotFoundError: No module named paddle).