diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 56fc2b42c7..07a02ad56b 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -51,6 +51,8 @@ def __init__( self.type_map = type_map self.descriptor = descriptor self.fitting = fitting + if hasattr(self.fitting, "reinit_exclude"): + self.fitting.reinit_exclude(self.atom_exclude_types) self.type_map = type_map super().init_out_stat() @@ -191,7 +193,7 @@ def change_type_map( if model_with_new_type_stat is not None else None, ) - self.fitting_net.change_type_map(type_map=type_map) + self.fitting.change_type_map(type_map=type_map) def serialize(self) -> dict: dd = super().serialize() diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 4ebef0106c..af2e8954df 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -64,6 +64,8 @@ def __init__( self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() self.fitting_net = fitting + if hasattr(self.fitting_net, "reinit_exclude"): + self.fitting_net.reinit_exclude(self.atom_exclude_types) super().init_out_stat() self.enable_eval_descriptor_hook = False self.enable_eval_fitting_last_layer_hook = False @@ -151,6 +153,9 @@ def change_type_map( else None, ) self.fitting_net.change_type_map(type_map=type_map) + # Reinitialize fitting to get correct sel_type + if hasattr(self.fitting_net, "reinit_exclude"): + self.fitting_net.reinit_exclude(self.atom_exclude_types) def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index da99cf5e60..26a3f7bf73 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -1003,7 +1003,16 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": check_version_compatibility(data.pop("@version", 2), 2, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) # bias_atom_e and out_bias are now completely independent - no conversion needed - fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) + fitting_dict = data.pop("fitting", {}) + atom_exclude_types = data.pop("atom_exclude_types", []) + if len(atom_exclude_types) > 0: + # get sel_type from complement of atom_exclude_types + full_type_list = np.arange(len(data["type_map"]), dtype=int) + sel_type = np.setdiff1d( + full_type_list, atom_exclude_types, assume_unique=True + ) + fitting_dict["sel_type"] = sel_type.tolist() + fitting = Fitting.deserialize(fitting_dict, suffix=suffix) # pass descriptor type embedding to model if descriptor.explicit_ntypes: type_embedding = descriptor.type_embedding @@ -1011,8 +1020,6 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": else: type_embedding = None # BEGINE not supported keys - if len(data.pop("atom_exclude_types")) > 0: - raise NotImplementedError("atom_exclude_types is not supported") if len(data.pop("pair_exclude_types")) > 0: raise NotImplementedError("pair_exclude_types is not supported") data.pop("rcond", None) diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py index 78146a4974..339dcae7c3 100644 --- a/source/tests/consistent/model/test_dipole.py +++ b/source/tests/consistent/model/test_dipole.py @@ -204,3 +204,15 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret[1].ravel(), ) raise ValueError(f"Unknown backend: {backend}") + + def test_atom_exclude_types(self): + if self.skip_pt: + self.skipTest("Unsupported backend") + if self.skip_tf: + self.skipTest("Unsupported backend") + _ret, data = self.get_reference_ret_serialization(self.RefBackend.PT) + data["atom_exclude_types"] = [1] + self.reset_unique_id() + tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id) + pt_obj = self.pt_class.deserialize(data) + self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type()) diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index 62e84a27c4..1405814f03 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -198,3 +198,15 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret[1].ravel(), ) raise ValueError(f"Unknown backend: {backend}") + + def test_atom_exclude_types(self): + if self.skip_pt: + self.skipTest("Unsupported backend") + if self.skip_tf: + self.skipTest("Unsupported backend") + _ret, data = self.get_reference_ret_serialization(self.RefBackend.PT) + data["atom_exclude_types"] = [1] + self.reset_unique_id() + tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id) + pt_obj = self.pt_class.deserialize(data) + self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type()) diff --git a/source/tests/pt/model/test_get_model.py b/source/tests/pt/model/test_get_model.py index e323c95ce0..2db6370059 100644 --- a/source/tests/pt/model/test_get_model.py +++ b/source/tests/pt/model/test_get_model.py @@ -60,6 +60,12 @@ def test_model_attr(self) -> None: ] }, ) + full_type_list = np.arange(len(atomic_model.type_map), dtype=int) + atom_exclude_types = np.setdiff1d( + full_type_list, + self.model.get_sel_type(), + ).tolist() + self.assertEqual(atom_exclude_types, [1]) self.assertEqual(atomic_model.atom_exclude_types, [1]) self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]]) diff --git a/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py b/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py index 7b579ae82c..7aa94e33ee 100644 --- a/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py +++ b/source/tests/universal/dpmodel/atomc_model/test_atomic_model.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +import numpy as np + from deepmd.dpmodel.atomic_model import ( DPAtomicModel, DPZBLLinearEnergyAtomicModel, @@ -72,6 +74,13 @@ ) +def make_sel_type_from_atom_exclude_types(type_map, atom_exclude_types): + """Get sel_type from complement of atom_exclude_types.""" + full_type_list = np.arange(len(type_map), dtype=int) + sel_type = np.setdiff1d(full_type_list, atom_exclude_types, assume_unique=True) + return sel_type.tolist() + + @parameterized( des_parameterized=( ( @@ -85,6 +94,7 @@ (DescriptorParamHybridMixedTTebd, DescrptHybrid), ), # descrpt_class_param & class ((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), fit_parameterized=( ( @@ -97,6 +107,7 @@ ( *[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList], ), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), ) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") @@ -128,9 +139,7 @@ def setUpClass(cls) -> None: **cls.input_dict_ft, ) cls.module = DPAtomicModel( - ds, - ft, - type_map=cls.expected_type_map, + ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2] ) cls.output_def = cls.module.atomic_output_def().get_data() cls.expected_has_message_passing = ds.has_message_passing() @@ -138,6 +147,14 @@ def setUpClass(cls) -> None: cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + def test_sel_type_from_atom_exclude_types(self): + self.assertEqual( + make_sel_type_from_atom_exclude_types( + self.expected_type_map, self.param[2] + ), + self.expected_sel_type, + ) + @parameterized( des_parameterized=( @@ -152,6 +169,7 @@ def setUpClass(cls) -> None: (DescriptorParamHybridMixedTTebd, DescrptHybrid), ), # descrpt_class_param & class ((FittingParamDos, DOSFittingNet),), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), fit_parameterized=( ( @@ -164,6 +182,7 @@ def setUpClass(cls) -> None: ( *[(param_func, DOSFittingNet) for param_func in FittingParamDosList], ), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), ) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") @@ -195,9 +214,7 @@ def setUpClass(cls) -> None: **cls.input_dict_ft, ) cls.module = DPAtomicModel( - ds, - ft, - type_map=cls.expected_type_map, + ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2] ) cls.output_def = cls.module.atomic_output_def().get_data() cls.expected_has_message_passing = ds.has_message_passing() @@ -205,6 +222,14 @@ def setUpClass(cls) -> None: cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + def test_sel_type_from_atom_exclude_types(self): + self.assertEqual( + make_sel_type_from_atom_exclude_types( + self.expected_type_map, self.param[2] + ), + self.expected_sel_type, + ) + @parameterized( des_parameterized=( @@ -216,6 +241,7 @@ def setUpClass(cls) -> None: (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class ((FittingParamDipole, DipoleFitting),), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), fit_parameterized=( ( @@ -226,6 +252,7 @@ def setUpClass(cls) -> None: ( *[(param_func, DipoleFitting) for param_func in FittingParamDipoleList], ), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), ) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") @@ -258,9 +285,7 @@ def setUpClass(cls) -> None: **cls.input_dict_ft, ) cls.module = DPAtomicModel( - ds, - ft, - type_map=cls.expected_type_map, + ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2] ) cls.output_def = cls.module.atomic_output_def().get_data() cls.expected_has_message_passing = ds.has_message_passing() @@ -268,6 +293,14 @@ def setUpClass(cls) -> None: cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + def test_sel_type_from_atom_exclude_types(self): + self.assertEqual( + make_sel_type_from_atom_exclude_types( + self.expected_type_map, self.param[2] + ), + self.expected_sel_type, + ) + @parameterized( des_parameterized=( @@ -279,6 +312,7 @@ def setUpClass(cls) -> None: (DescriptorParamHybridMixed, DescrptHybrid), ), # descrpt_class_param & class ((FittingParamPolar, PolarFitting),), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), fit_parameterized=( ( @@ -289,6 +323,7 @@ def setUpClass(cls) -> None: ( *[(param_func, PolarFitting) for param_func in FittingParamPolarList], ), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), ) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") @@ -321,9 +356,7 @@ def setUpClass(cls) -> None: **cls.input_dict_ft, ) cls.module = DPAtomicModel( - ds, - ft, - type_map=cls.expected_type_map, + ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2] ) cls.output_def = cls.module.atomic_output_def().get_data() cls.expected_has_message_passing = ds.has_message_passing() @@ -331,6 +364,14 @@ def setUpClass(cls) -> None: cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + def test_sel_type_from_atom_exclude_types(self): + self.assertEqual( + make_sel_type_from_atom_exclude_types( + self.expected_type_map, self.param[2] + ), + self.expected_sel_type, + ) + @parameterized( des_parameterized=( @@ -415,6 +456,7 @@ def setUpClass(cls) -> None: (DescriptorParamHybridMixedTTebd, DescrptHybrid), ), # descrpt_class_param & class ((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), fit_parameterized=( ( @@ -428,6 +470,7 @@ def setUpClass(cls) -> None: for param_func in FittingParamPropertyList ], ), # fitting_class_param & class + ([], [0]), # atom_exclude_types ), ) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") @@ -460,12 +503,18 @@ def setUpClass(cls) -> None: **cls.input_dict_ft, ) cls.module = DPAtomicModel( - ds, - ft, - type_map=cls.expected_type_map, + ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2] ) cls.output_def = cls.module.atomic_output_def().get_data() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + + def test_sel_type_from_atom_exclude_types(self): + self.assertEqual( + make_sel_type_from_atom_exclude_types( + self.expected_type_map, self.param[2] + ), + self.expected_sel_type, + )