Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting
if hasattr(self.fitting, "reinit_exclude"):
self.fitting.reinit_exclude(self.atom_exclude_types)
self.type_map = type_map
super().init_out_stat()

Expand Down Expand Up @@ -191,7 +193,7 @@ def change_type_map(
if model_with_new_type_stat is not None
else None,
)
self.fitting_net.change_type_map(type_map=type_map)
self.fitting.change_type_map(type_map=type_map)

def serialize(self) -> dict:
dd = super().serialize()
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
if hasattr(self.fitting_net, "reinit_exclude"):
self.fitting_net.reinit_exclude(self.atom_exclude_types)
super().init_out_stat()
self.enable_eval_descriptor_hook = False
self.enable_eval_fitting_last_layer_hook = False
Expand Down Expand Up @@ -151,6 +153,9 @@ def change_type_map(
else None,
)
self.fitting_net.change_type_map(type_map=type_map)
# Reinitialize fitting to get correct sel_type
if hasattr(self.fitting_net, "reinit_exclude"):
self.fitting_net.reinit_exclude(self.atom_exclude_types)

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
Expand Down
13 changes: 10 additions & 3 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,16 +1003,23 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
check_version_compatibility(data.pop("@version", 2), 2, 1)
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
# bias_atom_e and out_bias are now completely independent - no conversion needed
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
fitting_dict = data.pop("fitting", {})
atom_exclude_types = data.pop("atom_exclude_types", [])
if len(atom_exclude_types) > 0:
# get sel_type from complement of atom_exclude_types
full_type_list = np.arange(len(data["type_map"]), dtype=int)
sel_type = np.setdiff1d(
full_type_list, atom_exclude_types, assume_unique=True
)
fitting_dict["sel_type"] = sel_type.tolist()
fitting = Fitting.deserialize(fitting_dict, suffix=suffix)
# pass descriptor type embedding to model
if descriptor.explicit_ntypes:
type_embedding = descriptor.type_embedding
fitting.dim_descrpt -= type_embedding.neuron[-1]
else:
type_embedding = None
# BEGINE not supported keys
if len(data.pop("atom_exclude_types")) > 0:
raise NotImplementedError("atom_exclude_types is not supported")
if len(data.pop("pair_exclude_types")) > 0:
raise NotImplementedError("pair_exclude_types is not supported")
data.pop("rcond", None)
Expand Down
12 changes: 12 additions & 0 deletions source/tests/consistent/model/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,15 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
ret[1].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")

def test_atom_exclude_types(self):
if self.skip_pt:
self.skipTest("Unsupported backend")
if self.skip_tf:
self.skipTest("Unsupported backend")
_ret, data = self.get_reference_ret_serialization(self.RefBackend.PT)
data["atom_exclude_types"] = [1]
self.reset_unique_id()
tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id)
pt_obj = self.pt_class.deserialize(data)
self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type())
12 changes: 12 additions & 0 deletions source/tests/consistent/model/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,15 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
ret[1].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")

def test_atom_exclude_types(self):
if self.skip_pt:
self.skipTest("Unsupported backend")
if self.skip_tf:
self.skipTest("Unsupported backend")
_ret, data = self.get_reference_ret_serialization(self.RefBackend.PT)
data["atom_exclude_types"] = [1]
self.reset_unique_id()
tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id)
pt_obj = self.pt_class.deserialize(data)
self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type())
6 changes: 6 additions & 0 deletions source/tests/pt/model/test_get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def test_model_attr(self) -> None:
]
},
)
full_type_list = np.arange(len(atomic_model.type_map), dtype=int)
atom_exclude_types = np.setdiff1d(
full_type_list,
self.model.get_sel_type(),
).tolist()
self.assertEqual(atom_exclude_types, [1])
self.assertEqual(atomic_model.atom_exclude_types, [1])
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])

Expand Down
79 changes: 64 additions & 15 deletions source/tests/universal/dpmodel/atomc_model/test_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
DPZBLLinearEnergyAtomicModel,
Expand Down Expand Up @@ -72,6 +74,13 @@
)


def make_sel_type_from_atom_exclude_types(type_map, atom_exclude_types):
"""Get sel_type from complement of atom_exclude_types."""
full_type_list = np.arange(len(type_map), dtype=int)
sel_type = np.setdiff1d(full_type_list, atom_exclude_types, assume_unique=True)
return sel_type.tolist()


@parameterized(
des_parameterized=(
(
Expand All @@ -85,6 +94,7 @@
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
Expand All @@ -97,6 +107,7 @@
(
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
)
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
Expand Down Expand Up @@ -128,16 +139,22 @@ def setUpClass(cls) -> None:
**cls.input_dict_ft,
)
cls.module = DPAtomicModel(
ds,
ft,
type_map=cls.expected_type_map,
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
)
cls.output_def = cls.module.atomic_output_def().get_data()
cls.expected_has_message_passing = ds.has_message_passing()
cls.expected_sel_type = ft.get_sel_type()
cls.expected_dim_fparam = ft.get_dim_fparam()
cls.expected_dim_aparam = ft.get_dim_aparam()

def test_sel_type_from_atom_exclude_types(self):
self.assertEqual(
make_sel_type_from_atom_exclude_types(
self.expected_type_map, self.param[2]
),
self.expected_sel_type,
)


@parameterized(
des_parameterized=(
Expand All @@ -152,6 +169,7 @@ def setUpClass(cls) -> None:
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
((FittingParamDos, DOSFittingNet),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
Expand All @@ -164,6 +182,7 @@ def setUpClass(cls) -> None:
(
*[(param_func, DOSFittingNet) for param_func in FittingParamDosList],
), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
)
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
Expand Down Expand Up @@ -195,16 +214,22 @@ def setUpClass(cls) -> None:
**cls.input_dict_ft,
)
cls.module = DPAtomicModel(
ds,
ft,
type_map=cls.expected_type_map,
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
)
cls.output_def = cls.module.atomic_output_def().get_data()
cls.expected_has_message_passing = ds.has_message_passing()
cls.expected_sel_type = ft.get_sel_type()
cls.expected_dim_fparam = ft.get_dim_fparam()
cls.expected_dim_aparam = ft.get_dim_aparam()

def test_sel_type_from_atom_exclude_types(self):
self.assertEqual(
make_sel_type_from_atom_exclude_types(
self.expected_type_map, self.param[2]
),
self.expected_sel_type,
)


@parameterized(
des_parameterized=(
Expand All @@ -216,6 +241,7 @@ def setUpClass(cls) -> None:
(DescriptorParamHybridMixed, DescrptHybrid),
), # descrpt_class_param & class
((FittingParamDipole, DipoleFitting),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
Expand All @@ -226,6 +252,7 @@ def setUpClass(cls) -> None:
(
*[(param_func, DipoleFitting) for param_func in FittingParamDipoleList],
), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
)
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
Expand Down Expand Up @@ -258,16 +285,22 @@ def setUpClass(cls) -> None:
**cls.input_dict_ft,
)
cls.module = DPAtomicModel(
ds,
ft,
type_map=cls.expected_type_map,
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
)
cls.output_def = cls.module.atomic_output_def().get_data()
cls.expected_has_message_passing = ds.has_message_passing()
cls.expected_sel_type = ft.get_sel_type()
cls.expected_dim_fparam = ft.get_dim_fparam()
cls.expected_dim_aparam = ft.get_dim_aparam()

def test_sel_type_from_atom_exclude_types(self):
self.assertEqual(
make_sel_type_from_atom_exclude_types(
self.expected_type_map, self.param[2]
),
self.expected_sel_type,
)


@parameterized(
des_parameterized=(
Expand All @@ -279,6 +312,7 @@ def setUpClass(cls) -> None:
(DescriptorParamHybridMixed, DescrptHybrid),
), # descrpt_class_param & class
((FittingParamPolar, PolarFitting),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
Expand All @@ -289,6 +323,7 @@ def setUpClass(cls) -> None:
(
*[(param_func, PolarFitting) for param_func in FittingParamPolarList],
), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
)
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
Expand Down Expand Up @@ -321,16 +356,22 @@ def setUpClass(cls) -> None:
**cls.input_dict_ft,
)
cls.module = DPAtomicModel(
ds,
ft,
type_map=cls.expected_type_map,
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
)
cls.output_def = cls.module.atomic_output_def().get_data()
cls.expected_has_message_passing = ds.has_message_passing()
cls.expected_sel_type = ft.get_sel_type()
cls.expected_dim_fparam = ft.get_dim_fparam()
cls.expected_dim_aparam = ft.get_dim_aparam()

def test_sel_type_from_atom_exclude_types(self):
self.assertEqual(
make_sel_type_from_atom_exclude_types(
self.expected_type_map, self.param[2]
),
self.expected_sel_type,
)


@parameterized(
des_parameterized=(
Expand Down Expand Up @@ -415,6 +456,7 @@ def setUpClass(cls) -> None:
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
), # descrpt_class_param & class
((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
fit_parameterized=(
(
Expand All @@ -428,6 +470,7 @@ def setUpClass(cls) -> None:
for param_func in FittingParamPropertyList
],
), # fitting_class_param & class
([], [0]), # atom_exclude_types
),
)
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
Expand Down Expand Up @@ -460,12 +503,18 @@ def setUpClass(cls) -> None:
**cls.input_dict_ft,
)
cls.module = DPAtomicModel(
ds,
ft,
type_map=cls.expected_type_map,
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
)
cls.output_def = cls.module.atomic_output_def().get_data()
cls.expected_has_message_passing = ds.has_message_passing()
cls.expected_sel_type = ft.get_sel_type()
cls.expected_dim_fparam = ft.get_dim_fparam()
cls.expected_dim_aparam = ft.get_dim_aparam()

def test_sel_type_from_atom_exclude_types(self):
self.assertEqual(
make_sel_type_from_atom_exclude_types(
self.expected_type_map, self.param[2]
),
self.expected_sel_type,
)