Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
Comment thread
wanghan-iapcm marked this conversation as resolved.
extended_coord = extended_coord.clone().requires_grad_(True)

# Handle default chg_spin if descriptor supports it
if self.add_chg_spin_ebd and charge_spin is None:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def forward_atomic(
the result dict, defined by the fitting net output def.
"""
nframes, nloc, nnei = nlist.shape
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
extended_coord = extended_coord.clone().requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ def forward_atomic(
charge_spin: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
nframes, nloc, nnei = nlist.shape
if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad:
Comment thread
wanghan-iapcm marked this conversation as resolved.
extended_coord = extended_coord.clone().requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

# this will mask all -1 in the nlist
mask = nlist >= 0
Expand Down
8 changes: 6 additions & 2 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,12 @@ def forward_common_lower(
extended_coord, fparam=fparam, aparam=aparam
)
del extended_coord, fparam, aparam
force_coord = cc_ext
if self.atomic_model.do_grad_r() or self.atomic_model.do_grad_c():
if not force_coord.requires_grad:
force_coord = force_coord.clone().requires_grad_(True)
atomic_ret = self.atomic_model.forward_common_atomic(
cc_ext,
force_coord,
extended_atype,
nlist,
mapping=mapping,
Expand All @@ -319,7 +323,7 @@ def forward_common_lower(
model_predict = fit_output_to_model_output(
atomic_ret,
self.atomic_output_def(),
cc_ext,
force_coord,
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
Expand Down
61 changes: 61 additions & 0 deletions source/tests/pt/model/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,67 @@ def test_self_consistency(self) -> None:
to_numpy_array(ret1["energy"]),
)

def test_forward_common_atomic_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE)

coord = to_torch_tensor(self.coord_ext)
coord_view = coord.view(self.nf, self.nall, 3)
coord_view_before = coord_view.detach().clone()
self.assertTrue(coord.is_leaf)
self.assertTrue(coord_view._is_view())
args = [
coord_view,
to_torch_tensor(self.atype_ext),
to_torch_tensor(self.nlist),
]
ret = md0.forward_common_atomic(*args)

self.assertFalse(coord_view.requires_grad)
torch.testing.assert_close(coord_view, coord_view_before)
self.assertIn("energy", ret)
Comment thread
wanghan-iapcm marked this conversation as resolved.
self.assertEqual(ret["energy"].shape, (self.nf, self.nloc, 1))
self.assertTrue(torch.isfinite(ret["energy"]).all())

def test_forward_common_atomic_preserves_grad_enabled_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
md0 = DPAtomicModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE)

coord = to_torch_tensor(self.coord_ext)
coord_view = coord.view(self.nf, self.nall, 3).clone().requires_grad_(True)
args = [
coord_view,
to_torch_tensor(self.atype_ext),
to_torch_tensor(self.nlist),
]
ret = md0.forward_common_atomic(*args)
ret["energy"].sum().backward()

self.assertTrue(coord_view.requires_grad)
self.assertIsNotNone(coord_view.grad)

def test_dp_consistency(self) -> None:
nf, nloc, nnei = self.nlist.shape
ds = DPDescrptSeA(
Expand Down
45 changes: 45 additions & 0 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,51 @@ def test_self_consistency(self) -> None:
atol=self.atol,
)

def test_forward_lower_accepts_leaf_view_input(self) -> None:
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = EnergyFittingNet(
self.nt,
ds.get_dim_out(),
mixed_types=ds.mixed_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map).to(env.DEVICE)

coord_ext, atype_ext, _ = extend_coord_with_ghosts(
to_torch_tensor(self.coord),
to_torch_tensor(self.atype),
to_torch_tensor(self.cell),
self.rcut,
)
nlist = build_neighbor_list(
coord_ext,
atype_ext,
self.nloc,
self.rcut,
self.sel,
distinguish_types=(not md0.mixed_types()),
)
coord_view = coord_ext.view(self.nf, -1, 3)

ret = md0.forward_lower(coord_view, atype_ext, nlist, do_atomic_virial=True)

self.assertFalse(coord_view.requires_grad)
self.assertIn("extended_force", ret)
self.assertIn("virial", ret)

coord_view_grad = coord_ext.view(self.nf, -1, 3).clone().requires_grad_(True)
ret = md0.forward_lower(
coord_view_grad, atype_ext, nlist, do_atomic_virial=True
)

self.assertTrue(coord_view_grad.requires_grad)
self.assertIn("extended_force", ret)
self.assertIn("virial", ret)

def test_dp_consistency(self) -> None:
nf, nloc = self.atype.shape
nfp, nap = 2, 3
Expand Down
27 changes: 27 additions & 0 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,33 @@ def test_self_consistency(self) -> None:
to_numpy_array(ret0["energy"]), ret2["energy"], atol=0.001, rtol=0.001
)

def test_forward_atomic_accepts_leaf_view_input(self) -> None:
args = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
coord = args[0]
coord_view = coord.view(self.nf, self.nall, 3)
coord_view_before = coord_view.detach().clone()
self.assertTrue(coord.is_leaf)
self.assertTrue(coord_view._is_view())
args[0] = coord_view
ret = self.md0.forward_atomic(*args)

self.assertFalse(coord_view.requires_grad)
torch.testing.assert_close(coord_view, coord_view_before)
self.assertIn("energy", ret)

def test_forward_atomic_preserves_grad_enabled_input(self) -> None:
args = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
args[0] = args[0].view(self.nf, self.nall, 3).clone().requires_grad_(True)
ret = self.md0.forward_atomic(*args)
ret["energy"].sum().backward()

self.assertTrue(args[0].requires_grad)
self.assertIsNotNone(args[0].grad)

def test_jit(self) -> None:
md1 = torch.jit.script(self.md1)
# atomic model no more export methods
Expand Down
20 changes: 20 additions & 0 deletions source/tests/pt/model/test_pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,26 @@ def test_with_mask(self) -> None:
result["energy"], expected_result, rtol=0.0001, atol=0.0001
)

def test_forward_common_atomic_accepts_leaf_view_input(self) -> None:
coord = self.extended_coord.clone()
coord_view = coord.view(2, 4, 3)
coord_view_before = coord_view.detach().clone()
self.assertTrue(coord.is_leaf)
self.assertTrue(coord_view._is_view())
ret = self.model.forward_atomic(coord_view, self.extended_atype, self.nlist)

self.assertFalse(coord_view.requires_grad)
torch.testing.assert_close(coord_view, coord_view_before)
self.assertIn("energy", ret)

def test_forward_common_atomic_preserves_grad_enabled_input(self) -> None:
coord = self.extended_coord.clone().requires_grad_(True)
ret = self.model.forward_atomic(coord, self.extended_atype, self.nlist)
ret["energy"].sum().backward()

self.assertTrue(coord.requires_grad)
self.assertIsNotNone(coord.grad)

def test_jit(self) -> None:
model = torch.jit.script(self.model)
# atomic model no more export methods
Expand Down
Loading