diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 7d91218468..3208903b95 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -1038,7 +1038,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict: if JIT: break - if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): + if ( + self.change_bias_after_training + and self.num_steps > self.start_step + and (self.rank == 0 or dist.get_rank() == 0) + ): if not self.multi_task: self.model = model_change_out_bias( self.model, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a5b799dbdc..8d66a86a9c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1745,7 +1745,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict: if JIT: break - if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): + if ( + self.change_bias_after_training + and self.num_steps > self.start_step + and (self.rank == 0 or dist.get_rank() == 0) + ): if not self.multi_task: self.model = model_change_out_bias( self.model, diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index 21b7b3b854..d7cdf96577 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -11,6 +11,7 @@ ) import numpy as np +import paddle from deepmd.pd.entrypoints.main import ( get_trainer, @@ -163,6 +164,25 @@ def setUp(self) -> None: self.config["training"]["save_freq"] = 1 enable_prim(True) + def test_zero_step_with_change_bias_saves_initial_checkpoint(self) -> None: + config = deepcopy(self.config) + config["training"]["numb_steps"] = 0 + config["training"]["change_bias_after_training"] = True + trainer = get_trainer(config) + trainer.run() + + expected_model = Path(trainer.save_ckpt + "-0.pd") + self.assertEqual(expected_model, trainer.latest_model) + self.assertTrue(expected_model.exists()) + self.assertEqual( + expected_model, + Path(Path("checkpoint").read_text().strip()), + ) + checkpoint = paddle.load(expected_model) + train_infos = checkpoint["model"]["_extra_state"]["train_infos"] + self.assertEqual(0, train_infos["step"]) + self.assertEqual(0.0, train_infos["lr"]) + def tearDown(self) -> None: DPTrainTest.tearDown(self) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index c4e58c0368..e776074f5e 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -263,6 +263,25 @@ def test_yaml_input(self) -> None: ) self.assertTrue(Path("out.json").exists()) + def test_zero_step_with_change_bias_saves_initial_checkpoint(self) -> None: + config = deepcopy(self.config) + config["training"]["numb_steps"] = 0 + config["training"]["change_bias_after_training"] = True + trainer = get_trainer(config) + trainer.run() + + expected_model = Path(trainer.save_ckpt + "-0.pt") + self.assertEqual(expected_model, trainer.latest_model) + self.assertTrue(expected_model.exists()) + self.assertEqual( + expected_model, + Path(Path("checkpoint").read_text().strip()), + ) + checkpoint = torch.load(expected_model, map_location="cpu", weights_only=True) + train_infos = checkpoint["model"]["_extra_state"]["train_infos"] + self.assertEqual(0, train_infos["step"]) + self.assertEqual(0.0, train_infos["lr"]) + def tearDown(self) -> None: DPTrainTest.tearDown(self) for ff in ["out.json", "input.yaml"]: