From fbf69c2bd0b0c767b59d99d5adf7151d945a9b7e Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 17 Dec 2025 19:11:04 +0800 Subject: [PATCH 01/16] add data modifier for pt backend (training) --- deepmd/pt/entrypoints/main.py | 23 +++++++++++ deepmd/pt/modifier/__init__.py | 8 ++++ deepmd/pt/modifier/base_modifier.py | 60 +++++++++++++++++++++++++++++ deepmd/pt/train/training.py | 6 +++ deepmd/pt/utils/dataloader.py | 13 +++++++ deepmd/pt/utils/dataset.py | 22 ++++++++++- deepmd/utils/data.py | 49 +++++++++++++++++++++++ 7 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 deepmd/pt/modifier/__init__.py create mode 100644 deepmd/pt/modifier/base_modifier.py diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index d349d519ba..08346b8027 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -47,6 +47,9 @@ from deepmd.pt.model.model import ( BaseModel, ) +from deepmd.pt.modifier import ( + BaseModifier, +) from deepmd.pt.train import ( training, ) @@ -105,12 +108,30 @@ def get_trainer( ) -> training.Trainer: multi_task = "model_dict" in config.get("model", {}) + def get_data_modifier(_modifier_params: dict[str, Any]) -> BaseModifier: + modifier_params = copy.deepcopy(_modifier_params) + try: + modifier_type = modifier_params.pop("type") + except KeyError: + raise ValueError("Data modifier type not specified!") from None + return ( + BaseModifier.get_class_by_type(modifier_type) + .get_modifier(modifier_params) + .to(DEVICE) + ) + def prepare_trainer_input_single( model_params_single: dict[str, Any], data_dict_single: dict[str, Any], rank: int = 0, seed: int | None = None, ) -> tuple[DpLoaderSet, DpLoaderSet | None, DPPath | None]: + # get data modifier + modifier = None + modifier_params = model_params_single.get("modifier", None) + if modifier_params is not None: + modifier = get_data_modifier(modifier_params) + training_dataset_params = data_dict_single["training_data"] validation_dataset_params = data_dict_single.get("validation_data", None) validation_systems = ( @@ -145,6 +166,7 @@ def prepare_trainer_input_single( validation_dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, + modifier=modifier, ) if validation_systems else None @@ -154,6 +176,7 @@ def prepare_trainer_input_single( training_dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, + modifier=modifier, ) return ( train_data_single, diff --git a/deepmd/pt/modifier/__init__.py b/deepmd/pt/modifier/__init__.py new file mode 100644 index 0000000000..bfa1540ce9 --- /dev/null +++ b/deepmd/pt/modifier/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_modifier import ( + BaseModifier, +) + +__all__ = [ + "BaseModifier", +] diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py new file mode 100644 index 0000000000..02c393ca86 --- /dev/null +++ b/deepmd/pt/modifier/base_modifier.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.dpmodel.modifier.base_modifier import ( + make_base_modifier, +) +from deepmd.utils.data import ( + DeepmdData, +) + + +class BaseModifier(torch.nn.Module, make_base_modifier()): + def __init__(self) -> None: + """Construct a basic model for different tasks.""" + torch.nn.Module.__init__(self) + + def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + """Modify data. + + Parameters + ---------- + data + Internal data of DeepmdData. + Be a dict, has the following keys + - coord coordinates + - box simulation box + - atype atom types + - find_energy tells if data has energy + - find_force tells if data has force + - find_virial tells if data has virial + - energy energy + - force force + - virial virial + """ + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + get_nframes = None + coord = data["coord"][:get_nframes, :] + if data["box"] is None: + box = None + else: + box = data["box"][:get_nframes, :] + atype = data["atype"][:get_nframes, :] + atype = atype[0] + # nframes = coord.shape[0] + + # implement data modification method in forwrad + tot_e, tot_f, tot_v = self.forward(coord, atype, box, False, None, None) + + if "find_energy" in data and data["find_energy"] == 1.0: + data["energy"] -= tot_e.reshape(data["energy"].shape) + if "find_force" in data and data["find_force"] == 1.0: + data["force"] -= tot_f.reshape(data["force"].shape) + if "find_virial" in data and data["find_virial"] == 1.0: + data["virial"] -= tot_v.reshape(data["virial"].shape) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 24440e19de..9e1025260a 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -350,6 +350,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: if self.finetune_links is not None else False, ) + training_data.preload_and_modify_all_data() + if validation_data is not None: + validation_data.preload_and_modify_all_data() ( self.training_dataloader, self.training_data, @@ -388,6 +391,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: if self.finetune_links is not None else False, ) + training_data[model_key].preload_and_modify_all_data() + if validation_data[model_key] is not None: + validation_data[model_key].preload_and_modify_all_data() ( self.training_dataloader[model_key], self.training_data[model_key], diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index c991e59daa..b9a061c29c 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -25,6 +25,9 @@ DistributedSampler, ) +from deepmd.pt.modifier import ( + BaseModifier, +) from deepmd.pt.utils import ( dp_random, env, @@ -83,6 +86,7 @@ def __init__( type_map: list[str] | None, seed: int | None = None, shuffle: bool = True, + modifier: BaseModifier | None = None, ) -> None: if seed is not None: setup_seed(seed) @@ -94,6 +98,7 @@ def construct_dataset(system: str) -> DeepmdDataSetForLoader: return DeepmdDataSetForLoader( system=system, type_map=type_map, + modifier=modifier, ) self.systems: list[DeepmdDataSetForLoader] = [] @@ -233,6 +238,14 @@ def print_summary( [ss._data_system.pbc for ss in self.systems], ) + def preload_and_modify_all_data(self) -> None: + for system in self.systems: + system.preload_and_modify_all_data() + + # def clear_modified_frame_cache(self) -> None: + # for system in self.systems: + # system.clear_modified_frame_cache() + def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: example = batch[0] diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 481fa04497..995245e462 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -9,6 +9,9 @@ Dataset, ) +from deepmd.pt.modifier import ( + BaseModifier, +) from deepmd.utils.data import ( DataRequirementItem, DeepmdData, @@ -16,16 +19,25 @@ class DeepmdDataSetForLoader(Dataset): - def __init__(self, system: str, type_map: list[str] | None = None) -> None: + def __init__( + self, + system: str, + type_map: list[str] | None = None, + modifier: BaseModifier | None = None, + ) -> None: """Construct DeePMD-style dataset containing frames cross different systems. Args: - systems: Paths to systems. - type_map: Atom types. + - modifier: Data modifier. """ self.system = system self._type_map = type_map - self._data_system = DeepmdData(sys_path=system, type_map=self._type_map) + self.modifier = modifier + self._data_system = DeepmdData( + sys_path=system, type_map=self._type_map, modifier=self.modifier + ) self.mixed_type = self._data_system.mixed_type self._ntypes = self._data_system.get_ntypes() self._natoms = self._data_system.get_natoms() @@ -55,3 +67,9 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N dtype=data_item["dtype"], output_natoms_for_type_sel=data_item["output_natoms_for_type_sel"], ) + + def preload_and_modify_all_data(self) -> None: + self._data_system.preload_and_modify_all_data() + + # def clear_modified_frame_cache(self) -> None: + # self._data_system.clear_modified_frame_cache() diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 287107a7ff..1b22b3d7cf 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -139,6 +139,13 @@ def __init__( # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method self.prefix_sum = np.cumsum(frames_list).tolist() + self.apply_modifier_at_load = True + if self.modifier is not None: + if hasattr(self.modifier, "apply_modifier_at_load"): + self.apply_modifier_at_load = self.modifier.apply_modifier_at_load + # Cache for modified frames when apply_modifier_at_load is True + self._modified_frame_cache = {} + def add( self, key: str, @@ -377,6 +384,14 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray: def get_single_frame(self, index: int) -> dict: """Orchestrates loading a single frame efficiently using memmap.""" + # Check if we have a cached modified frame and apply_modifier_at_load is True + if ( + self.apply_modifier_at_load + and self.modifier is not None + and index in self._modified_frame_cache + ): + return self._modified_frame_cache[index] + if index < 0 or index >= self.nframes: raise IndexError(f"Frame index {index} out of range [0, {self.nframes})") # 1. Find the correct set directory and local frame index @@ -470,8 +485,42 @@ def get_single_frame(self, index: int) -> dict: frame_data["box"] = None frame_data["fid"] = index + + if self.modifier is not None: + # Apply modifier if it exists + self.modifier.modify_data(frame_data, self) + if self.apply_modifier_at_load: + # Cache the modified frame to avoid recomputation + self._modified_frame_cache[index] = frame_data.copy() + return frame_data + def preload_and_modify_all_data(self) -> None: + """Preload all frames and apply modifier to cache them. + + This method is useful when apply_modifier_at_load is True and you want to + avoid applying the modifier repeatedly during training. + """ + if not self.apply_modifier_at_load or self.modifier is None: + return + + log.info("Preloading and modifying all data frames...") + for i in range(self.nframes): + if i not in self._modified_frame_cache: + self.get_single_frame(i) + if (i + 1) % 100 == 0: + log.info(f"Processed {i + 1}/{self.nframes} frames") + log.info("All frames preloaded and modified.") + + # def clear_modified_frame_cache(self) -> None: + # """Clear the modified frame cache. + + # This method is useful when you want to free up memory or force + # recomputation of modified frames. + # """ + # self._modified_frame_cache.clear() + # log.info("Modified frame cache cleared.") + def avg(self, key: str) -> float: """Return the average value of an item.""" if key not in self.data_dict.keys(): From e75e35d492bab194c3f2c134ebaa8eb7ebf8e0fc Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 17 Dec 2025 19:12:27 +0800 Subject: [PATCH 02/16] add UT for data modifier in pt model training --- source/tests/pt/test_data_modifier.py | 232 ++++++++++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 source/tests/pt/test_data_modifier.py diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py new file mode 100644 index 0000000000..4fc4d1825b --- /dev/null +++ b/source/tests/pt/test_data_modifier.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.modifier.base_modifier import ( + BaseModifier, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.argcheck import ( + modifier_args_plugin, +) +from deepmd.utils.data import ( + DeepmdData, +) + + +@modifier_args_plugin.register("random_tester") +def modifier_random_tester() -> list: + return [] + + +@modifier_args_plugin.register("zero_tester") +def modifier_zero_tester() -> list: + return [] + + +@BaseModifier.register("random_tester") +class ModifierRandomTester(BaseModifier): + def __new__(cls) -> BaseModifier: + return super().__new__(cls) + + def __init__(self) -> None: + """Construct a basic model for different tasks.""" + super().__init__() + self.modifier_type = "tester" + + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Modifier", + "type": self.modifier_type, + "@version": 3, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "BaseModifier": + """Deserialize the modifier. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModifier + The deserialized modifier + """ + data = data.copy() + modifier = cls(**data) + return modifier + + def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + """Multiply by a random factor.""" + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + if "find_energy" in data and data["find_energy"] == 1.0: + data["energy"] = data["energy"] * np.random.Generator() + if "find_force" in data and data["find_force"] == 1.0: + data["force"] = data["force"] * np.random.Generator() + if "find_virial" in data and data["find_virial"] == 1.0: + data["virial"] = data["virial"] * np.random.Generator() + + +@BaseModifier.register("zero_tester") +class ModifierZeroTester(BaseModifier): + def __new__(cls) -> BaseModifier: + return super().__new__(cls) + + def __init__(self) -> None: + """Construct a basic model for different tasks.""" + super().__init__() + self.modifier_type = "tester" + + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Modifier", + "type": self.modifier_type, + "@version": 3, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "BaseModifier": + """Deserialize the modifier. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModifier + The deserialized modifier + """ + data = data.copy() + modifier = cls(**data) + return modifier + + def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + """Multiply by a random factor.""" + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + if "find_energy" in data and data["find_energy"] == 1.0: + data["energy"] -= data["energy"] + if "find_force" in data and data["find_force"] == 1.0: + data["force"] -= data["force"] + if "find_virial" in data and data["find_virial"] == 1.0: + data["virial"] -= data["virial"] + + +class TestDataModifier(unittest.TestCase): + def setUp(self) -> None: + """Set up test fixtures.""" + input_json = str(Path(__file__).parent / "water/se_e2_a.json") + with open(input_json, encoding="utf-8") as f: + config = json.load(f) + config["training"]["numb_steps"] = 10 + config["training"]["save_freq"] = 1 + config["learning_rate"]["start_lr"] = 1.0 + config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single") + ] + self.config = config + + def test_init_modify_data(self): + """Ensure modify_data applied.""" + tmp_config = self.config.copy() + # add tester data modifier + tmp_config["model"]["modifier"] = {"type": "zero_tester"} + + # data modification is finished in __init__ + trainer = get_trainer(tmp_config) + + # training data + training_data = trainer.get_data(is_train=True) + # validation data + validation_data = trainer.get_data(is_train=False) + + for dataset in [training_data, validation_data]: + for kw in ["energy", "force"]: + data = to_numpy_array(dataset[1][kw]) + np.testing.assert_allclose(data, np.zeros_like(data)) + + def test_full_modify_data(self): + """Ensure modify_data only applied once.""" + tmp_config = self.config.copy() + # add tester data modifier + tmp_config["model"]["modifier"] = {"type": "random_tester"} + + # data modification is finished in __init__ + trainer = get_trainer(tmp_config) + + # training data + training_data_before = trainer.get_data(is_train=True) + # validation data + validation_data_before = trainer.get_data(is_train=False) + + trainer.run() + + # training data + training_data_after = trainer.get_data(is_train=True) + # validation data + validation_data_after = trainer.get_data(is_train=False) + + for kw in ["energy", "force"]: + np.testing.assert_allclose( + to_numpy_array(training_data_before[1][kw]), + to_numpy_array(training_data_after[1][kw]), + ) + np.testing.assert_allclose( + to_numpy_array(validation_data_before[1][kw]), + to_numpy_array(validation_data_after[1][kw]), + ) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("frozen_model") and f.endswith(".pth"): + os.remove(f) + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "checkpoint"]: + os.remove(f) From 89ba1142866008c79be5b342334bd68d9eda3e31 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 17 Dec 2025 20:50:37 +0800 Subject: [PATCH 03/16] fix minor bugs --- deepmd/pt/modifier/base_modifier.py | 5 ++--- deepmd/pt/utils/dataloader.py | 4 ---- deepmd/pt/utils/dataset.py | 3 --- source/tests/pt/test_data_modifier.py | 6 +++--- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index 02c393ca86..e0ac1d52c1 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -15,6 +15,7 @@ def __init__(self) -> None: torch.nn.Module.__init__(self) def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + # TODO: data_sys parameter is currently unused but may be needed by subclasses in the future """Modify data. Parameters @@ -46,10 +47,8 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None: else: box = data["box"][:get_nframes, :] atype = data["atype"][:get_nframes, :] - atype = atype[0] - # nframes = coord.shape[0] - # implement data modification method in forwrad + # implement data modification method in forward tot_e, tot_f, tot_v = self.forward(coord, atype, box, False, None, None) if "find_energy" in data and data["find_energy"] == 1.0: diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index b9a061c29c..d74d6c5c00 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -242,10 +242,6 @@ def preload_and_modify_all_data(self) -> None: for system in self.systems: system.preload_and_modify_all_data() - # def clear_modified_frame_cache(self) -> None: - # for system in self.systems: - # system.clear_modified_frame_cache() - def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: example = batch[0] diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 995245e462..f99de70f55 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -70,6 +70,3 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N def preload_and_modify_all_data(self) -> None: self._data_system.preload_and_modify_all_data() - - # def clear_modified_frame_cache(self) -> None: - # self._data_system.clear_modified_frame_cache() diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index 4fc4d1825b..98f60580ad 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -88,11 +88,11 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None: return if "find_energy" in data and data["find_energy"] == 1.0: - data["energy"] = data["energy"] * np.random.Generator() + data["energy"] = data["energy"] * np.random.default_rng().random() if "find_force" in data and data["find_force"] == 1.0: - data["force"] = data["force"] * np.random.Generator() + data["force"] = data["force"] * np.random.default_rng().random() if "find_virial" in data and data["find_virial"] == 1.0: - data["virial"] = data["virial"] * np.random.Generator() + data["virial"] = data["virial"] * np.random.default_rng().random() @BaseModifier.register("zero_tester") From 3bdb9f59beb3b38425f7b9a281fe3472b117015d Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 17 Dec 2025 23:50:53 +0800 Subject: [PATCH 04/16] fix bug and add type annotation for BaseModifier --- deepmd/pt/modifier/base_modifier.py | 89 +++++++++++++++++++--- deepmd/utils/data.py | 9 --- source/tests/pt/test_data_modifier.py | 105 ++++++++------------------ 3 files changed, 108 insertions(+), 95 deletions(-) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index e0ac1d52c1..d8b465a661 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -1,9 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + abstractmethod, +) + import torch from deepmd.dpmodel.modifier.base_modifier import ( make_base_modifier, ) +from deepmd.dpmodel.array_api import ( + Array, +) +from deepmd.pt.utils.utils import to_torch_tensor, to_numpy_array from deepmd.utils.data import ( DeepmdData, ) @@ -13,9 +21,55 @@ class BaseModifier(torch.nn.Module, make_base_modifier()): def __init__(self) -> None: """Construct a basic model for different tasks.""" torch.nn.Module.__init__(self) + self.modifier_type = "base" + + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Modifier", + "type": self.modifier_type, + "@version": 3, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "BaseModifier": + """Deserialize the modifier. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModifier + The deserialized modifier + """ + data = data.copy() + modifier = cls(**data) + return modifier - def modify_data(self, data: dict, data_sys: DeepmdData) -> None: - # TODO: data_sys parameter is currently unused but may be needed by subclasses in the future + @abstractmethod + @torch.jit.export + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Compute energy, force, and virial corrections.""" + + @torch.jit.unused + def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: """Modify data. Parameters @@ -25,7 +79,9 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None: Be a dict, has the following keys - coord coordinates - box simulation box - - atype atom types + - atype atom types + - fparam frame parameter + - aparam atom parameter - find_energy tells if data has energy - find_force tells if data has force - find_virial tells if data has virial @@ -41,19 +97,28 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None: return get_nframes = None - coord = data["coord"][:get_nframes, :] + t_coord = to_torch_tensor(data["coord"][:get_nframes, :]) + t_atype = to_torch_tensor(data["atype"][:get_nframes, :]) if data["box"] is None: - box = None + t_box = None else: - box = data["box"][:get_nframes, :] - atype = data["atype"][:get_nframes, :] - + t_box = to_torch_tensor(data["box"][:get_nframes, :]) + if data["fparam"] is None: + t_fparam = None + else: + t_fparam = to_torch_tensor(data["fparam"][:get_nframes, :]) + if data["aparam"] is None: + t_aparam = None + else: + t_aparam = to_torch_tensor(data["aparam"][:get_nframes, :]) + # + # implement data modification method in forward - tot_e, tot_f, tot_v = self.forward(coord, atype, box, False, None, None) + modifier_data = self.forward(t_coord, t_atype, t_box, t_fparam, t_aparam) if "find_energy" in data and data["find_energy"] == 1.0: - data["energy"] -= tot_e.reshape(data["energy"].shape) + data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape(data["energy"].shape) if "find_force" in data and data["find_force"] == 1.0: - data["force"] -= tot_f.reshape(data["force"].shape) + data["force"] -= to_numpy_array(modifier_data["force"]).reshape(data["force"].shape) if "find_virial" in data and data["find_virial"] == 1.0: - data["virial"] -= tot_v.reshape(data["virial"].shape) + data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape(data["virial"].shape) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 1b22b3d7cf..7c56b131c2 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -512,15 +512,6 @@ def preload_and_modify_all_data(self) -> None: log.info(f"Processed {i + 1}/{self.nframes} frames") log.info("All frames preloaded and modified.") - # def clear_modified_frame_cache(self) -> None: - # """Clear the modified frame cache. - - # This method is useful when you want to free up memory or force - # recomputation of modified frames. - # """ - # self._modified_frame_cache.clear() - # log.info("Modified frame cache cleared.") - def avg(self, key: str) -> float: """Return the average value of an item.""" if key not in self.data_dict.keys(): diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index 98f60580ad..9ebdb9aacd 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -7,6 +7,7 @@ ) import numpy as np +import torch from deepmd.pt.entrypoints.main import ( get_trainer, @@ -23,7 +24,9 @@ from deepmd.utils.data import ( DeepmdData, ) - +from deepmd.dpmodel.array_api import ( + Array, +) @modifier_args_plugin.register("random_tester") def modifier_random_tester() -> list: @@ -43,42 +46,19 @@ def __new__(cls) -> BaseModifier: def __init__(self) -> None: """Construct a basic model for different tasks.""" super().__init__() - self.modifier_type = "tester" - - def serialize(self) -> dict: - """Serialize the modifier. - - Returns - ------- - dict - The serialized data - """ - data = { - "@class": "Modifier", - "type": self.modifier_type, - "@version": 3, - } - return data - - @classmethod - def deserialize(cls, data: dict) -> "BaseModifier": - """Deserialize the modifier. - - Parameters - ---------- - data : dict - The serialized data - - Returns - ------- - BaseModifier - The deserialized modifier - """ - data = data.copy() - modifier = cls(**data) - return modifier - - def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + self.modifier_type = "random_tester" + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return {"coord": coord} + + def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: """Multiply by a random factor.""" if ( "find_energy" not in data @@ -86,7 +66,7 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None: and "find_virial" not in data ): return - + if "find_energy" in data and data["find_energy"] == 1.0: data["energy"] = data["energy"] * np.random.default_rng().random() if "find_force" in data and data["find_force"] == 1.0: @@ -103,42 +83,19 @@ def __new__(cls) -> BaseModifier: def __init__(self) -> None: """Construct a basic model for different tasks.""" super().__init__() - self.modifier_type = "tester" - - def serialize(self) -> dict: - """Serialize the modifier. - - Returns - ------- - dict - The serialized data - """ - data = { - "@class": "Modifier", - "type": self.modifier_type, - "@version": 3, - } - return data - - @classmethod - def deserialize(cls, data: dict) -> "BaseModifier": - """Deserialize the modifier. - - Parameters - ---------- - data : dict - The serialized data - - Returns - ------- - BaseModifier - The deserialized modifier - """ - data = data.copy() - modifier = cls(**data) - return modifier - - def modify_data(self, data: dict, data_sys: DeepmdData) -> None: + self.modifier_type = "zero_tester" + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return {"coord": coord} + + def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: """Multiply by a random factor.""" if ( "find_energy" not in data From b093750975b0bade4c5dec0081557413aedc4891 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 17 Dec 2025 23:55:32 +0800 Subject: [PATCH 05/16] minor revision based on coderabbit --- deepmd/pt/modifier/base_modifier.py | 25 +++++++++++++++++-------- deepmd/utils/data.py | 3 ++- source/tests/pt/test_data_modifier.py | 9 +++++---- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index d8b465a661..b093956161 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -5,13 +5,16 @@ import torch +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.modifier.base_modifier import ( make_base_modifier, ) -from deepmd.dpmodel.array_api import ( - Array, +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, ) -from deepmd.pt.utils.utils import to_torch_tensor, to_numpy_array from deepmd.utils.data import ( DeepmdData, ) @@ -111,14 +114,20 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N t_aparam = None else: t_aparam = to_torch_tensor(data["aparam"][:get_nframes, :]) - # - + # + # implement data modification method in forward modifier_data = self.forward(t_coord, t_atype, t_box, t_fparam, t_aparam) if "find_energy" in data and data["find_energy"] == 1.0: - data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape(data["energy"].shape) + data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape( + data["energy"].shape + ) if "find_force" in data and data["find_force"] == 1.0: - data["force"] -= to_numpy_array(modifier_data["force"]).reshape(data["force"].shape) + data["force"] -= to_numpy_array(modifier_data["force"]).reshape( + data["force"].shape + ) if "find_virial" in data and data["find_virial"] == 1.0: - data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape(data["virial"].shape) + data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape( + data["virial"].shape + ) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 7c56b131c2..f5a03a400d 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import bisect +import copy import functools import logging from concurrent.futures import ( @@ -491,7 +492,7 @@ def get_single_frame(self, index: int) -> dict: self.modifier.modify_data(frame_data, self) if self.apply_modifier_at_load: # Cache the modified frame to avoid recomputation - self._modified_frame_cache[index] = frame_data.copy() + self._modified_frame_cache[index] = copy.deepcopy(frame_data) return frame_data diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index 9ebdb9aacd..20fc5a4036 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -9,6 +9,9 @@ import numpy as np import torch +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.pt.entrypoints.main import ( get_trainer, ) @@ -24,9 +27,7 @@ from deepmd.utils.data import ( DeepmdData, ) -from deepmd.dpmodel.array_api import ( - Array, -) + @modifier_args_plugin.register("random_tester") def modifier_random_tester() -> list: @@ -66,7 +67,7 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N and "find_virial" not in data ): return - + if "find_energy" in data and data["find_energy"] == 1.0: data["energy"] = data["energy"] * np.random.default_rng().random() if "find_force" in data and data["find_force"] == 1.0: From b60a1e69e820a0c07a3e4457c2e796c0179226ac Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 18 Dec 2025 11:57:26 +0800 Subject: [PATCH 06/16] fix bug in UT --- deepmd/pt/modifier/base_modifier.py | 2 +- source/tests/pt/test_data_modifier.py | 217 +++++++++++++++++++++++--- 2 files changed, 194 insertions(+), 25 deletions(-) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index b093956161..341bcbdf9c 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -101,7 +101,7 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N get_nframes = None t_coord = to_torch_tensor(data["coord"][:get_nframes, :]) - t_atype = to_torch_tensor(data["atype"][:get_nframes, :]) + t_atype = to_torch_tensor(data["atype"]) if data["box"] is None: t_box = None else: diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index 20fc5a4036..c837484480 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -1,4 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Test data modification functionality in DeepMD. + +This module tests the data modification capabilities of DeepMD, specifically +testing the BaseModifier implementations and their effects on training and +validation data. It includes: + +1. Test modifier implementations (random_tester and zero_tester) +2. Tests to verify data modification is applied correctly +3. Tests to ensure data modification is only applied once + +The tests use parameterized testing with different batch sizes for training +and validation data. +""" + import json import os import unittest @@ -28,26 +42,54 @@ DeepmdData, ) +from ..consistent.common import ( + parameterized, +) + @modifier_args_plugin.register("random_tester") def modifier_random_tester() -> list: + """Return empty argument list for random_tester modifier. + + This function registers the argument schema for the random_tester modifier. + + Returns + ------- + list: Empty list indicating no additional arguments required + """ return [] @modifier_args_plugin.register("zero_tester") def modifier_zero_tester() -> list: + """Return empty argument list for zero_tester modifier. + + This function registers the argument schema for the zero_tester modifier. + + Returns + ------- + list: Empty list indicating no additional arguments required + """ return [] @BaseModifier.register("random_tester") class ModifierRandomTester(BaseModifier): def __new__(cls) -> BaseModifier: + """Create a new instance of ModifierRandomTester. + + Returns + ------- + BaseModifier: New instance of the modifier + """ return super().__new__(cls) def __init__(self) -> None: """Construct a basic model for different tasks.""" super().__init__() self.modifier_type = "random_tester" + # Use a fixed seed for deterministic behavior + self.rng = np.random.default_rng(12345) # Fixed seed for reproducibility def forward( self, @@ -57,10 +99,11 @@ def forward( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: + """Implementation of abstractmethod.""" return {"coord": coord} def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: - """Multiply by a random factor.""" + """Multiply by a deterministic factor for testing.""" if ( "find_energy" not in data and "find_force" not in data @@ -69,16 +112,22 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N return if "find_energy" in data and data["find_energy"] == 1.0: - data["energy"] = data["energy"] * np.random.default_rng().random() + data["energy"] = data["energy"] * self.rng.random() if "find_force" in data and data["find_force"] == 1.0: - data["force"] = data["force"] * np.random.default_rng().random() + data["force"] = data["force"] * self.rng.random() if "find_virial" in data and data["find_virial"] == 1.0: - data["virial"] = data["virial"] * np.random.default_rng().random() + data["virial"] = data["virial"] * self.rng.random() @BaseModifier.register("zero_tester") class ModifierZeroTester(BaseModifier): def __new__(cls) -> BaseModifier: + """Create a new instance of ModifierZeroTester. + + Returns + ------- + BaseModifier: New instance of the modifier + """ return super().__new__(cls) def __init__(self) -> None: @@ -94,10 +143,11 @@ def forward( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: + """Implementation of abstractmethod.""" return {"coord": coord} def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: - """Multiply by a random factor.""" + """Zero out energy, force, and virial data.""" if ( "find_energy" not in data and "find_force" not in data @@ -113,23 +163,134 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N data["virial"] -= data["virial"] +@parameterized( + (1, 2), # training data batch_size + (1, 2), # validation data batch_size +) class TestDataModifier(unittest.TestCase): def setUp(self) -> None: """Set up test fixtures.""" input_json = str(Path(__file__).parent / "water/se_e2_a.json") + training_data = [ + str(Path(__file__).parent / "water/data/data_0"), + str(Path(__file__).parent / "water/data/data_1"), + ] + validation_data = [str(Path(__file__).parent / "water/data/data_1")] with open(input_json, encoding="utf-8") as f: config = json.load(f) - config["training"]["numb_steps"] = 10 + config["training"]["numb_steps"] = 1 config["training"]["save_freq"] = 1 config["learning_rate"]["start_lr"] = 1.0 - config["training"]["training_data"]["systems"] = [ - str(Path(__file__).parent / "water/data/single") - ] - config["training"]["validation_data"]["systems"] = [ - str(Path(__file__).parent / "water/data/single") - ] + config["training"]["training_data"]["systems"] = training_data + config["training"]["training_data"]["batch_size"] = self.param[0] + config["training"]["validation_data"]["systems"] = validation_data + config["training"]["validation_data"]["batch_size"] = self.param[1] self.config = config + self.training_nframes = self.get_dataset_nframes(training_data) + self.validation_nframes = self.get_dataset_nframes(validation_data) + + @staticmethod + def get_dataset_nframes(dataset: list[str]) -> int: + """Calculate total number of frames in a dataset. + + Args: + dataset: List of dataset paths + + Returns + ------- + int: Total number of frames across all datasets + """ + nframes = 0 + for _data in dataset: + _dpdata = DeepmdData(_data) + nframes += _dpdata.nframes + return nframes + + @staticmethod + def get_sampled_data(trainer, nbatch: int, is_train: bool): + """ + Collect all data from trainer and organize by IDs for easy comparison. + + Args: + trainer: The trainer object + nbatch: Number of batches to iterate through + is_train: Whether to get training data (True) or validation data (False) + + Returns + ------- + dict: A nested dictionary organized by system_id and frame_id + Format: {system_id: {frame_id: label_dict}} + """ + output = {} + # Keep track of all unique frames we've collected + collected_frames = set() + + # Continue collecting data until we've gone through all batches + for _ in range(nbatch): + input_dict, label_dict, log_dict = trainer.get_data(is_train=is_train) + + system_id = log_dict["sid"] + frame_ids = log_dict["fid"] + + # Initialize system entry if not exists + if system_id not in output: + output[system_id] = {} + + # Store label data for each frame ID + for idx, frame_id in enumerate(frame_ids): + # Skip if we already have this frame + frame_key = (system_id, frame_id) + if frame_key in collected_frames: + continue + + # Create a copy of label_dict for this specific frame + frame_data = {} + for key, value in label_dict.items(): + # If value is a tensor/array with batch dimension, extract the specific frame + if ( + hasattr(value, "shape") + and len(value.shape) > 0 + and value.shape[0] == len(frame_ids) + ): + # Handle batched data - extract the data for this frame + frame_data[key] = value[idx] + else: + # For scalar values or non-batched data, just copy as is + frame_data[key] = value + + output[system_id][frame_id] = frame_data + collected_frames.add(frame_key) + + return output + + @staticmethod + def check_sampled_data( + ref_data: dict[int, dict], test_data: dict[int, dict], label_kw: str + ): + """Compare sampled data between reference and test datasets. + + Args: + ref_data: Reference data dictionary organized by system and frame IDs + test_data: Test data dictionary organized by system and frame IDs + label_kw: Key of the label to compare (e.g., "energy", "force") + + Raises + ------ + AssertionError: If the data doesn't match between reference and test + """ + for sid in ref_data.keys(): + ref_sys = ref_data[sid] + test_sys = test_data[sid] + for fid in ref_sys.keys(): + # compare common elements + try: + ref_label = to_numpy_array(ref_sys[fid][label_kw]) + test_label = to_numpy_array(test_sys[fid][label_kw]) + except KeyError: + continue + np.testing.assert_allclose(ref_label, test_label) + def test_init_modify_data(self): """Ensure modify_data applied.""" tmp_config = self.config.copy() @@ -159,28 +320,36 @@ def test_full_modify_data(self): trainer = get_trainer(tmp_config) # training data - training_data_before = trainer.get_data(is_train=True) + training_data_before = self.get_sampled_data( + trainer, self.training_nframes, True + ) # validation data - validation_data_before = trainer.get_data(is_train=False) + validation_data_before = self.get_sampled_data( + trainer, self.validation_nframes, False + ) trainer.run() # training data - training_data_after = trainer.get_data(is_train=True) + training_data_after = self.get_sampled_data( + trainer, self.training_nframes, True + ) # validation data - validation_data_after = trainer.get_data(is_train=False) + validation_data_after = self.get_sampled_data( + trainer, self.validation_nframes, False + ) - for kw in ["energy", "force"]: - np.testing.assert_allclose( - to_numpy_array(training_data_before[1][kw]), - to_numpy_array(training_data_after[1][kw]), - ) - np.testing.assert_allclose( - to_numpy_array(validation_data_before[1][kw]), - to_numpy_array(validation_data_after[1][kw]), + for label_kw in ["energy", "force"]: + self.check_sampled_data(training_data_before, training_data_after, label_kw) + self.check_sampled_data( + validation_data_before, validation_data_after, label_kw ) def tearDown(self) -> None: + """Clean up test artifacts after each test. + + Removes model files and other artifacts created during testing. + """ for f in os.listdir("."): if f.startswith("frozen_model") and f.endswith(".pth"): os.remove(f) From ddb689bcdfeab67d79b4514f3f8ebb775bd282cc Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 18 Dec 2025 18:25:14 +0800 Subject: [PATCH 07/16] fix bug in BaseModifier.modify_data --- deepmd/pt/entrypoints/main.py | 16 +---- deepmd/pt/modifier/__init__.py | 15 +++++ deepmd/pt/modifier/base_modifier.py | 84 +++++++++++++++++++-------- source/tests/pt/test_data_modifier.py | 12 ++-- 4 files changed, 85 insertions(+), 42 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 08346b8027..138ef8daa0 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -48,7 +48,7 @@ BaseModel, ) from deepmd.pt.modifier import ( - BaseModifier, + get_data_modifier, ) from deepmd.pt.train import ( training, @@ -108,18 +108,6 @@ def get_trainer( ) -> training.Trainer: multi_task = "model_dict" in config.get("model", {}) - def get_data_modifier(_modifier_params: dict[str, Any]) -> BaseModifier: - modifier_params = copy.deepcopy(_modifier_params) - try: - modifier_type = modifier_params.pop("type") - except KeyError: - raise ValueError("Data modifier type not specified!") from None - return ( - BaseModifier.get_class_by_type(modifier_type) - .get_modifier(modifier_params) - .to(DEVICE) - ) - def prepare_trainer_input_single( model_params_single: dict[str, Any], data_dict_single: dict[str, Any], @@ -130,7 +118,7 @@ def prepare_trainer_input_single( modifier = None modifier_params = model_params_single.get("modifier", None) if modifier_params is not None: - modifier = get_data_modifier(modifier_params) + modifier = get_data_modifier(modifier_params).to(DEVICE) training_dataset_params = data_dict_single["training_data"] validation_dataset_params = data_dict_single.get("validation_data", None) diff --git a/deepmd/pt/modifier/__init__.py b/deepmd/pt/modifier/__init__.py index bfa1540ce9..71d847bcbc 100644 --- a/deepmd/pt/modifier/__init__.py +++ b/deepmd/pt/modifier/__init__.py @@ -1,8 +1,23 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + Any, +) + from .base_modifier import ( BaseModifier, ) __all__ = [ "BaseModifier", + "get_data_modifier", ] + + +def get_data_modifier(_modifier_params: dict[str, Any]) -> BaseModifier: + modifier_params = copy.deepcopy(_modifier_params) + try: + modifier_type = modifier_params.pop("type") + except KeyError: + raise ValueError("Data modifier type not specified!") from None + return BaseModifier.get_class_by_type(modifier_type).get_modifier(modifier_params) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index 341bcbdf9c..e57f6d934b 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -3,14 +3,21 @@ abstractmethod, ) +import numpy as np import torch from deepmd.dpmodel.array_api import ( Array, ) +from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT from deepmd.dpmodel.modifier.base_modifier import ( make_base_modifier, ) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, + RESERVED_PRECISION_DICT, +) from deepmd.pt.utils.utils import ( to_numpy_array, to_torch_tensor, @@ -25,6 +32,7 @@ def __init__(self) -> None: """Construct a basic model for different tasks.""" torch.nn.Module.__init__(self) self.modifier_type = "base" + self.jitable = True def serialize(self) -> dict: """Serialize the modifier. @@ -56,6 +64,10 @@ def deserialize(cls, data: dict) -> "BaseModifier": The deserialized modifier """ data = data.copy() + # Remove serialization metadata before passing to constructor + data.pop("@class", None) + data.pop("type", None) + data.pop("@version", None) modifier = cls(**data) return modifier @@ -68,29 +80,30 @@ def forward( box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: """Compute energy, force, and virial corrections.""" @torch.jit.unused def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: - """Modify data. + """Modify data of single frame. Parameters ---------- data Internal data of DeepmdData. Be a dict, has the following keys - - coord coordinates - - box simulation box - - atype atom types - - fparam frame parameter - - aparam atom parameter + - coord coordinates (nat, 3) + - box simulation box (9,) + - atype atom types (nat,) + - fparam frame parameter (nfp,) + - aparam atom parameter (nat, nap) - find_energy tells if data has energy - find_force tells if data has force - find_virial tells if data has virial - - energy energy - - force force - - virial virial + - energy energy (1,) + - force force (nat, 3) + - virial virial (9,) """ if ( "find_energy" not in data @@ -99,25 +112,50 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N ): return - get_nframes = None - t_coord = to_torch_tensor(data["coord"][:get_nframes, :]) - t_atype = to_torch_tensor(data["atype"]) - if data["box"] is None: - t_box = None + # model = self.dp.to(DEVICE) + prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]] + + nframes = 1 + natoms = len(data["atype"]) + atom_types = np.tile(data["atype"], nframes).reshape(nframes, -1) + + coord_input = torch.tensor( + data["coord"].reshape([nframes, natoms, 3]).astype(prec), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor( + atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISION_DICT[torch.long]]), + dtype=torch.long, + device=DEVICE, + ) + if data["box"] is not None: + box_input = torch.tensor( + data["box"].reshape([nframes, 3, 3]).astype(prec), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) else: - t_box = to_torch_tensor(data["box"][:get_nframes, :]) - if data["fparam"] is None: - t_fparam = None + box_input = None + if "fparam" in data: + fparam_input = to_torch_tensor(data["fparam"].reshape(nframes, -1)) else: - t_fparam = to_torch_tensor(data["fparam"][:get_nframes, :]) - if data["aparam"] is None: - t_aparam = None + fparam_input = None + if "aparam" in data: + aparam_input = to_torch_tensor(data["aparam"].reshape(nframes, natoms, -1)) else: - t_aparam = to_torch_tensor(data["aparam"][:get_nframes, :]) - # + aparam_input = None + do_atomic_virial = False # implement data modification method in forward - modifier_data = self.forward(t_coord, t_atype, t_box, t_fparam, t_aparam) + modifier_data = self.forward( + coord_input, + type_input, + box_input, + fparam_input, + aparam_input, + do_atomic_virial, + ) if "find_energy" in data and data["find_energy"] == 1.0: data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape( diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index c837484480..bf0c8390b9 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Test data modification functionality in DeepMD. +"""Test data modification functionality. -This module tests the data modification capabilities of DeepMD, specifically +This module tests the data modification functionality, specifically testing the BaseModifier implementations and their effects on training and validation data. It includes: @@ -98,9 +98,10 @@ def forward( box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: """Implementation of abstractmethod.""" - return {"coord": coord} + return {} def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: """Multiply by a deterministic factor for testing.""" @@ -142,9 +143,10 @@ def forward( box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: """Implementation of abstractmethod.""" - return {"coord": coord} + return {} def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: """Zero out energy, force, and virial data.""" @@ -228,7 +230,7 @@ def get_sampled_data(trainer, nbatch: int, is_train: bool): # Continue collecting data until we've gone through all batches for _ in range(nbatch): - input_dict, label_dict, log_dict = trainer.get_data(is_train=is_train) + _, label_dict, log_dict = trainer.get_data(is_train=is_train) system_id = log_dict["sid"] frame_ids = log_dict["fid"] From f74752ce980e4120d2bce0c3c06c4b579767210c Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Fri, 19 Dec 2025 16:39:54 +0800 Subject: [PATCH 08/16] feat(pt): support data modifier in inference and frozen models - Add data modifier support in model inference pipeline - Enable saving and loading data modifiers with frozen models - Add ModifierScalingTester for scaling model predictions as data modification - Update test cases to verify data modifier functionality in inference - Enhance modifier argument registration with documentation This allows data modifiers to be applied during model inference and preserves them when saving frozen models for consistent behavior across training and inference stages. --- deepmd/pt/entrypoints/main.py | 19 +- deepmd/pt/infer/deep_eval.py | 22 +- deepmd/pt/infer/inference.py | 9 + deepmd/pt/train/wrapper.py | 7 + source/tests/pt/test_data_modifier.py | 317 ++++++++++++++++++-------- 5 files changed, 269 insertions(+), 105 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 138ef8daa0..f3abb3d4d2 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -4,6 +4,7 @@ import json import logging import os +import tempfile from pathlib import ( Path, ) @@ -383,10 +384,24 @@ def freeze( output: str = "frozen_model.pth", head: str | None = None, ) -> None: - model = inference.Tester(model, head=head).model + tester = inference.Tester(model, head=head) + model = tester.model model.eval() model = torch.jit.script(model) - extra_files = {} + + dm_output = "data_modifier.pth" + extra_files = {dm_output: ""} + if tester.modifier is not None: + dm = tester.modifier + dm.eval() + with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file: + torch.jit.save( + torch.jit.script(dm), + tmp_file, + ) + with open(tmp_file.name, "rb") as f: + extra_files = {dm_output: f.read()} + os.unlink(tmp_file.name) # Clean up the temporary file torch.jit.save( model, output, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 2726b61152..e718da9352 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json import logging +import os +import tempfile from collections.abc import ( Callable, ) @@ -171,8 +173,24 @@ def __init__( self.dp = ModelWrapper(model) self.dp.load_state_dict(state_dict) elif str(self.model_path).endswith(".pth"): - model = torch.jit.load(model_file, map_location=env.DEVICE) - self.dp = ModelWrapper(model) + extra_files = {"data_modifier.pth": ""} + model = torch.jit.load( + model_file, map_location=env.DEVICE, _extra_files=extra_files + ) + modifier = None + # Load modifier if it exists in extra_files + if len(extra_files["data_modifier.pth"]) > 0: + # Save the extra file content to a temporary file + with tempfile.NamedTemporaryFile( + suffix=".pth", delete=False + ) as tmp_file: + tmp_file.write(extra_files["data_modifier.pth"]) + tmp_file_path = tmp_file.name + # Load the modifier from the temporary file + modifier = torch.jit.load(tmp_file_path, map_location=env.DEVICE) + os.unlink(tmp_file_path) # Clean up the temporary file + self.dp = ModelWrapper(model, modifier=modifier) + self.modifier = modifier model_def_script = self.dp.model["Default"].get_model_def_script() if model_def_script: self.model_def_script = json.loads(model_def_script) diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index 4c49abeef8..424a79272c 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -9,6 +9,9 @@ from deepmd.pt.model.model import ( get_model, ) +from deepmd.pt.modifier import ( + get_data_modifier, +) from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -60,9 +63,15 @@ def __init__( ) # wrapper Hessian to Energy model due to JIT limit self.model_params = deepcopy(model_params) self.model = get_model(model_params).to(DEVICE) + self.modifier = None + if "modifier" in model_params: + modifier = get_data_modifier(model_params["modifier"]).to(DEVICE) + if modifier.jitable: + self.modifier = modifier # Model Wrapper self.wrapper = ModelWrapper(self.model) # inference only + # , modifier=self.modifier if JIT: self.wrapper = torch.jit.script(self.wrapper) self.wrapper.load_state_dict(state_dict) diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 2669e3d832..a2539c6ff6 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -20,6 +20,7 @@ def __init__( loss: torch.nn.Module | dict = None, model_params: dict[str, Any] | None = None, shared_links: dict[str, Any] | None = None, + modifier: torch.nn.Module | None = None, ) -> None: """Construct a DeePMD model wrapper. @@ -57,6 +58,8 @@ def __init__( ) self.loss[task_key] = loss[task_key] self.inference_only = self.loss is None + # Modifier + self.modifier = modifier def share_params( self, @@ -185,6 +188,10 @@ def forward( if self.inference_only or inference_only: model_pred = self.model[task_key](**input_dict) + if self.modifier is not None: + modifier_pred = self.modifier(**input_dict) + for k, v in modifier_pred.items(): + model_pred[k] = model_pred[k] - v return model_pred, None, None else: natoms = atype.shape[-1] diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index bf0c8390b9..9282aefc5f 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -22,16 +22,26 @@ import numpy as np import torch +from dargs import ( + Argument, +) from deepmd.dpmodel.array_api import ( Array, ) +from deepmd.infer import ( + DeepEval, +) from deepmd.pt.entrypoints.main import ( + freeze, get_trainer, ) from deepmd.pt.modifier.base_modifier import ( BaseModifier, ) +from deepmd.pt.utils import ( + env, +) from deepmd.pt.utils.utils import ( to_numpy_array, ) @@ -46,50 +56,48 @@ parameterized, ) +doc_random_tester = "A test modifier that multiplies energy, force, and virial data by a deterministic random factor." +doc_zero_tester = "A test modifier that zeros out energy, force, and virial data by subtracting their original values." +doc_scaling_tester = "A test modifier that applies scaled model predictions as data modification using a frozen model." -@modifier_args_plugin.register("random_tester") -def modifier_random_tester() -> list: - """Return empty argument list for random_tester modifier. - - This function registers the argument schema for the random_tester modifier. - Returns - ------- - list: Empty list indicating no additional arguments required - """ - return [] +@modifier_args_plugin.register("random_tester", doc=doc_random_tester) +def modifier_random_tester() -> list: + doc_seed = "Random seed used as the scaling factor." + return [ + Argument("seed", int, optional=True, doc=doc_seed), + ] -@modifier_args_plugin.register("zero_tester") +@modifier_args_plugin.register("zero_tester", doc=doc_zero_tester) def modifier_zero_tester() -> list: - """Return empty argument list for zero_tester modifier. + return [] - This function registers the argument schema for the zero_tester modifier. - Returns - ------- - list: Empty list indicating no additional arguments required - """ - return [] +@modifier_args_plugin.register("scaling_tester", doc=doc_scaling_tester) +def modifier_scaling_tester() -> list[Argument]: + doc_model_name = "The name of the frozen energy model file." + doc_sfactor = "The scaling factor for correction." + return [ + Argument("model_name", str, optional=False, doc=doc_model_name), + Argument("sfactor", float, optional=False, doc=doc_sfactor), + ] @BaseModifier.register("random_tester") class ModifierRandomTester(BaseModifier): - def __new__(cls) -> BaseModifier: - """Create a new instance of ModifierRandomTester. - - Returns - ------- - BaseModifier: New instance of the modifier - """ + def __new__(cls, *args, **kwargs): return super().__new__(cls) - def __init__(self) -> None: + def __init__( + self, + seed: int = 1, + ) -> None: """Construct a basic model for different tasks.""" super().__init__() self.modifier_type = "random_tester" # Use a fixed seed for deterministic behavior - self.rng = np.random.default_rng(12345) # Fixed seed for reproducibility + self.rng = np.random.default_rng(seed) def forward( self, @@ -122,13 +130,7 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N @BaseModifier.register("zero_tester") class ModifierZeroTester(BaseModifier): - def __new__(cls) -> BaseModifier: - """Create a new instance of ModifierZeroTester. - - Returns - ------- - BaseModifier: New instance of the modifier - """ + def __new__(cls, *args, **kwargs): return super().__new__(cls) def __init__(self) -> None: @@ -165,6 +167,48 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N data["virial"] -= data["virial"] +@BaseModifier.register("scaling_tester") +class ModifierScalingTester(BaseModifier): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + model_name: str, + sfactor: float = 1.0, + ) -> None: + """Construct a basic model for different tasks.""" + super().__init__() + self.modifier_type = "scaling_tester" + self.model_name = model_name + self.sfactor = sfactor + self.model = torch.jit.load(model_name, map_location=env.DEVICE) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Take scaled model prediction as data modification.""" + model_pred = self.model( + coord=coord, + atype=atype, + box=box, + do_atomic_virial=do_atomic_virial, + fparam=fparam, + aparam=aparam, + ) + if isinstance(model_pred, tuple): + model_pred = model_pred[0] + for k in ["energy", "force", "virial"]: + model_pred[k] = model_pred[k] * self.sfactor + return model_pred + + @parameterized( (1, 2), # training data batch_size (1, 2), # validation data batch_size @@ -192,6 +236,144 @@ def setUp(self) -> None: self.training_nframes = self.get_dataset_nframes(training_data) self.validation_nframes = self.get_dataset_nframes(validation_data) + def test_init_modify_data(self): + """Ensure modify_data applied.""" + tmp_config = self.config.copy() + # add tester data modifier + tmp_config["model"]["modifier"] = {"type": "zero_tester"} + + # data modification is finished in __init__ + trainer = get_trainer(tmp_config) + + # training data + training_data = trainer.get_data(is_train=True) + # validation data + validation_data = trainer.get_data(is_train=False) + + for dataset in [training_data, validation_data]: + for kw in ["energy", "force"]: + data = to_numpy_array(dataset[1][kw]) + np.testing.assert_allclose(data, np.zeros_like(data)) + + def test_full_modify_data(self): + """Ensure modify_data only applied once.""" + tmp_config = self.config.copy() + # add tester data modifier + tmp_config["model"]["modifier"] = { + "type": "random_tester", + "seed": 1024, + } + + # data modification is finished in __init__ + trainer = get_trainer(tmp_config) + + # training data + training_data_before = self.get_sampled_data( + trainer, self.training_nframes, True + ) + # validation data + validation_data_before = self.get_sampled_data( + trainer, self.validation_nframes, False + ) + + trainer.run() + + # training data + training_data_after = self.get_sampled_data( + trainer, self.training_nframes, True + ) + # validation data + validation_data_after = self.get_sampled_data( + trainer, self.validation_nframes, False + ) + + for label_kw in ["energy", "force"]: + self.check_sampled_data(training_data_before, training_data_after, label_kw) + self.check_sampled_data( + validation_data_before, validation_data_after, label_kw + ) + + def test_inference(self): + """Test the inference with data modification by verifying scaled model predictions are properly applied.""" + # Generate frozen energy model used in modifier + trainer = get_trainer(self.config) + trainer.run() + freeze("model.ckpt.pt", "frozen_model_dm.pth") + + tmp_config = self.config.copy() + sfactor = np.random.default_rng(1).random() + # add tester data modifier + tmp_config["model"]["modifier"] = { + "type": "scaling_tester", + "model_name": "frozen_model_dm.pth", + "sfactor": sfactor, + } + + trainer = get_trainer(tmp_config) + trainer.run() + freeze("model.ckpt.pt", "frozen_model.pth") + + cell = np.array( + [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + ).reshape(1, 3, 3) + coord = np.array( + [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + ).reshape(1, -1, 3) + atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) + + model = DeepEval("frozen_model.pth") + modifier = DeepEval("frozen_model_dm.pth") + # model inference without modifier + model_ref = DeepEval("model.ckpt.pt") + + model_pred = model.eval(coord, cell, atype) + modifier_pred = modifier.eval(coord, cell, atype) + model_pred_ref = model_ref.eval(coord, cell, atype) + # expected: output_model - sfactor * output_modifier + for ii in range(3): + np.testing.assert_allclose( + model_pred[ii], model_pred_ref[ii] - sfactor * modifier_pred[ii] + ) + + def tearDown(self) -> None: + """Clean up test artifacts after each test. + + Removes model files and other artifacts created during testing. + """ + for f in os.listdir("."): + if f.startswith("frozen_model") and f.endswith(".pth"): + os.remove(f) + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "checkpoint"]: + os.remove(f) + @staticmethod def get_dataset_nframes(dataset: list[str]) -> int: """Calculate total number of frames in a dataset. @@ -292,70 +474,3 @@ def check_sampled_data( except KeyError: continue np.testing.assert_allclose(ref_label, test_label) - - def test_init_modify_data(self): - """Ensure modify_data applied.""" - tmp_config = self.config.copy() - # add tester data modifier - tmp_config["model"]["modifier"] = {"type": "zero_tester"} - - # data modification is finished in __init__ - trainer = get_trainer(tmp_config) - - # training data - training_data = trainer.get_data(is_train=True) - # validation data - validation_data = trainer.get_data(is_train=False) - - for dataset in [training_data, validation_data]: - for kw in ["energy", "force"]: - data = to_numpy_array(dataset[1][kw]) - np.testing.assert_allclose(data, np.zeros_like(data)) - - def test_full_modify_data(self): - """Ensure modify_data only applied once.""" - tmp_config = self.config.copy() - # add tester data modifier - tmp_config["model"]["modifier"] = {"type": "random_tester"} - - # data modification is finished in __init__ - trainer = get_trainer(tmp_config) - - # training data - training_data_before = self.get_sampled_data( - trainer, self.training_nframes, True - ) - # validation data - validation_data_before = self.get_sampled_data( - trainer, self.validation_nframes, False - ) - - trainer.run() - - # training data - training_data_after = self.get_sampled_data( - trainer, self.training_nframes, True - ) - # validation data - validation_data_after = self.get_sampled_data( - trainer, self.validation_nframes, False - ) - - for label_kw in ["energy", "force"]: - self.check_sampled_data(training_data_before, training_data_after, label_kw) - self.check_sampled_data( - validation_data_before, validation_data_after, label_kw - ) - - def tearDown(self) -> None: - """Clean up test artifacts after each test. - - Removes model files and other artifacts created during testing. - """ - for f in os.listdir("."): - if f.startswith("frozen_model") and f.endswith(".pth"): - os.remove(f) - if f.startswith("model") and f.endswith(".pt"): - os.remove(f) - if f in ["lcurve.out", "checkpoint"]: - os.remove(f) From a8932c154943ee50bbe6bd7e4dc3644871216970 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 24 Dec 2025 11:24:17 +0800 Subject: [PATCH 09/16] Apply suggestion from @Copilot and @njzjz --- deepmd/pt/entrypoints/main.py | 16 +++++----- deepmd/pt/infer/deep_eval.py | 18 +++++------ deepmd/pt/infer/inference.py | 1 - deepmd/pt/modifier/base_modifier.py | 23 +++++++++++--- deepmd/pt/train/training.py | 43 +++++++++++++++++---------- source/tests/pt/test_data_modifier.py | 24 ++++++++------- 6 files changed, 74 insertions(+), 51 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index f3abb3d4d2..75efdd8c9f 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse import copy +import io import json import logging import os -import tempfile from pathlib import ( Path, ) @@ -394,14 +394,12 @@ def freeze( if tester.modifier is not None: dm = tester.modifier dm.eval() - with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file: - torch.jit.save( - torch.jit.script(dm), - tmp_file, - ) - with open(tmp_file.name, "rb") as f: - extra_files = {dm_output: f.read()} - os.unlink(tmp_file.name) # Clean up the temporary file + buffer = io.BytesIO() + torch.jit.save( + torch.jit.script(dm), + buffer, + ) + extra_files = {dm_output: buffer.getvalue()} torch.jit.save( model, output, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index e718da9352..6e63ecb2fc 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import io import json import logging -import os -import tempfile from collections.abc import ( Callable, ) @@ -180,15 +179,12 @@ def __init__( modifier = None # Load modifier if it exists in extra_files if len(extra_files["data_modifier.pth"]) > 0: - # Save the extra file content to a temporary file - with tempfile.NamedTemporaryFile( - suffix=".pth", delete=False - ) as tmp_file: - tmp_file.write(extra_files["data_modifier.pth"]) - tmp_file_path = tmp_file.name - # Load the modifier from the temporary file - modifier = torch.jit.load(tmp_file_path, map_location=env.DEVICE) - os.unlink(tmp_file_path) # Clean up the temporary file + # Create a file-like object from the in-memory data + modifier_data = extra_files["data_modifier.pth"] + if isinstance(modifier_data, bytes): + modifier_data = io.BytesIO(modifier_data) + # Load the modifier directly from the file-like object + modifier = torch.jit.load(modifier_data, map_location=env.DEVICE) self.dp = ModelWrapper(model, modifier=modifier) self.modifier = modifier model_def_script = self.dp.model["Default"].get_model_def_script() diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index 424a79272c..b026cd54c5 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -71,7 +71,6 @@ def __init__( # Model Wrapper self.wrapper = ModelWrapper(self.model) # inference only - # , modifier=self.modifier if JIT: self.wrapper = torch.jit.script(self.wrapper) self.wrapper.load_state_dict(state_dict) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index e57f6d934b..f61fbd5d29 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -29,7 +29,7 @@ class BaseModifier(torch.nn.Module, make_base_modifier()): def __init__(self) -> None: - """Construct a basic model for different tasks.""" + """Construct a base modifier for data modification tasks.""" torch.nn.Module.__init__(self) self.modifier_type = "base" self.jitable = True @@ -157,15 +157,30 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N do_atomic_virial, ) - if "find_energy" in data and data["find_energy"] == 1.0: + if data.get("find_energy") == 1.0: + if "energy" not in modifier_data: + raise KeyError( + f"Modifier {self.__class__.__name__} did not provide 'energy' " + "in its output while 'find_energy' is set." + ) data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape( data["energy"].shape ) - if "find_force" in data and data["find_force"] == 1.0: + if data.get("find_force") == 1.0: + if "force" not in modifier_data: + raise KeyError( + f"Modifier {self.__class__.__name__} did not provide 'force' " + "in its output while 'find_force' is set." + ) data["force"] -= to_numpy_array(modifier_data["force"]).reshape( data["force"].shape ) - if "find_virial" in data and data["find_virial"] == 1.0: + if data.get("find_virial") == 1.0: + if "virial" not in modifier_data: + raise KeyError( + f"Modifier {self.__class__.__name__} did not provide 'virial' " + "in its output while 'find_virial' is set." + ) data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape( data["virial"].shape ) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 9e1025260a..6f0cff86a1 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -245,15 +245,13 @@ def single_model_stat( _model: Any, _data_stat_nbatch: int, _training_data: DpLoaderSet, - _validation_data: DpLoaderSet | None, _stat_file_path: str | None, - _data_requirement: list[DataRequirementItem], finetune_has_new_type: bool = False, ) -> Callable[[], Any]: - _data_requirement += get_additional_data_requirement(_model) - _training_data.add_data_requirement(_data_requirement) - if _validation_data is not None: - _validation_data.add_data_requirement(_data_requirement) + # _data_requirement += get_additional_data_requirement(_model) + # _training_data.add_data_requirement(_data_requirement) + # if _validation_data is not None: + # _validation_data.add_data_requirement(_data_requirement) @functools.lru_cache def get_sample() -> Any: @@ -339,20 +337,25 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: # Data if not self.multi_task: + # add data requirement for labels + data_requirement = self.loss.label_requirement + data_requirement += get_additional_data_requirement(self.model) + training_data.add_data_requirement(data_requirement) + if validation_data is not None: + validation_data.add_data_requirement(data_requirement) + # data modification + training_data.preload_and_modify_all_data() + if validation_data is not None: + validation_data.preload_and_modify_all_data() self.get_sample_func = single_model_stat( self.model, model_params.get("data_stat_nbatch", 10), training_data, - validation_data, stat_file_path, - self.loss.label_requirement, finetune_has_new_type=self.finetune_links["Default"].get_has_new_type() if self.finetune_links is not None else False, ) - training_data.preload_and_modify_all_data() - if validation_data is not None: - validation_data.preload_and_modify_all_data() ( self.training_dataloader, self.training_data, @@ -378,22 +381,30 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: self.get_sample_func, ) = {}, {}, {}, {}, {}, {} for model_key in self.model_keys: + # add data requirement for labels + data_requirement = self.loss[model_key].label_requirement + data_requirement += get_additional_data_requirement( + self.model[model_key] + ) + training_data[model_key].add_data_requirement(data_requirement) + if validation_data[model_key] is not None: + validation_data[model_key].add_data_requirement(data_requirement) + # data modification + training_data[model_key].preload_and_modify_all_data() + if validation_data[model_key] is not None: + validation_data[model_key].preload_and_modify_all_data() self.get_sample_func[model_key] = single_model_stat( self.model[model_key], model_params["model_dict"][model_key].get("data_stat_nbatch", 10), training_data[model_key], - validation_data[model_key], stat_file_path[model_key], - self.loss[model_key].label_requirement, finetune_has_new_type=self.finetune_links[ model_key ].get_has_new_type() if self.finetune_links is not None else False, ) - training_data[model_key].preload_and_modify_all_data() - if validation_data[model_key] is not None: - validation_data[model_key].preload_and_modify_all_data() + ( self.training_dataloader[model_key], self.training_data[model_key], diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index 9282aefc5f..67cc4ba16d 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -63,7 +63,7 @@ @modifier_args_plugin.register("random_tester", doc=doc_random_tester) def modifier_random_tester() -> list: - doc_seed = "Random seed used as the scaling factor." + doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors." return [ Argument("seed", int, optional=True, doc=doc_seed), ] @@ -93,7 +93,7 @@ def __init__( self, seed: int = 1, ) -> None: - """Construct a basic model for different tasks.""" + """Construct a random_tester modifier that scales data by deterministic random factors for testing.""" super().__init__() self.modifier_type = "random_tester" # Use a fixed seed for deterministic behavior @@ -134,7 +134,7 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls) def __init__(self) -> None: - """Construct a basic model for different tasks.""" + """Construct a modifier that zeros out data for testing.""" super().__init__() self.modifier_type = "zero_tester" @@ -177,7 +177,7 @@ def __init__( model_name: str, sfactor: float = 1.0, ) -> None: - """Construct a basic model for different tasks.""" + """Initialize a test modifier that applies scaled model predictions using a frozen model.""" super().__init__() self.modifier_type = "scaling_tester" self.model_name = model_name @@ -367,12 +367,16 @@ def tearDown(self) -> None: Removes model files and other artifacts created during testing. """ for f in os.listdir("."): - if f.startswith("frozen_model") and f.endswith(".pth"): - os.remove(f) - if f.startswith("model") and f.endswith(".pt"): - os.remove(f) - if f in ["lcurve.out", "checkpoint"]: - os.remove(f) + try: + if f.startswith("frozen_model") and f.endswith(".pth"): + os.remove(f) + elif f.startswith("model") and f.endswith(".pt"): + os.remove(f) + elif f in ["lcurve.out", "checkpoint"]: + os.remove(f) + except OSError: + # Ignore failures during cleanup to allow remaining files to be processed + pass @staticmethod def get_dataset_nframes(dataset: list[str]) -> int: From 50712547f9dfbedaaa7671d2c28f78b7de1392bc Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 24 Dec 2025 12:05:04 +0800 Subject: [PATCH 10/16] remove comment-out codes --- deepmd/pt/modifier/base_modifier.py | 1 - deepmd/pt/train/training.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index f61fbd5d29..79fc8ea3fe 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -112,7 +112,6 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N ): return - # model = self.dp.to(DEVICE) prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]] nframes = 1 diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6f0cff86a1..5f53f04777 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -248,11 +248,6 @@ def single_model_stat( _stat_file_path: str | None, finetune_has_new_type: bool = False, ) -> Callable[[], Any]: - # _data_requirement += get_additional_data_requirement(_model) - # _training_data.add_data_requirement(_data_requirement) - # if _validation_data is not None: - # _validation_data.add_data_requirement(_data_requirement) - @functools.lru_cache def get_sample() -> Any: sampled = make_stat_input( From 1fdcfa33bb7aac579419fc8dce28c8033112e205 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 24 Dec 2025 12:10:18 +0800 Subject: [PATCH 11/16] resolve nitpick comments --- deepmd/pt/train/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 5f53f04777..34375672e7 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -338,7 +338,7 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: training_data.add_data_requirement(data_requirement) if validation_data is not None: validation_data.add_data_requirement(data_requirement) - # data modification + # Preload and apply modifiers to all data before computing statistics training_data.preload_and_modify_all_data() if validation_data is not None: validation_data.preload_and_modify_all_data() @@ -384,7 +384,7 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: training_data[model_key].add_data_requirement(data_requirement) if validation_data[model_key] is not None: validation_data[model_key].add_data_requirement(data_requirement) - # data modification + # Preload and apply modifiers to all data before computing statistics training_data[model_key].preload_and_modify_all_data() if validation_data[model_key] is not None: validation_data[model_key].preload_and_modify_all_data() From f28a6a8829c9245eb9e6f8d461a5ed33e73f6475 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 25 Dec 2025 16:38:03 +0800 Subject: [PATCH 12/16] add `use_cache` as modifier var, so that the user can choose whether to save the data modification before training or to perform modification on-the-fly. --- deepmd/pt/modifier/base_modifier.py | 4 ++- deepmd/utils/data.py | 19 +++++++------- source/tests/pt/test_data_modifier.py | 37 ++++++++++++++++++++------- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index 79fc8ea3fe..5a8c6538b0 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -28,12 +28,14 @@ class BaseModifier(torch.nn.Module, make_base_modifier()): - def __init__(self) -> None: + def __init__(self, use_cache: bool = True) -> None: """Construct a base modifier for data modification tasks.""" torch.nn.Module.__init__(self) self.modifier_type = "base" self.jitable = True + self.use_cache = use_cache + def serialize(self) -> dict: """Serialize the modifier. diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index f5a03a400d..1034ccabc4 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -140,11 +140,11 @@ def __init__( # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method self.prefix_sum = np.cumsum(frames_list).tolist() - self.apply_modifier_at_load = True + self.use_modifier_cache = True if self.modifier is not None: - if hasattr(self.modifier, "apply_modifier_at_load"): - self.apply_modifier_at_load = self.modifier.apply_modifier_at_load - # Cache for modified frames when apply_modifier_at_load is True + if hasattr(self.modifier, "use_cache"): + self.use_modifier_cache = self.modifier.use_cache + # Cache for modified frames when use_modifier_cache is True self._modified_frame_cache = {} def add( @@ -385,9 +385,9 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray: def get_single_frame(self, index: int) -> dict: """Orchestrates loading a single frame efficiently using memmap.""" - # Check if we have a cached modified frame and apply_modifier_at_load is True + # Check if we have a cached modified frame and use_modifier_cache is True if ( - self.apply_modifier_at_load + self.use_modifier_cache and self.modifier is not None and index in self._modified_frame_cache ): @@ -490,19 +490,18 @@ def get_single_frame(self, index: int) -> dict: if self.modifier is not None: # Apply modifier if it exists self.modifier.modify_data(frame_data, self) - if self.apply_modifier_at_load: + if self.use_modifier_cache: # Cache the modified frame to avoid recomputation self._modified_frame_cache[index] = copy.deepcopy(frame_data) - return frame_data def preload_and_modify_all_data(self) -> None: """Preload all frames and apply modifier to cache them. - This method is useful when apply_modifier_at_load is True and you want to + This method is useful when use_modifier_cache is True and you want to avoid applying the modifier repeatedly during training. """ - if not self.apply_modifier_at_load or self.modifier is None: + if not self.use_modifier_cache or self.modifier is None: return log.info("Preloading and modifying all data frames...") diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index 67cc4ba16d..dfbeba2c77 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -64,23 +64,30 @@ @modifier_args_plugin.register("random_tester", doc=doc_random_tester) def modifier_random_tester() -> list: doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors." + doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." return [ Argument("seed", int, optional=True, doc=doc_seed), + Argument("use_cache", bool, optional=True, doc=doc_use_cache), ] @modifier_args_plugin.register("zero_tester", doc=doc_zero_tester) def modifier_zero_tester() -> list: - return [] + doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." + return [ + Argument("use_cache", bool, optional=True, doc=doc_use_cache), + ] @modifier_args_plugin.register("scaling_tester", doc=doc_scaling_tester) def modifier_scaling_tester() -> list[Argument]: doc_model_name = "The name of the frozen energy model file." doc_sfactor = "The scaling factor for correction." + doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." return [ Argument("model_name", str, optional=False, doc=doc_model_name), Argument("sfactor", float, optional=False, doc=doc_sfactor), + Argument("use_cache", bool, optional=True, doc=doc_use_cache), ] @@ -92,12 +99,14 @@ def __new__(cls, *args, **kwargs): def __init__( self, seed: int = 1, + use_cache: bool = True, ) -> None: """Construct a random_tester modifier that scales data by deterministic random factors for testing.""" - super().__init__() + super().__init__(use_cache) self.modifier_type = "random_tester" # Use a fixed seed for deterministic behavior self.rng = np.random.default_rng(seed) + self.sfactor = self.rng.random() def forward( self, @@ -121,11 +130,11 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N return if "find_energy" in data and data["find_energy"] == 1.0: - data["energy"] = data["energy"] * self.rng.random() + data["energy"] = data["energy"] * self.sfactor if "find_force" in data and data["find_force"] == 1.0: - data["force"] = data["force"] * self.rng.random() + data["force"] = data["force"] * self.sfactor if "find_virial" in data and data["find_virial"] == 1.0: - data["virial"] = data["virial"] * self.rng.random() + data["virial"] = data["virial"] * self.sfactor @BaseModifier.register("zero_tester") @@ -133,9 +142,12 @@ class ModifierZeroTester(BaseModifier): def __new__(cls, *args, **kwargs): return super().__new__(cls) - def __init__(self) -> None: + def __init__( + self, + use_cache: bool = True, + ) -> None: """Construct a modifier that zeros out data for testing.""" - super().__init__() + super().__init__(use_cache) self.modifier_type = "zero_tester" def forward( @@ -176,9 +188,10 @@ def __init__( self, model_name: str, sfactor: float = 1.0, + use_cache: bool = True, ) -> None: """Initialize a test modifier that applies scaled model predictions using a frozen model.""" - super().__init__() + super().__init__(use_cache) self.modifier_type = "scaling_tester" self.model_name = model_name self.sfactor = sfactor @@ -212,6 +225,7 @@ def forward( @parameterized( (1, 2), # training data batch_size (1, 2), # validation data batch_size + (True, False), # use_cache ) class TestDataModifier(unittest.TestCase): def setUp(self) -> None: @@ -240,7 +254,10 @@ def test_init_modify_data(self): """Ensure modify_data applied.""" tmp_config = self.config.copy() # add tester data modifier - tmp_config["model"]["modifier"] = {"type": "zero_tester"} + tmp_config["model"]["modifier"] = { + "type": "zero_tester", + "use_cache": self.param[2], + } # data modification is finished in __init__ trainer = get_trainer(tmp_config) @@ -262,6 +279,7 @@ def test_full_modify_data(self): tmp_config["model"]["modifier"] = { "type": "random_tester", "seed": 1024, + "use_cache": self.param[2], } # data modification is finished in __init__ @@ -307,6 +325,7 @@ def test_inference(self): "type": "scaling_tester", "model_name": "frozen_model_dm.pth", "sfactor": sfactor, + "use_cache": True, } trainer = get_trainer(tmp_config) From a0f8ccbce12b19d2834e14dfd34ffe5731f43d74 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 25 Dec 2025 17:55:12 +0800 Subject: [PATCH 13/16] use ThreadPoolExecutor to eliminate CUDA re-initialization in data modification during training. --- deepmd/pd/utils/dataset.py | 5 +++- deepmd/pt/train/training.py | 8 +++--- deepmd/pt/utils/dataloader.py | 4 +-- deepmd/pt/utils/dataset.py | 9 ++++--- deepmd/utils/data.py | 35 ++++++++++++++++++++------- source/tests/pt/test_data_modifier.py | 2 +- 6 files changed, 43 insertions(+), 20 deletions(-) diff --git a/deepmd/pd/utils/dataset.py b/deepmd/pd/utils/dataset.py index fa9106044c..671345e5af 100644 --- a/deepmd/pd/utils/dataset.py +++ b/deepmd/pd/utils/dataset.py @@ -5,6 +5,9 @@ Dataset, ) +from deepmd.pd.utils.env import ( + NUM_WORKERS, +) from deepmd.utils.data import ( DataRequirementItem, DeepmdData, @@ -32,7 +35,7 @@ def __len__(self): def __getitem__(self, index): """Get a frame from the selected system.""" - b_data = self._data_system.get_item_paddle(index) + b_data = self._data_system.get_item_paddle(index, NUM_WORKERS) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 34375672e7..d98b23d25c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -339,9 +339,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: if validation_data is not None: validation_data.add_data_requirement(data_requirement) # Preload and apply modifiers to all data before computing statistics - training_data.preload_and_modify_all_data() + training_data.preload_and_modify_all_data_torch() if validation_data is not None: - validation_data.preload_and_modify_all_data() + validation_data.preload_and_modify_all_data_torch() self.get_sample_func = single_model_stat( self.model, model_params.get("data_stat_nbatch", 10), @@ -385,9 +385,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: if validation_data[model_key] is not None: validation_data[model_key].add_data_requirement(data_requirement) # Preload and apply modifiers to all data before computing statistics - training_data[model_key].preload_and_modify_all_data() + training_data[model_key].preload_and_modify_all_data_torch() if validation_data[model_key] is not None: - validation_data[model_key].preload_and_modify_all_data() + validation_data[model_key].preload_and_modify_all_data_torch() self.get_sample_func[model_key] = single_model_stat( self.model[model_key], model_params["model_dict"][model_key].get("data_stat_nbatch", 10), diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index d74d6c5c00..807c1f5ba1 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -238,9 +238,9 @@ def print_summary( [ss._data_system.pbc for ss in self.systems], ) - def preload_and_modify_all_data(self) -> None: + def preload_and_modify_all_data_torch(self) -> None: for system in self.systems: - system.preload_and_modify_all_data() + system.preload_and_modify_all_data_torch() def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index f99de70f55..81d456d74e 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -12,6 +12,9 @@ from deepmd.pt.modifier import ( BaseModifier, ) +from deepmd.pt.utils.env import ( + NUM_WORKERS, +) from deepmd.utils.data import ( DataRequirementItem, DeepmdData, @@ -48,7 +51,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" - b_data = self._data_system.get_item_torch(index) + b_data = self._data_system.get_item_torch(index, NUM_WORKERS) b_data["natoms"] = self._natoms_vec return b_data @@ -68,5 +71,5 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N output_natoms_for_type_sel=data_item["output_natoms_for_type_sel"], ) - def preload_and_modify_all_data(self) -> None: - self._data_system.preload_and_modify_all_data() + def preload_and_modify_all_data_torch(self) -> None: + self._data_system.preload_and_modify_all_data_torch(NUM_WORKERS) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 1034ccabc4..bda976d682 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -253,17 +253,27 @@ def check_test_size(self, test_size: int) -> bool: """Check if the system can get a test dataset with `test_size` frames.""" return self.check_batch_size(test_size) - def get_item_torch(self, index: int) -> dict: + def get_item_torch( + self, + index: int, + num_worker: int = 1, + ) -> dict: """Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets. Parameters ---------- index index of the frame + num_worker + number of workers for parallel data modification """ - return self.get_single_frame(index) + return self.get_single_frame(index, num_worker) - def get_item_paddle(self, index: int) -> dict: + def get_item_paddle( + self, + index: int, + num_worker: int = 1, + ) -> dict: """Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets. Same with PyTorch backend. @@ -271,8 +281,10 @@ def get_item_paddle(self, index: int) -> dict: ---------- index index of the frame + num_worker + number of workers for parallel data modification """ - return self.get_single_frame(index) + return self.get_single_frame(index, num_worker) def get_batch(self, batch_size: int) -> dict: """Get a batch of data with `batch_size` frames. The frames are randomly picked from the data system. @@ -383,7 +395,7 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray: tmp = np.append(tmp, natoms_vec) return tmp.astype(np.int32) - def get_single_frame(self, index: int) -> dict: + def get_single_frame(self, index: int, num_worker: int) -> dict: """Orchestrates loading a single frame efficiently using memmap.""" # Check if we have a cached modified frame and use_modifier_cache is True if ( @@ -488,14 +500,19 @@ def get_single_frame(self, index: int) -> dict: frame_data["fid"] = index if self.modifier is not None: - # Apply modifier if it exists - self.modifier.modify_data(frame_data, self) + with ThreadPoolExecutor(max_workers=num_worker) as executor: + # Apply modifier if it exists + executor.submit( + self.modifier.modify_data, + frame_data, + self, + ) if self.use_modifier_cache: # Cache the modified frame to avoid recomputation self._modified_frame_cache[index] = copy.deepcopy(frame_data) return frame_data - def preload_and_modify_all_data(self) -> None: + def preload_and_modify_all_data_torch(self, num_worker: int) -> None: """Preload all frames and apply modifier to cache them. This method is useful when use_modifier_cache is True and you want to @@ -507,7 +524,7 @@ def preload_and_modify_all_data(self) -> None: log.info("Preloading and modifying all data frames...") for i in range(self.nframes): if i not in self._modified_frame_cache: - self.get_single_frame(i) + self.get_single_frame(i, num_worker) if (i + 1) % 100 == 0: log.info(f"Processed {i + 1}/{self.nframes} frames") log.info("All frames preloaded and modified.") diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index dfbeba2c77..de9e7fd34d 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -325,7 +325,7 @@ def test_inference(self): "type": "scaling_tester", "model_name": "frozen_model_dm.pth", "sfactor": sfactor, - "use_cache": True, + "use_cache": self.param[2], } trainer = get_trainer(tmp_config) From fc894dd795654fb02b0141e079242c833c332031 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Thu, 25 Dec 2025 18:02:53 +0800 Subject: [PATCH 14/16] resolve nitpick comments --- deepmd/utils/data.py | 4 +++- source/tests/pt/test_data_modifier.py | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index bda976d682..729e706819 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -502,11 +502,13 @@ def get_single_frame(self, index: int, num_worker: int) -> dict: if self.modifier is not None: with ThreadPoolExecutor(max_workers=num_worker) as executor: # Apply modifier if it exists - executor.submit( + future = executor.submit( self.modifier.modify_data, frame_data, self, ) + # Wait for completion and propagate any exceptions + future.result() if self.use_modifier_cache: # Cache the modified frame to avoid recomputation self._modified_frame_cache[index] = copy.deepcopy(frame_data) diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index de9e7fd34d..fc168e739a 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -62,7 +62,7 @@ @modifier_args_plugin.register("random_tester", doc=doc_random_tester) -def modifier_random_tester() -> list: +def modifier_random_tester() -> list[Argument]: doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors." doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." return [ @@ -72,7 +72,7 @@ def modifier_random_tester() -> list: @modifier_args_plugin.register("zero_tester", doc=doc_zero_tester) -def modifier_zero_tester() -> list: +def modifier_zero_tester() -> list[Argument]: doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." return [ Argument("use_cache", bool, optional=True, doc=doc_use_cache), @@ -377,7 +377,10 @@ def test_inference(self): # expected: output_model - sfactor * output_modifier for ii in range(3): np.testing.assert_allclose( - model_pred[ii], model_pred_ref[ii] - sfactor * modifier_pred[ii] + model_pred[ii], + model_pred_ref[ii] - sfactor * modifier_pred[ii], + rtol=1e-5, + atol=1e-8, ) def tearDown(self) -> None: From 1acc58e6c84924214bb4a0f472d7b885a7e8947f Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 31 Dec 2025 16:57:11 +0800 Subject: [PATCH 15/16] fix bug about max_workers in ThreadPoolExecutor --- deepmd/pd/utils/dataset.py | 2 +- deepmd/pt/utils/dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pd/utils/dataset.py b/deepmd/pd/utils/dataset.py index 671345e5af..5accd0315b 100644 --- a/deepmd/pd/utils/dataset.py +++ b/deepmd/pd/utils/dataset.py @@ -35,7 +35,7 @@ def __len__(self): def __getitem__(self, index): """Get a frame from the selected system.""" - b_data = self._data_system.get_item_paddle(index, NUM_WORKERS) + b_data = self._data_system.get_item_paddle(index, max(1, NUM_WORKERS)) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 81d456d74e..ce9a6c52c6 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -51,7 +51,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" - b_data = self._data_system.get_item_torch(index, NUM_WORKERS) + b_data = self._data_system.get_item_torch(index, max(1, NUM_WORKERS)) b_data["natoms"] = self._natoms_vec return b_data @@ -72,4 +72,4 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N ) def preload_and_modify_all_data_torch(self) -> None: - self._data_system.preload_and_modify_all_data_torch(NUM_WORKERS) + self._data_system.preload_and_modify_all_data_torch(max(1, NUM_WORKERS)) From 352c149e4f09b3f1c08e7ed45abdb6211925fbe5 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Tue, 6 Jan 2026 19:49:45 +0800 Subject: [PATCH 16/16] fix typo in wrapper and UT --- deepmd/pt/train/wrapper.py | 2 +- deepmd/utils/data.py | 2 -- source/tests/pt/test_data_modifier.py | 4 ++-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index a2539c6ff6..ddb4a4323d 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -191,7 +191,7 @@ def forward( if self.modifier is not None: modifier_pred = self.modifier(**input_dict) for k, v in modifier_pred.items(): - model_pred[k] = model_pred[k] - v + model_pred[k] = model_pred[k] + v return model_pred, None, None else: natoms = atype.shape[-1] diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 729e706819..b8c0692231 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -507,8 +507,6 @@ def get_single_frame(self, index: int, num_worker: int) -> dict: frame_data, self, ) - # Wait for completion and propagate any exceptions - future.result() if self.use_modifier_cache: # Cache the modified frame to avoid recomputation self._modified_frame_cache[index] = copy.deepcopy(frame_data) diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py index fc168e739a..18d66ef2ff 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/test_data_modifier.py @@ -374,11 +374,11 @@ def test_inference(self): model_pred = model.eval(coord, cell, atype) modifier_pred = modifier.eval(coord, cell, atype) model_pred_ref = model_ref.eval(coord, cell, atype) - # expected: output_model - sfactor * output_modifier + # expected: output_model = output_model_ref + sfactor * output_modifier for ii in range(3): np.testing.assert_allclose( model_pred[ii], - model_pred_ref[ii] - sfactor * modifier_pred[ii], + model_pred_ref[ii] + sfactor * modifier_pred[ii], rtol=1e-5, atol=1e-8, )