Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion deepmd/pd/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
Dataset,
)

from deepmd.pd.utils.env import (
NUM_WORKERS,
)
from deepmd.utils.data import (
DataRequirementItem,
DeepmdData,
Expand Down Expand Up @@ -32,7 +35,7 @@ def __len__(self):

def __getitem__(self, index):
"""Get a frame from the selected system."""
b_data = self._data_system.get_item_paddle(index)
b_data = self._data_system.get_item_paddle(index, max(1, NUM_WORKERS))
b_data["natoms"] = self._natoms_vec
return b_data

Expand Down
28 changes: 26 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import io
import json
import logging
import os
Expand Down Expand Up @@ -47,6 +48,9 @@
from deepmd.pt.model.model import (
BaseModel,
)
from deepmd.pt.modifier import (
get_data_modifier,
)
from deepmd.pt.train import (
training,
)
Expand Down Expand Up @@ -111,6 +115,12 @@ def prepare_trainer_input_single(
rank: int = 0,
seed: int | None = None,
) -> tuple[DpLoaderSet, DpLoaderSet | None, DPPath | None]:
# get data modifier
modifier = None
modifier_params = model_params_single.get("modifier", None)
if modifier_params is not None:
modifier = get_data_modifier(modifier_params).to(DEVICE)

training_dataset_params = data_dict_single["training_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
Expand Down Expand Up @@ -145,6 +155,7 @@ def prepare_trainer_input_single(
validation_dataset_params["batch_size"],
model_params_single["type_map"],
seed=rank_seed,
modifier=modifier,
)
if validation_systems
else None
Expand All @@ -154,6 +165,7 @@ def prepare_trainer_input_single(
training_dataset_params["batch_size"],
model_params_single["type_map"],
seed=rank_seed,
modifier=modifier,
)
return (
train_data_single,
Expand Down Expand Up @@ -372,10 +384,22 @@ def freeze(
output: str = "frozen_model.pth",
head: str | None = None,
) -> None:
model = inference.Tester(model, head=head).model
tester = inference.Tester(model, head=head)
model = tester.model
model.eval()
model = torch.jit.script(model)
extra_files = {}

dm_output = "data_modifier.pth"
extra_files = {dm_output: ""}
if tester.modifier is not None:
dm = tester.modifier
dm.eval()
buffer = io.BytesIO()
torch.jit.save(
torch.jit.script(dm),
buffer,
)
extra_files = {dm_output: buffer.getvalue()}
torch.jit.save(
model,
output,
Expand Down
18 changes: 16 additions & 2 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import io
import json
import logging
from collections.abc import (
Expand Down Expand Up @@ -171,8 +172,21 @@ def __init__(
self.dp = ModelWrapper(model)
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
model = torch.jit.load(model_file, map_location=env.DEVICE)
self.dp = ModelWrapper(model)
extra_files = {"data_modifier.pth": ""}
model = torch.jit.load(
model_file, map_location=env.DEVICE, _extra_files=extra_files
)
modifier = None
# Load modifier if it exists in extra_files
if len(extra_files["data_modifier.pth"]) > 0:
# Create a file-like object from the in-memory data
modifier_data = extra_files["data_modifier.pth"]
if isinstance(modifier_data, bytes):
modifier_data = io.BytesIO(modifier_data)
# Load the modifier directly from the file-like object
modifier = torch.jit.load(modifier_data, map_location=env.DEVICE)
self.dp = ModelWrapper(model, modifier=modifier)
self.modifier = modifier
model_def_script = self.dp.model["Default"].get_model_def_script()
if model_def_script:
self.model_def_script = json.loads(model_def_script)
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.modifier import (
get_data_modifier,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
Expand Down Expand Up @@ -60,6 +63,11 @@ def __init__(
) # wrapper Hessian to Energy model due to JIT limit
self.model_params = deepcopy(model_params)
self.model = get_model(model_params).to(DEVICE)
self.modifier = None
if "modifier" in model_params:
modifier = get_data_modifier(model_params["modifier"]).to(DEVICE)
if modifier.jitable:
self.modifier = modifier

# Model Wrapper
self.wrapper = ModelWrapper(self.model) # inference only
Expand Down
23 changes: 23 additions & 0 deletions deepmd/pt/modifier/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Any,
)

from .base_modifier import (
BaseModifier,
)

__all__ = [
"BaseModifier",
"get_data_modifier",
]


def get_data_modifier(_modifier_params: dict[str, Any]) -> BaseModifier:
modifier_params = copy.deepcopy(_modifier_params)
try:
modifier_type = modifier_params.pop("type")
except KeyError:
raise ValueError("Data modifier type not specified!") from None
return BaseModifier.get_class_by_type(modifier_type).get_modifier(modifier_params)
187 changes: 187 additions & 0 deletions deepmd/pt/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
abstractmethod,
)

import numpy as np
import torch

from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
from deepmd.dpmodel.modifier.base_modifier import (
make_base_modifier,
)
from deepmd.pt.utils.env import (
DEVICE,
GLOBAL_PT_FLOAT_PRECISION,
RESERVED_PRECISION_DICT,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)
from deepmd.utils.data import (
DeepmdData,
)


class BaseModifier(torch.nn.Module, make_base_modifier()):
def __init__(self, use_cache: bool = True) -> None:
"""Construct a base modifier for data modification tasks."""
torch.nn.Module.__init__(self)
self.modifier_type = "base"
self.jitable = True

self.use_cache = use_cache

def serialize(self) -> dict:
"""Serialize the modifier.

Returns
-------
dict
The serialized data
"""
data = {
"@class": "Modifier",
"type": self.modifier_type,
"@version": 3,
}
return data

@classmethod
def deserialize(cls, data: dict) -> "BaseModifier":
"""Deserialize the modifier.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModifier
The deserialized modifier
"""
data = data.copy()
# Remove serialization metadata before passing to constructor
data.pop("@class", None)
data.pop("type", None)
data.pop("@version", None)
modifier = cls(**data)
return modifier

@abstractmethod
@torch.jit.export
def forward(
self,
coord: torch.Tensor,
atype: torch.Tensor,
box: torch.Tensor | None = None,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:
"""Compute energy, force, and virial corrections."""

@torch.jit.unused
def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None:
"""Modify data of single frame.

Parameters
----------
data
Internal data of DeepmdData.
Be a dict, has the following keys
- coord coordinates (nat, 3)
- box simulation box (9,)
- atype atom types (nat,)
- fparam frame parameter (nfp,)
- aparam atom parameter (nat, nap)
- find_energy tells if data has energy
- find_force tells if data has force
- find_virial tells if data has virial
- energy energy (1,)
- force force (nat, 3)
- virial virial (9,)
"""
if (
"find_energy" not in data
and "find_force" not in data
and "find_virial" not in data
):
return

prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]]

nframes = 1
natoms = len(data["atype"])
atom_types = np.tile(data["atype"], nframes).reshape(nframes, -1)

coord_input = torch.tensor(
data["coord"].reshape([nframes, natoms, 3]).astype(prec),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
type_input = torch.tensor(
atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISION_DICT[torch.long]]),
dtype=torch.long,
device=DEVICE,
)
if data["box"] is not None:
box_input = torch.tensor(
data["box"].reshape([nframes, 3, 3]).astype(prec),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
else:
box_input = None
if "fparam" in data:
fparam_input = to_torch_tensor(data["fparam"].reshape(nframes, -1))
else:
fparam_input = None
if "aparam" in data:
aparam_input = to_torch_tensor(data["aparam"].reshape(nframes, natoms, -1))
else:
aparam_input = None
do_atomic_virial = False

# implement data modification method in forward
modifier_data = self.forward(
coord_input,
type_input,
box_input,
fparam_input,
aparam_input,
do_atomic_virial,
)

if data.get("find_energy") == 1.0:
if "energy" not in modifier_data:
raise KeyError(
f"Modifier {self.__class__.__name__} did not provide 'energy' "
"in its output while 'find_energy' is set."
)
data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape(
data["energy"].shape
)
if data.get("find_force") == 1.0:
if "force" not in modifier_data:
raise KeyError(
f"Modifier {self.__class__.__name__} did not provide 'force' "
"in its output while 'find_force' is set."
)
data["force"] -= to_numpy_array(modifier_data["force"]).reshape(
data["force"].shape
)
if data.get("find_virial") == 1.0:
if "virial" not in modifier_data:
raise KeyError(
f"Modifier {self.__class__.__name__} did not provide 'virial' "
"in its output while 'find_virial' is set."
)
data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape(
data["virial"].shape
)
Loading