diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 783ee9e766..d59d518cab 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -271,8 +271,8 @@ def forward_atomic( """ nframes, nloc, nnei = nlist.shape atype = extended_atype[:, :nloc] - if self.do_grad_r() or self.do_grad_c(): - extended_coord.requires_grad_(True) + if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: + extended_coord = extended_coord.clone().requires_grad_(True) # Handle default chg_spin if descriptor supports it if self.add_chg_spin_ebd and charge_spin is None: diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 5c0f616634..6c620f6f5b 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -258,8 +258,8 @@ def forward_atomic( the result dict, defined by the fitting net output def. """ nframes, nloc, nnei = nlist.shape - if self.do_grad_r() or self.do_grad_c(): - extended_coord.requires_grad_(True) + if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: + extended_coord = extended_coord.clone().requires_grad_(True) extended_coord = extended_coord.view(nframes, -1, 3) sorted_rcuts, sorted_sels = self._sort_rcuts_sels() nlists = build_multiple_neighbor_list( diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 5750f7cfd1..eb70cc0e78 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -274,9 +274,9 @@ def forward_atomic( charge_spin: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: nframes, nloc, nnei = nlist.shape + if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: + extended_coord = extended_coord.clone().requires_grad_(True) extended_coord = extended_coord.view(nframes, -1, 3) - if self.do_grad_r() or self.do_grad_c(): - extended_coord.requires_grad_(True) # this will mask all -1 in the nlist mask = nlist >= 0 diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 713eab3d8c..78705b153c 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -306,8 +306,12 @@ def forward_common_lower( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam + force_coord = cc_ext + if self.atomic_model.do_grad_r() or self.atomic_model.do_grad_c(): + if not force_coord.requires_grad: + force_coord = force_coord.clone().requires_grad_(True) atomic_ret = self.atomic_model.forward_common_atomic( - cc_ext, + force_coord, extended_atype, nlist, mapping=mapping, @@ -319,7 +323,7 @@ def forward_common_lower( model_predict = fit_output_to_model_output( atomic_ret, self.atomic_output_def(), - cc_ext, + force_coord, do_atomic_virial=do_atomic_virial, create_graph=self.training, mask=atomic_ret["mask"] if "mask" in atomic_ret else None, diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 6d6a22f357..aa6640992e 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -74,6 +74,67 @@ def test_self_consistency(self) -> None: to_numpy_array(ret1["energy"]), ) + def test_forward_common_atomic_accepts_leaf_view_input(self) -> None: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + + coord = to_torch_tensor(self.coord_ext) + coord_view = coord.view(self.nf, self.nall, 3) + coord_view_before = coord_view.detach().clone() + self.assertTrue(coord.is_leaf) + self.assertTrue(coord_view._is_view()) + args = [ + coord_view, + to_torch_tensor(self.atype_ext), + to_torch_tensor(self.nlist), + ] + ret = md0.forward_common_atomic(*args) + + self.assertFalse(coord_view.requires_grad) + torch.testing.assert_close(coord_view, coord_view_before) + self.assertIn("energy", ret) + self.assertEqual(ret["energy"].shape, (self.nf, self.nloc, 1)) + self.assertTrue(torch.isfinite(ret["energy"]).all()) + + def test_forward_common_atomic_preserves_grad_enabled_input(self) -> None: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + + coord = to_torch_tensor(self.coord_ext) + coord_view = coord.view(self.nf, self.nall, 3).clone().requires_grad_(True) + args = [ + coord_view, + to_torch_tensor(self.atype_ext), + to_torch_tensor(self.nlist), + ] + ret = md0.forward_common_atomic(*args) + ret["energy"].sum().backward() + + self.assertTrue(coord_view.requires_grad) + self.assertIsNotNone(coord_view.grad) + def test_dp_consistency(self) -> None: nf, nloc, nnei = self.nlist.shape ds = DPDescrptSeA( diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index f4e350869a..21335b5bec 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -114,6 +114,51 @@ def test_self_consistency(self) -> None: atol=self.atol, ) + def test_forward_lower_accepts_leaf_view_input(self) -> None: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = EnergyFittingNet( + self.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE) + + coord_ext, atype_ext, _ = extend_coord_with_ghosts( + to_torch_tensor(self.coord), + to_torch_tensor(self.atype), + to_torch_tensor(self.cell), + self.rcut, + ) + nlist = build_neighbor_list( + coord_ext, + atype_ext, + self.nloc, + self.rcut, + self.sel, + distinguish_types=(not md0.mixed_types()), + ) + coord_view = coord_ext.view(self.nf, -1, 3) + + ret = md0.forward_lower(coord_view, atype_ext, nlist, do_atomic_virial=True) + + self.assertFalse(coord_view.requires_grad) + self.assertIn("extended_force", ret) + self.assertIn("virial", ret) + + coord_view_grad = coord_ext.view(self.nf, -1, 3).clone().requires_grad_(True) + ret = md0.forward_lower( + coord_view_grad, atype_ext, nlist, do_atomic_virial=True + ) + + self.assertTrue(coord_view_grad.requires_grad) + self.assertIn("extended_force", ret) + self.assertIn("virial", ret) + def test_dp_consistency(self) -> None: nf, nloc = self.atype.shape nfp, nap = 2, 3 diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index 038bb04a5a..e693e8b90e 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -181,6 +181,33 @@ def test_self_consistency(self) -> None: to_numpy_array(ret0["energy"]), ret2["energy"], atol=0.001, rtol=0.001 ) + def test_forward_atomic_accepts_leaf_view_input(self) -> None: + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + coord = args[0] + coord_view = coord.view(self.nf, self.nall, 3) + coord_view_before = coord_view.detach().clone() + self.assertTrue(coord.is_leaf) + self.assertTrue(coord_view._is_view()) + args[0] = coord_view + ret = self.md0.forward_atomic(*args) + + self.assertFalse(coord_view.requires_grad) + torch.testing.assert_close(coord_view, coord_view_before) + self.assertIn("energy", ret) + + def test_forward_atomic_preserves_grad_enabled_input(self) -> None: + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + args[0] = args[0].view(self.nf, self.nall, 3).clone().requires_grad_(True) + ret = self.md0.forward_atomic(*args) + ret["energy"].sum().backward() + + self.assertTrue(args[0].requires_grad) + self.assertIsNotNone(args[0].grad) + def test_jit(self) -> None: md1 = torch.jit.script(self.md1) # atomic model no more export methods diff --git a/source/tests/pt/model/test_pairtab_atomic_model.py b/source/tests/pt/model/test_pairtab_atomic_model.py index 0f324cbf51..01927b4967 100644 --- a/source/tests/pt/model/test_pairtab_atomic_model.py +++ b/source/tests/pt/model/test_pairtab_atomic_model.py @@ -96,6 +96,26 @@ def test_with_mask(self) -> None: result["energy"], expected_result, rtol=0.0001, atol=0.0001 ) + def test_forward_common_atomic_accepts_leaf_view_input(self) -> None: + coord = self.extended_coord.clone() + coord_view = coord.view(2, 4, 3) + coord_view_before = coord_view.detach().clone() + self.assertTrue(coord.is_leaf) + self.assertTrue(coord_view._is_view()) + ret = self.model.forward_atomic(coord_view, self.extended_atype, self.nlist) + + self.assertFalse(coord_view.requires_grad) + torch.testing.assert_close(coord_view, coord_view_before) + self.assertIn("energy", ret) + + def test_forward_common_atomic_preserves_grad_enabled_input(self) -> None: + coord = self.extended_coord.clone().requires_grad_(True) + ret = self.model.forward_atomic(coord, self.extended_atype, self.nlist) + ret["energy"].sum().backward() + + self.assertTrue(coord.requires_grad) + self.assertIsNotNone(coord.grad) + def test_jit(self) -> None: model = torch.jit.script(self.model) # atomic model no more export methods