diff --git a/deepmd/pd/utils/dataset.py b/deepmd/pd/utils/dataset.py index fa9106044c..5accd0315b 100644 --- a/deepmd/pd/utils/dataset.py +++ b/deepmd/pd/utils/dataset.py @@ -5,6 +5,9 @@ Dataset, ) +from deepmd.pd.utils.env import ( + NUM_WORKERS, +) from deepmd.utils.data import ( DataRequirementItem, DeepmdData, @@ -32,7 +35,7 @@ def __len__(self): def __getitem__(self, index): """Get a frame from the selected system.""" - b_data = self._data_system.get_item_paddle(index) + b_data = self._data_system.get_item_paddle(index, max(1, NUM_WORKERS)) b_data["natoms"] = self._natoms_vec return b_data diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index d349d519ba..75efdd8c9f 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse import copy +import io import json import logging import os @@ -47,6 +48,9 @@ from deepmd.pt.model.model import ( BaseModel, ) +from deepmd.pt.modifier import ( + get_data_modifier, +) from deepmd.pt.train import ( training, ) @@ -111,6 +115,12 @@ def prepare_trainer_input_single( rank: int = 0, seed: int | None = None, ) -> tuple[DpLoaderSet, DpLoaderSet | None, DPPath | None]: + # get data modifier + modifier = None + modifier_params = model_params_single.get("modifier", None) + if modifier_params is not None: + modifier = get_data_modifier(modifier_params).to(DEVICE) + training_dataset_params = data_dict_single["training_data"] validation_dataset_params = data_dict_single.get("validation_data", None) validation_systems = ( @@ -145,6 +155,7 @@ def prepare_trainer_input_single( validation_dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, + modifier=modifier, ) if validation_systems else None @@ -154,6 +165,7 @@ def prepare_trainer_input_single( training_dataset_params["batch_size"], model_params_single["type_map"], seed=rank_seed, + modifier=modifier, ) return ( train_data_single, @@ -372,10 +384,22 @@ def freeze( output: str = "frozen_model.pth", head: str | None = None, ) -> None: - model = inference.Tester(model, head=head).model + tester = inference.Tester(model, head=head) + model = tester.model model.eval() model = torch.jit.script(model) - extra_files = {} + + dm_output = "data_modifier.pth" + extra_files = {dm_output: ""} + if tester.modifier is not None: + dm = tester.modifier + dm.eval() + buffer = io.BytesIO() + torch.jit.save( + torch.jit.script(dm), + buffer, + ) + extra_files = {dm_output: buffer.getvalue()} torch.jit.save( model, output, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 2726b61152..6e63ecb2fc 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import io import json import logging from collections.abc import ( @@ -171,8 +172,21 @@ def __init__( self.dp = ModelWrapper(model) self.dp.load_state_dict(state_dict) elif str(self.model_path).endswith(".pth"): - model = torch.jit.load(model_file, map_location=env.DEVICE) - self.dp = ModelWrapper(model) + extra_files = {"data_modifier.pth": ""} + model = torch.jit.load( + model_file, map_location=env.DEVICE, _extra_files=extra_files + ) + modifier = None + # Load modifier if it exists in extra_files + if len(extra_files["data_modifier.pth"]) > 0: + # Create a file-like object from the in-memory data + modifier_data = extra_files["data_modifier.pth"] + if isinstance(modifier_data, bytes): + modifier_data = io.BytesIO(modifier_data) + # Load the modifier directly from the file-like object + modifier = torch.jit.load(modifier_data, map_location=env.DEVICE) + self.dp = ModelWrapper(model, modifier=modifier) + self.modifier = modifier model_def_script = self.dp.model["Default"].get_model_def_script() if model_def_script: self.model_def_script = json.loads(model_def_script) diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index 4c49abeef8..b026cd54c5 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -9,6 +9,9 @@ from deepmd.pt.model.model import ( get_model, ) +from deepmd.pt.modifier import ( + get_data_modifier, +) from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -60,6 +63,11 @@ def __init__( ) # wrapper Hessian to Energy model due to JIT limit self.model_params = deepcopy(model_params) self.model = get_model(model_params).to(DEVICE) + self.modifier = None + if "modifier" in model_params: + modifier = get_data_modifier(model_params["modifier"]).to(DEVICE) + if modifier.jitable: + self.modifier = modifier # Model Wrapper self.wrapper = ModelWrapper(self.model) # inference only diff --git a/deepmd/pt/modifier/__init__.py b/deepmd/pt/modifier/__init__.py new file mode 100644 index 0000000000..71d847bcbc --- /dev/null +++ b/deepmd/pt/modifier/__init__.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +from typing import ( + Any, +) + +from .base_modifier import ( + BaseModifier, +) + +__all__ = [ + "BaseModifier", + "get_data_modifier", +] + + +def get_data_modifier(_modifier_params: dict[str, Any]) -> BaseModifier: + modifier_params = copy.deepcopy(_modifier_params) + try: + modifier_type = modifier_params.pop("type") + except KeyError: + raise ValueError("Data modifier type not specified!") from None + return BaseModifier.get_class_by_type(modifier_type).get_modifier(modifier_params) diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py new file mode 100644 index 0000000000..5a8c6538b0 --- /dev/null +++ b/deepmd/pt/modifier/base_modifier.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + abstractmethod, +) + +import numpy as np +import torch + +from deepmd.dpmodel.array_api import ( + Array, +) +from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT +from deepmd.dpmodel.modifier.base_modifier import ( + make_base_modifier, +) +from deepmd.pt.utils.env import ( + DEVICE, + GLOBAL_PT_FLOAT_PRECISION, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.data import ( + DeepmdData, +) + + +class BaseModifier(torch.nn.Module, make_base_modifier()): + def __init__(self, use_cache: bool = True) -> None: + """Construct a base modifier for data modification tasks.""" + torch.nn.Module.__init__(self) + self.modifier_type = "base" + self.jitable = True + + self.use_cache = use_cache + + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + data = { + "@class": "Modifier", + "type": self.modifier_type, + "@version": 3, + } + return data + + @classmethod + def deserialize(cls, data: dict) -> "BaseModifier": + """Deserialize the modifier. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModifier + The deserialized modifier + """ + data = data.copy() + # Remove serialization metadata before passing to constructor + data.pop("@class", None) + data.pop("type", None) + data.pop("@version", None) + modifier = cls(**data) + return modifier + + @abstractmethod + @torch.jit.export + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Compute energy, force, and virial corrections.""" + + @torch.jit.unused + def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: + """Modify data of single frame. + + Parameters + ---------- + data + Internal data of DeepmdData. + Be a dict, has the following keys + - coord coordinates (nat, 3) + - box simulation box (9,) + - atype atom types (nat,) + - fparam frame parameter (nfp,) + - aparam atom parameter (nat, nap) + - find_energy tells if data has energy + - find_force tells if data has force + - find_virial tells if data has virial + - energy energy (1,) + - force force (nat, 3) + - virial virial (9,) + """ + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]] + + nframes = 1 + natoms = len(data["atype"]) + atom_types = np.tile(data["atype"], nframes).reshape(nframes, -1) + + coord_input = torch.tensor( + data["coord"].reshape([nframes, natoms, 3]).astype(prec), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor( + atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISION_DICT[torch.long]]), + dtype=torch.long, + device=DEVICE, + ) + if data["box"] is not None: + box_input = torch.tensor( + data["box"].reshape([nframes, 3, 3]).astype(prec), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + else: + box_input = None + if "fparam" in data: + fparam_input = to_torch_tensor(data["fparam"].reshape(nframes, -1)) + else: + fparam_input = None + if "aparam" in data: + aparam_input = to_torch_tensor(data["aparam"].reshape(nframes, natoms, -1)) + else: + aparam_input = None + do_atomic_virial = False + + # implement data modification method in forward + modifier_data = self.forward( + coord_input, + type_input, + box_input, + fparam_input, + aparam_input, + do_atomic_virial, + ) + + if data.get("find_energy") == 1.0: + if "energy" not in modifier_data: + raise KeyError( + f"Modifier {self.__class__.__name__} did not provide 'energy' " + "in its output while 'find_energy' is set." + ) + data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape( + data["energy"].shape + ) + if data.get("find_force") == 1.0: + if "force" not in modifier_data: + raise KeyError( + f"Modifier {self.__class__.__name__} did not provide 'force' " + "in its output while 'find_force' is set." + ) + data["force"] -= to_numpy_array(modifier_data["force"]).reshape( + data["force"].shape + ) + if data.get("find_virial") == 1.0: + if "virial" not in modifier_data: + raise KeyError( + f"Modifier {self.__class__.__name__} did not provide 'virial' " + "in its output while 'find_virial' is set." + ) + data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape( + data["virial"].shape + ) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 24440e19de..d98b23d25c 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -245,16 +245,9 @@ def single_model_stat( _model: Any, _data_stat_nbatch: int, _training_data: DpLoaderSet, - _validation_data: DpLoaderSet | None, _stat_file_path: str | None, - _data_requirement: list[DataRequirementItem], finetune_has_new_type: bool = False, ) -> Callable[[], Any]: - _data_requirement += get_additional_data_requirement(_model) - _training_data.add_data_requirement(_data_requirement) - if _validation_data is not None: - _validation_data.add_data_requirement(_data_requirement) - @functools.lru_cache def get_sample() -> Any: sampled = make_stat_input( @@ -339,13 +332,21 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: # Data if not self.multi_task: + # add data requirement for labels + data_requirement = self.loss.label_requirement + data_requirement += get_additional_data_requirement(self.model) + training_data.add_data_requirement(data_requirement) + if validation_data is not None: + validation_data.add_data_requirement(data_requirement) + # Preload and apply modifiers to all data before computing statistics + training_data.preload_and_modify_all_data_torch() + if validation_data is not None: + validation_data.preload_and_modify_all_data_torch() self.get_sample_func = single_model_stat( self.model, model_params.get("data_stat_nbatch", 10), training_data, - validation_data, stat_file_path, - self.loss.label_requirement, finetune_has_new_type=self.finetune_links["Default"].get_has_new_type() if self.finetune_links is not None else False, @@ -375,19 +376,30 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: self.get_sample_func, ) = {}, {}, {}, {}, {}, {} for model_key in self.model_keys: + # add data requirement for labels + data_requirement = self.loss[model_key].label_requirement + data_requirement += get_additional_data_requirement( + self.model[model_key] + ) + training_data[model_key].add_data_requirement(data_requirement) + if validation_data[model_key] is not None: + validation_data[model_key].add_data_requirement(data_requirement) + # Preload and apply modifiers to all data before computing statistics + training_data[model_key].preload_and_modify_all_data_torch() + if validation_data[model_key] is not None: + validation_data[model_key].preload_and_modify_all_data_torch() self.get_sample_func[model_key] = single_model_stat( self.model[model_key], model_params["model_dict"][model_key].get("data_stat_nbatch", 10), training_data[model_key], - validation_data[model_key], stat_file_path[model_key], - self.loss[model_key].label_requirement, finetune_has_new_type=self.finetune_links[ model_key ].get_has_new_type() if self.finetune_links is not None else False, ) + ( self.training_dataloader[model_key], self.training_data[model_key], diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 2669e3d832..ddb4a4323d 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -20,6 +20,7 @@ def __init__( loss: torch.nn.Module | dict = None, model_params: dict[str, Any] | None = None, shared_links: dict[str, Any] | None = None, + modifier: torch.nn.Module | None = None, ) -> None: """Construct a DeePMD model wrapper. @@ -57,6 +58,8 @@ def __init__( ) self.loss[task_key] = loss[task_key] self.inference_only = self.loss is None + # Modifier + self.modifier = modifier def share_params( self, @@ -185,6 +188,10 @@ def forward( if self.inference_only or inference_only: model_pred = self.model[task_key](**input_dict) + if self.modifier is not None: + modifier_pred = self.modifier(**input_dict) + for k, v in modifier_pred.items(): + model_pred[k] = model_pred[k] + v return model_pred, None, None else: natoms = atype.shape[-1] diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index c991e59daa..807c1f5ba1 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -25,6 +25,9 @@ DistributedSampler, ) +from deepmd.pt.modifier import ( + BaseModifier, +) from deepmd.pt.utils import ( dp_random, env, @@ -83,6 +86,7 @@ def __init__( type_map: list[str] | None, seed: int | None = None, shuffle: bool = True, + modifier: BaseModifier | None = None, ) -> None: if seed is not None: setup_seed(seed) @@ -94,6 +98,7 @@ def construct_dataset(system: str) -> DeepmdDataSetForLoader: return DeepmdDataSetForLoader( system=system, type_map=type_map, + modifier=modifier, ) self.systems: list[DeepmdDataSetForLoader] = [] @@ -233,6 +238,10 @@ def print_summary( [ss._data_system.pbc for ss in self.systems], ) + def preload_and_modify_all_data_torch(self) -> None: + for system in self.systems: + system.preload_and_modify_all_data_torch() + def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]: example = batch[0] diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 481fa04497..ce9a6c52c6 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -9,6 +9,12 @@ Dataset, ) +from deepmd.pt.modifier import ( + BaseModifier, +) +from deepmd.pt.utils.env import ( + NUM_WORKERS, +) from deepmd.utils.data import ( DataRequirementItem, DeepmdData, @@ -16,16 +22,25 @@ class DeepmdDataSetForLoader(Dataset): - def __init__(self, system: str, type_map: list[str] | None = None) -> None: + def __init__( + self, + system: str, + type_map: list[str] | None = None, + modifier: BaseModifier | None = None, + ) -> None: """Construct DeePMD-style dataset containing frames cross different systems. Args: - systems: Paths to systems. - type_map: Atom types. + - modifier: Data modifier. """ self.system = system self._type_map = type_map - self._data_system = DeepmdData(sys_path=system, type_map=self._type_map) + self.modifier = modifier + self._data_system = DeepmdData( + sys_path=system, type_map=self._type_map, modifier=self.modifier + ) self.mixed_type = self._data_system.mixed_type self._ntypes = self._data_system.get_ntypes() self._natoms = self._data_system.get_natoms() @@ -36,7 +51,7 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> dict[str, Any]: """Get a frame from the selected system.""" - b_data = self._data_system.get_item_torch(index) + b_data = self._data_system.get_item_torch(index, max(1, NUM_WORKERS)) b_data["natoms"] = self._natoms_vec return b_data @@ -55,3 +70,6 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N dtype=data_item["dtype"], output_natoms_for_type_sel=data_item["output_natoms_for_type_sel"], ) + + def preload_and_modify_all_data_torch(self) -> None: + self._data_system.preload_and_modify_all_data_torch(max(1, NUM_WORKERS)) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 287107a7ff..b8c0692231 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import bisect +import copy import functools import logging from concurrent.futures import ( @@ -139,6 +140,13 @@ def __init__( # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method self.prefix_sum = np.cumsum(frames_list).tolist() + self.use_modifier_cache = True + if self.modifier is not None: + if hasattr(self.modifier, "use_cache"): + self.use_modifier_cache = self.modifier.use_cache + # Cache for modified frames when use_modifier_cache is True + self._modified_frame_cache = {} + def add( self, key: str, @@ -245,17 +253,27 @@ def check_test_size(self, test_size: int) -> bool: """Check if the system can get a test dataset with `test_size` frames.""" return self.check_batch_size(test_size) - def get_item_torch(self, index: int) -> dict: + def get_item_torch( + self, + index: int, + num_worker: int = 1, + ) -> dict: """Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets. Parameters ---------- index index of the frame + num_worker + number of workers for parallel data modification """ - return self.get_single_frame(index) + return self.get_single_frame(index, num_worker) - def get_item_paddle(self, index: int) -> dict: + def get_item_paddle( + self, + index: int, + num_worker: int = 1, + ) -> dict: """Get a single frame data . The frame is picked from the data system by index. The index is coded across all the sets. Same with PyTorch backend. @@ -263,8 +281,10 @@ def get_item_paddle(self, index: int) -> dict: ---------- index index of the frame + num_worker + number of workers for parallel data modification """ - return self.get_single_frame(index) + return self.get_single_frame(index, num_worker) def get_batch(self, batch_size: int) -> dict: """Get a batch of data with `batch_size` frames. The frames are randomly picked from the data system. @@ -375,8 +395,16 @@ def get_natoms_vec(self, ntypes: int) -> np.ndarray: tmp = np.append(tmp, natoms_vec) return tmp.astype(np.int32) - def get_single_frame(self, index: int) -> dict: + def get_single_frame(self, index: int, num_worker: int) -> dict: """Orchestrates loading a single frame efficiently using memmap.""" + # Check if we have a cached modified frame and use_modifier_cache is True + if ( + self.use_modifier_cache + and self.modifier is not None + and index in self._modified_frame_cache + ): + return self._modified_frame_cache[index] + if index < 0 or index >= self.nframes: raise IndexError(f"Frame index {index} out of range [0, {self.nframes})") # 1. Find the correct set directory and local frame index @@ -470,8 +498,37 @@ def get_single_frame(self, index: int) -> dict: frame_data["box"] = None frame_data["fid"] = index + + if self.modifier is not None: + with ThreadPoolExecutor(max_workers=num_worker) as executor: + # Apply modifier if it exists + future = executor.submit( + self.modifier.modify_data, + frame_data, + self, + ) + if self.use_modifier_cache: + # Cache the modified frame to avoid recomputation + self._modified_frame_cache[index] = copy.deepcopy(frame_data) return frame_data + def preload_and_modify_all_data_torch(self, num_worker: int) -> None: + """Preload all frames and apply modifier to cache them. + + This method is useful when use_modifier_cache is True and you want to + avoid applying the modifier repeatedly during training. + """ + if not self.use_modifier_cache or self.modifier is None: + return + + log.info("Preloading and modifying all data frames...") + for i in range(self.nframes): + if i not in self._modified_frame_cache: + self.get_single_frame(i, num_worker) + if (i + 1) % 100 == 0: + log.info(f"Processed {i + 1}/{self.nframes} frames") + log.info("All frames preloaded and modified.") + def avg(self, key: str) -> float: """Return the average value of an item.""" if key not in self.data_dict.keys(): diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/test_data_modifier.py new file mode 100644 index 0000000000..18d66ef2ff --- /dev/null +++ b/source/tests/pt/test_data_modifier.py @@ -0,0 +1,502 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Test data modification functionality. + +This module tests the data modification functionality, specifically +testing the BaseModifier implementations and their effects on training and +validation data. It includes: + +1. Test modifier implementations (random_tester and zero_tester) +2. Tests to verify data modification is applied correctly +3. Tests to ensure data modification is only applied once + +The tests use parameterized testing with different batch sizes for training +and validation data. +""" + +import json +import os +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import torch +from dargs import ( + Argument, +) + +from deepmd.dpmodel.array_api import ( + Array, +) +from deepmd.infer import ( + DeepEval, +) +from deepmd.pt.entrypoints.main import ( + freeze, + get_trainer, +) +from deepmd.pt.modifier.base_modifier import ( + BaseModifier, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.argcheck import ( + modifier_args_plugin, +) +from deepmd.utils.data import ( + DeepmdData, +) + +from ..consistent.common import ( + parameterized, +) + +doc_random_tester = "A test modifier that multiplies energy, force, and virial data by a deterministic random factor." +doc_zero_tester = "A test modifier that zeros out energy, force, and virial data by subtracting their original values." +doc_scaling_tester = "A test modifier that applies scaled model predictions as data modification using a frozen model." + + +@modifier_args_plugin.register("random_tester", doc=doc_random_tester) +def modifier_random_tester() -> list[Argument]: + doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors." + doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." + return [ + Argument("seed", int, optional=True, doc=doc_seed), + Argument("use_cache", bool, optional=True, doc=doc_use_cache), + ] + + +@modifier_args_plugin.register("zero_tester", doc=doc_zero_tester) +def modifier_zero_tester() -> list[Argument]: + doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." + return [ + Argument("use_cache", bool, optional=True, doc=doc_use_cache), + ] + + +@modifier_args_plugin.register("scaling_tester", doc=doc_scaling_tester) +def modifier_scaling_tester() -> list[Argument]: + doc_model_name = "The name of the frozen energy model file." + doc_sfactor = "The scaling factor for correction." + doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." + return [ + Argument("model_name", str, optional=False, doc=doc_model_name), + Argument("sfactor", float, optional=False, doc=doc_sfactor), + Argument("use_cache", bool, optional=True, doc=doc_use_cache), + ] + + +@BaseModifier.register("random_tester") +class ModifierRandomTester(BaseModifier): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + seed: int = 1, + use_cache: bool = True, + ) -> None: + """Construct a random_tester modifier that scales data by deterministic random factors for testing.""" + super().__init__(use_cache) + self.modifier_type = "random_tester" + # Use a fixed seed for deterministic behavior + self.rng = np.random.default_rng(seed) + self.sfactor = self.rng.random() + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Implementation of abstractmethod.""" + return {} + + def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: + """Multiply by a deterministic factor for testing.""" + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + if "find_energy" in data and data["find_energy"] == 1.0: + data["energy"] = data["energy"] * self.sfactor + if "find_force" in data and data["find_force"] == 1.0: + data["force"] = data["force"] * self.sfactor + if "find_virial" in data and data["find_virial"] == 1.0: + data["virial"] = data["virial"] * self.sfactor + + +@BaseModifier.register("zero_tester") +class ModifierZeroTester(BaseModifier): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + use_cache: bool = True, + ) -> None: + """Construct a modifier that zeros out data for testing.""" + super().__init__(use_cache) + self.modifier_type = "zero_tester" + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Implementation of abstractmethod.""" + return {} + + def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None: + """Zero out energy, force, and virial data.""" + if ( + "find_energy" not in data + and "find_force" not in data + and "find_virial" not in data + ): + return + + if "find_energy" in data and data["find_energy"] == 1.0: + data["energy"] -= data["energy"] + if "find_force" in data and data["find_force"] == 1.0: + data["force"] -= data["force"] + if "find_virial" in data and data["find_virial"] == 1.0: + data["virial"] -= data["virial"] + + +@BaseModifier.register("scaling_tester") +class ModifierScalingTester(BaseModifier): + def __new__(cls, *args, **kwargs): + return super().__new__(cls) + + def __init__( + self, + model_name: str, + sfactor: float = 1.0, + use_cache: bool = True, + ) -> None: + """Initialize a test modifier that applies scaled model predictions using a frozen model.""" + super().__init__(use_cache) + self.modifier_type = "scaling_tester" + self.model_name = model_name + self.sfactor = sfactor + self.model = torch.jit.load(model_name, map_location=env.DEVICE) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Take scaled model prediction as data modification.""" + model_pred = self.model( + coord=coord, + atype=atype, + box=box, + do_atomic_virial=do_atomic_virial, + fparam=fparam, + aparam=aparam, + ) + if isinstance(model_pred, tuple): + model_pred = model_pred[0] + for k in ["energy", "force", "virial"]: + model_pred[k] = model_pred[k] * self.sfactor + return model_pred + + +@parameterized( + (1, 2), # training data batch_size + (1, 2), # validation data batch_size + (True, False), # use_cache +) +class TestDataModifier(unittest.TestCase): + def setUp(self) -> None: + """Set up test fixtures.""" + input_json = str(Path(__file__).parent / "water/se_e2_a.json") + training_data = [ + str(Path(__file__).parent / "water/data/data_0"), + str(Path(__file__).parent / "water/data/data_1"), + ] + validation_data = [str(Path(__file__).parent / "water/data/data_1")] + with open(input_json, encoding="utf-8") as f: + config = json.load(f) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + config["learning_rate"]["start_lr"] = 1.0 + config["training"]["training_data"]["systems"] = training_data + config["training"]["training_data"]["batch_size"] = self.param[0] + config["training"]["validation_data"]["systems"] = validation_data + config["training"]["validation_data"]["batch_size"] = self.param[1] + self.config = config + + self.training_nframes = self.get_dataset_nframes(training_data) + self.validation_nframes = self.get_dataset_nframes(validation_data) + + def test_init_modify_data(self): + """Ensure modify_data applied.""" + tmp_config = self.config.copy() + # add tester data modifier + tmp_config["model"]["modifier"] = { + "type": "zero_tester", + "use_cache": self.param[2], + } + + # data modification is finished in __init__ + trainer = get_trainer(tmp_config) + + # training data + training_data = trainer.get_data(is_train=True) + # validation data + validation_data = trainer.get_data(is_train=False) + + for dataset in [training_data, validation_data]: + for kw in ["energy", "force"]: + data = to_numpy_array(dataset[1][kw]) + np.testing.assert_allclose(data, np.zeros_like(data)) + + def test_full_modify_data(self): + """Ensure modify_data only applied once.""" + tmp_config = self.config.copy() + # add tester data modifier + tmp_config["model"]["modifier"] = { + "type": "random_tester", + "seed": 1024, + "use_cache": self.param[2], + } + + # data modification is finished in __init__ + trainer = get_trainer(tmp_config) + + # training data + training_data_before = self.get_sampled_data( + trainer, self.training_nframes, True + ) + # validation data + validation_data_before = self.get_sampled_data( + trainer, self.validation_nframes, False + ) + + trainer.run() + + # training data + training_data_after = self.get_sampled_data( + trainer, self.training_nframes, True + ) + # validation data + validation_data_after = self.get_sampled_data( + trainer, self.validation_nframes, False + ) + + for label_kw in ["energy", "force"]: + self.check_sampled_data(training_data_before, training_data_after, label_kw) + self.check_sampled_data( + validation_data_before, validation_data_after, label_kw + ) + + def test_inference(self): + """Test the inference with data modification by verifying scaled model predictions are properly applied.""" + # Generate frozen energy model used in modifier + trainer = get_trainer(self.config) + trainer.run() + freeze("model.ckpt.pt", "frozen_model_dm.pth") + + tmp_config = self.config.copy() + sfactor = np.random.default_rng(1).random() + # add tester data modifier + tmp_config["model"]["modifier"] = { + "type": "scaling_tester", + "model_name": "frozen_model_dm.pth", + "sfactor": sfactor, + "use_cache": self.param[2], + } + + trainer = get_trainer(tmp_config) + trainer.run() + freeze("model.ckpt.pt", "frozen_model.pth") + + cell = np.array( + [ + 5.122106549439247480e00, + 4.016537340154059388e-01, + 6.951654033828678081e-01, + 4.016537340154059388e-01, + 6.112136112297989143e00, + 8.178091365465004481e-01, + 6.951654033828678081e-01, + 8.178091365465004481e-01, + 6.159552512682983760e00, + ] + ).reshape(1, 3, 3) + coord = np.array( + [ + 2.978060152121375648e00, + 3.588469695887098077e00, + 2.792459820604495491e00, + 3.895592322591093115e00, + 2.712091020667753760e00, + 1.366836847133650501e00, + 9.955616170888935690e-01, + 4.121324820711413039e00, + 1.817239061889086571e00, + 3.553661462345699906e00, + 5.313046969500791583e00, + 6.635182659098815883e00, + 6.088601018589653080e00, + 6.575011420004332585e00, + 6.825240650611076099e00, + ] + ).reshape(1, -1, 3) + atype = np.array([0, 0, 0, 1, 1]).reshape(1, -1) + + model = DeepEval("frozen_model.pth") + modifier = DeepEval("frozen_model_dm.pth") + # model inference without modifier + model_ref = DeepEval("model.ckpt.pt") + + model_pred = model.eval(coord, cell, atype) + modifier_pred = modifier.eval(coord, cell, atype) + model_pred_ref = model_ref.eval(coord, cell, atype) + # expected: output_model = output_model_ref + sfactor * output_modifier + for ii in range(3): + np.testing.assert_allclose( + model_pred[ii], + model_pred_ref[ii] + sfactor * modifier_pred[ii], + rtol=1e-5, + atol=1e-8, + ) + + def tearDown(self) -> None: + """Clean up test artifacts after each test. + + Removes model files and other artifacts created during testing. + """ + for f in os.listdir("."): + try: + if f.startswith("frozen_model") and f.endswith(".pth"): + os.remove(f) + elif f.startswith("model") and f.endswith(".pt"): + os.remove(f) + elif f in ["lcurve.out", "checkpoint"]: + os.remove(f) + except OSError: + # Ignore failures during cleanup to allow remaining files to be processed + pass + + @staticmethod + def get_dataset_nframes(dataset: list[str]) -> int: + """Calculate total number of frames in a dataset. + + Args: + dataset: List of dataset paths + + Returns + ------- + int: Total number of frames across all datasets + """ + nframes = 0 + for _data in dataset: + _dpdata = DeepmdData(_data) + nframes += _dpdata.nframes + return nframes + + @staticmethod + def get_sampled_data(trainer, nbatch: int, is_train: bool): + """ + Collect all data from trainer and organize by IDs for easy comparison. + + Args: + trainer: The trainer object + nbatch: Number of batches to iterate through + is_train: Whether to get training data (True) or validation data (False) + + Returns + ------- + dict: A nested dictionary organized by system_id and frame_id + Format: {system_id: {frame_id: label_dict}} + """ + output = {} + # Keep track of all unique frames we've collected + collected_frames = set() + + # Continue collecting data until we've gone through all batches + for _ in range(nbatch): + _, label_dict, log_dict = trainer.get_data(is_train=is_train) + + system_id = log_dict["sid"] + frame_ids = log_dict["fid"] + + # Initialize system entry if not exists + if system_id not in output: + output[system_id] = {} + + # Store label data for each frame ID + for idx, frame_id in enumerate(frame_ids): + # Skip if we already have this frame + frame_key = (system_id, frame_id) + if frame_key in collected_frames: + continue + + # Create a copy of label_dict for this specific frame + frame_data = {} + for key, value in label_dict.items(): + # If value is a tensor/array with batch dimension, extract the specific frame + if ( + hasattr(value, "shape") + and len(value.shape) > 0 + and value.shape[0] == len(frame_ids) + ): + # Handle batched data - extract the data for this frame + frame_data[key] = value[idx] + else: + # For scalar values or non-batched data, just copy as is + frame_data[key] = value + + output[system_id][frame_id] = frame_data + collected_frames.add(frame_key) + + return output + + @staticmethod + def check_sampled_data( + ref_data: dict[int, dict], test_data: dict[int, dict], label_kw: str + ): + """Compare sampled data between reference and test datasets. + + Args: + ref_data: Reference data dictionary organized by system and frame IDs + test_data: Test data dictionary organized by system and frame IDs + label_kw: Key of the label to compare (e.g., "energy", "force") + + Raises + ------ + AssertionError: If the data doesn't match between reference and test + """ + for sid in ref_data.keys(): + ref_sys = ref_data[sid] + test_sys = test_data[sid] + for fid in ref_sys.keys(): + # compare common elements + try: + ref_label = to_numpy_array(ref_sys[fid][label_kw]) + test_label = to_numpy_array(test_sys[fid][label_kw]) + except KeyError: + continue + np.testing.assert_allclose(ref_label, test_label)