diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index ce23c5981b..5f17ca8d5d 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -212,6 +212,11 @@ def eval( aparam=aparam, **kwargs, ) + # TODO: if the grid is requested, we can directly return it without reshaping to energy, force and virial. We can also consider to return the grid in a separate key in the results dict, instead of reshaping it to energy, force and virial. + if "grid" in kwargs: + result = results["density"].reshape(nframes, -1) + return result + energy = results["energy_redu"].reshape(nframes, 1) force = results["energy_derv_r"].reshape(nframes, natoms, 3) virial = results["energy_derv_c_redu"].reshape(nframes, 9) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7b45c46333..a74d6eb653 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -421,7 +421,12 @@ def train( # Initialize DDP if os.environ.get("LOCAL_RANK") is not None: - dist.init_process_group(backend="cuda:nccl,cpu:gloo") + import datetime + + timeout = datetime.timedelta( + seconds=18000 + ) # set a longer timeout for for large datasets or slow file systems + dist.init_process_group(backend="cuda:nccl,cpu:gloo", timeout=timeout) trainer = get_trainer( config, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 3a44bde4fc..a97eaaed65 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -406,11 +406,7 @@ def eval( coords, atom_types, len(atom_types.shape) > 1 ) request_defs = self._get_request_defs(atomic) - if "spin" not in kwargs or kwargs["spin"] is None: - out = self._eval_func(self._eval_model, numb_test, natoms)( - coords, cells, atom_types, fparam, aparam, request_defs, charge_spin - ) - else: + if "spin" in kwargs and kwargs["spin"] is not None: out = self._eval_func(self._eval_model_spin, numb_test, natoms)( coords, cells, @@ -421,6 +417,21 @@ def eval( request_defs, charge_spin, ) + elif "grid" in kwargs and kwargs["grid"] is not None: + out = self._eval_func(self._eval_model_density, numb_test, natoms)( + coords, + cells, + atom_types, + np.array(kwargs["grid"]), + fparam, + aparam, + request_defs, + ) + return {"density": out} + else: + out = self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, fparam, aparam, request_defs, charge_spin + ) return dict( zip( [x.name for x in request_defs], @@ -688,6 +699,80 @@ def _eval_model_spin( ) # this is kinda hacky return tuple(results) + def _eval_model_density( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + grid: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + model = self.dp.to(DEVICE) + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([nframes, natoms, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE) + grid_input = torch.tensor( + grid.reshape([nframes, -1, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + ngrid = grid_input.shape[1] + if cells is not None: + box_input = torch.tensor( + cells.reshape([nframes, 3, 3]), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + else: + box_input = None + if fparam is not None: + fparam_input = to_torch_tensor( + fparam.reshape(nframes, self.get_dim_fparam()) + ) + else: + fparam_input = None + if aparam is not None: + aparam_input = to_torch_tensor( + aparam.reshape(nframes, natoms, self.get_dim_aparam()) + ) + else: + aparam_input = None + + do_atomic_virial = any( + x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs + ) + batch_output = model( + coord_input, + type_input, + grid=grid_input, + box=box_input, + do_atomic_virial=do_atomic_virial, + fparam=fparam_input, + aparam=aparam_input, + ) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + + results = [] + pt_name = "density" + density_shape = [nframes, ngrid] + out = batch_output[pt_name].reshape(density_shape).detach().cpu().numpy() + results.append(out) + return tuple(results) + def _get_output_shape( self, odef: OutputVariableDef, nframes: int, natoms: int ) -> list[int]: diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 1d25c1e52f..79fec9ca41 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .charge import ( + GridDensityLoss, +) from .denoise import ( DenoiseLoss, ) @@ -28,6 +31,7 @@ "EnergyHessianStdLoss", "EnergySpinLoss", "EnergyStdLoss", + "GridDensityLoss", "PropertyLoss", "TaskLoss", "TensorLoss", diff --git a/deepmd/pt/loss/charge.py b/deepmd/pt/loss/charge.py new file mode 100644 index 0000000000..5199d2b00d --- /dev/null +++ b/deepmd/pt/loss/charge.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.pt.loss.loss import ( + TaskLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + GLOBAL_PT_FLOAT_PRECISION, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + + +class GridDensityLoss(TaskLoss): + def __init__( + self, + starter_learning_rate: float = 1.0, + start_pref_d: float = 0.0, + limit_pref_d: float = 0.0, + inference: bool = False, + **kwargs: Any, + ) -> None: + r"""Construct a layer to compute loss on grid density. + + Parameters + ---------- + starter_learning_rate : float + The learning rate at the start of the training. + start_pref_d : float + The prefactor of charge density loss at the start of the training. + limit_pref_d : float + The prefactor of charge density loss at the end of the training. + inference : bool + If true, it will output all losses found in output, ignoring the pre-factors. + **kwargs + Other keyword arguments. + """ + super().__init__() + self.starter_learning_rate = starter_learning_rate + self.has_d = (start_pref_d != 0.0 and limit_pref_d != 0.0) or inference + + self.start_pref_d = start_pref_d + self.limit_pref_d = limit_pref_d + self.inference = inference + + def forward( + self, + input_dict: dict[str, torch.Tensor], + model: torch.nn.Module, + label: dict[str, torch.Tensor], + natoms: int, + learning_rate: float, + mae: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]: + """Return loss on energy and force. + + Parameters + ---------- + input_dict : dict[str, torch.Tensor] + Model inputs. + model : torch.nn.Module + Model to be used to output the predictions. + label : dict[str, torch.Tensor] + Labels. + natoms : int + The local atom number. + + Returns + ------- + model_pred: dict[str, torch.Tensor] + Model predictions. + loss: torch.Tensor + Loss for model to minimize. + more_loss: dict[str, torch.Tensor] + Other losses for display. + """ + model_pred = model(**input_dict) + coef = learning_rate / self.starter_learning_rate + pref_d = self.limit_pref_d + (self.start_pref_d - self.limit_pref_d) * coef + + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] + more_loss = {} + # more_loss['log_keys'] = [] # showed when validation on the fly + # more_loss['test_keys'] = [] # showed when doing dp test + atom_norm = 1.0 / natoms + if self.has_d and "density" in model_pred and "density" in label: + density_pred = model_pred["density"] + density_label = label["density"] + find_density = label.get("find_density", 0.0) + pref_d = pref_d * find_density + density_pred_reshape = density_pred.reshape(-1) + density_label_reshape = density_label.reshape(-1) + l2_density_loss = torch.square( + density_label_reshape - density_pred_reshape + ).mean() + rmse_d = l2_density_loss.sqrt() + more_loss["rmse_d"] = self.display_if_exist(rmse_d.detach(), find_density) + l1_density_loss = torch.abs( + density_label_reshape - density_pred_reshape + ).mean() + loss += (pref_d * l1_density_loss).to(GLOBAL_PT_FLOAT_PRECISION) + mae_d = l1_density_loss + more_loss["mae_d"] = self.display_if_exist(mae_d.detach(), find_density) + return model_pred, loss, more_loss + + @property + def label_requirement(self) -> list[DataRequirementItem]: + """Return data label requirements needed for this loss calculation.""" + label_requirement = [] + label_requirement.append( + DataRequirementItem( + "grid", + ndof=3, + atomic=True, # the grid is defined for each atom, so it is atomic + must=True, + high_prec=True, + ) + ) + if self.has_d: + label_requirement.append( + DataRequirementItem( + "density", + ndof=1, + atomic=True, + must=False, + high_prec=True, + ) + ) + return label_requirement diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..aa28bf437c 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -17,6 +17,9 @@ from .base_atomic_model import ( BaseAtomicModel, ) +from .density_atomic_model import ( + DPDensityAtomicModel, +) from .dipole_atomic_model import ( DPDipoleAtomicModel, ) @@ -47,6 +50,7 @@ "BaseAtomicModel", "DPAtomicModel", "DPDOSAtomicModel", + "DPDensityAtomicModel", "DPDipoleAtomicModel", "DPEnergyAtomicModel", "DPPolarAtomicModel", diff --git a/deepmd/pt/model/atomic_model/density_atomic_model.py b/deepmd/pt/model/atomic_model/density_atomic_model.py new file mode 100644 index 0000000000..1eccdeb6d0 --- /dev/null +++ b/deepmd/pt/model/atomic_model/density_atomic_model.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from collections.abc import ( + Callable, +) +from typing import ( + Any, +) + +import torch + +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.density import ( + DensityFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + build_directional_neighbor_list, + extend_input_and_build_neighbor_list, +) +from deepmd.utils.path import ( + DPPath, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + +log = logging.getLogger(__name__) + + +class DPDensityAtomicModel(DPAtomicModel): + def __init__( + self, + descriptor: BaseDescriptor, + fitting: DensityFittingNet, + type_map: list[str], + **kwargs: Any, + ) -> None: + assert isinstance(fitting, DensityFittingNet) + super().__init__(descriptor, fitting, type_map, **kwargs) + self.rcut = self.descriptor.get_rcut() + self.rcut_smth = self.descriptor.get_rcut_smth() + self.env_protection = self.descriptor.get_env_protection() + if self.env_protection == 0.0: + self.env_protection = 1e-6 + self.sel = self.descriptor.get_sel() + self.nnei = self.descriptor.get_nsel() + + wanted_shape = (1, self.nnei, 4) + mean = torch.zeros( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + stddev = torch.ones( + wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) + self.register_buffer("mean", mean) + self.register_buffer("stddev", stddev) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + grid: torch.Tensor | None = None, + grid_type: torch.Tensor | None = None, + grid_nlist: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return atomic prediction. + + Parameters + ---------- + extended_coord + coodinates in extended region + extended_atype + atomic type in extended region + nlist + neighbor list. nf x nloc x nsel + mapping + mapps the extended indices to local indices + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + + Returns + ------- + result_dict + the result dict, defined by the `FittingOutputDef`. + + """ + nframes, nloc, nnei = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + assert mapping is not None + assert grid is not None + assert grid_type is not None + assert grid_nlist is not None + bsz, ngrid, nnei = grid_nlist.shape + # nb x (ngrid+nall) x 3 + merged_coord = torch.cat([grid, extended_coord], dim=1) + + grid_atype = torch.ones( + [nframes, ngrid], device=extended_atype.device, dtype=extended_atype.dtype + ) * (self.descriptor.ntypes - 1) + # nb x (ngrid+nall) + merged_atype = torch.cat([grid_atype, extended_atype], dim=1) + + # nb x ngrid + grid_nlist_mask = grid_nlist >= 0 + shifted_grid_nlist = torch.where(grid_nlist_mask, grid_nlist + ngrid, -1) + # nb x all + nlist_mask = nlist >= 0 + shifted_nlist = torch.where(nlist_mask, nlist + ngrid, -1) + # nb x (ngrid+nall) + merged_nlist = torch.cat([shifted_grid_nlist, shifted_nlist], dim=1) + + grid_mapping = torch.cat( + [ + torch.ones([nframes, 1], device=mapping.device, dtype=mapping.dtype) * i + for i in range(ngrid) + ], + dim=1, + ) + # nb x (ngrid+nall) + merged_mapping = torch.cat([grid_mapping, mapping + ngrid], dim=1) + + descriptor, rot_mat, g2, h2, sw = self.descriptor( + merged_coord, + merged_atype, + merged_nlist, + mapping=merged_mapping, + comm_dict=comm_dict, + ) + assert descriptor is not None + + ret = self.fitting_net( + descriptor[:, :ngrid, :], + torch.zeros( + [nframes, ngrid], device=grid_type.device, dtype=grid_type.dtype + ), + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + return ret + + def forward_common_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + grid: torch.Tensor | None = None, + grid_type: torch.Tensor | None = None, + grid_nlist: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Common interface for atomic inference. + + This method accept extended coordinates, extended atom typs, neighbor list, + and predict the atomic contribution of the fit property. + + Parameters + ---------- + extended_coord + extended coodinates, shape: nf x (nall x 3) + extended_atype + extended atom typs, shape: nf x nall + for a type < 0 indicating the atomic is virtual. + nlist + neighbor list, shape: nf x nloc x nsel + mapping + extended to local index mapping, shape: nf x nall + fparam + frame parameters, shape: nf x dim_fparam + aparam + atomic parameter, shape: nf x nloc x dim_aparam + comm_dict + The data needed for communication for parallel inference. + + Returns + ------- + ret_dict + dict of output atomic properties. + should implement the definition of `fitting_output_def`. + ret_dict["mask"] of shape nf x nloc will be provided. + ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real. + ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. + + """ + assert grid is not None + assert grid_type is not None + assert grid_nlist is not None + nframes, nloc, _ = nlist.shape + _, ngrid, _ = grid_nlist.shape + atype = extended_atype[:, :nloc] + + if self.pair_excl is not None: + pair_mask = self.pair_excl(nlist, extended_atype) + # exclude neighbors in the nlist + nlist = torch.where(pair_mask == 1, nlist, -1) + + ext_atom_mask = self.make_atom_mask(extended_atype) + ret_dict = self.forward_atomic( + extended_coord, + torch.where(ext_atom_mask, extended_atype, 0), + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + comm_dict=comm_dict, + grid=grid, + grid_type=grid_type, + grid_nlist=grid_nlist, + ) + ret_dict = self.apply_out_stat(ret_dict, grid_type) + + ext_grid_mask = self.make_atom_mask(grid_type) + + # nf x ngrid + grid_mask = torch.ones( + [nframes, ngrid], dtype=torch.int32, device=ext_atom_mask.device + ) + if self.atom_excl is not None: + grid_mask *= self.atom_excl(grid_type) + + for kk in ret_dict.keys(): + out_shape = ret_dict[kk].shape + out_shape2 = 1 + for ss in out_shape[2:]: + out_shape2 *= ss + ret_dict[kk] = ( + ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) + * grid_mask[:, :, None] + ).view(out_shape) + ret_dict["mask"] = grid_mask + + return ret_dict + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + return self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + comm_dict=comm_dict, + ) + + def _get_forward_wrapper_func(self) -> Callable[..., torch.Tensor]: + """Get a forward wrapper of the atomic model for output bias calculation.""" + + def model_forward( + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + grid: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + with torch.no_grad(): + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=self.mixed_types(), + box=box, + ) + assert grid is not None + grid_type = torch.zeros( + grid.shape[:-1], device=grid.device, dtype=atype.dtype + ) + grid_nlist = build_directional_neighbor_list( + grid, + grid_type, + extended_coord, + extended_atype, + self.get_rcut(), + self.get_sel(), + distinguish_types=(not self.mixed_types()), + ) + atomic_ret = self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + grid=grid, + grid_type=grid_type, + grid_nlist=grid_nlist, + ) + return {kk: vv.detach() for kk, vv in atomic_ret.items()} + + return model_forward + + def change_out_bias( + self, + sample_merged: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias according to the input data and the pretrained model. + + For density models, this operation is skipped because the output is + grid-based rather than atomic-based, and the standard atomic bias + adjustment (change-by-statistic / set-by-statistic) does not apply. + The fitting net will adapt to the target dataset through normal + gradient descent during training. + """ + log.warning("change_out_bias is not supported for density models; skipping.") + return + + def compute_or_load_out_stat( + self, + merged: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + ) -> None: + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + log.warning("Not implemented yet for density out stat!") + pass diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 24075412db..9c0ca10c0f 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -33,6 +33,9 @@ Spin, ) +from .density_model import ( + GridDensityModel, +) from .dipole_model import ( DipoleModel, ) @@ -268,6 +271,8 @@ def get_standard_model(model_params: dict) -> BaseModel: modelcls = DOSModel elif fitting_net_type in ["ener", "direct_force_ener"]: modelcls = EnergyModel + elif fitting_net_type == "density": + modelcls = GridDensityModel elif fitting_net_type == "property": modelcls = PropertyModel else: diff --git a/deepmd/pt/model/model/density_model.py b/deepmd/pt/model/model/density_model.py new file mode 100644 index 0000000000..7c24f7bd05 --- /dev/null +++ b/deepmd/pt/model/model/density_model.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import torch + +from deepmd.pt.model.atomic_model import ( + DPDensityAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_density_model import ( + make_density_model, +) + +DPDensityModel_ = make_density_model(DPDensityAtomicModel) + + +@BaseModel.register("grid_density") +class GridDensityModel(DPModelCommon, DPDensityModel_): + model_type = "grid_density" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + DPDensityModel_.__init__(self, *args, **kwargs) + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "density": deepcopy(out_def_data["density"]), + } + if "mask" in out_def_data: + output_def["mask"] = deepcopy(out_def_data["mask"]) + return output_def + + @torch.jit.export + def has_grid(self) -> bool: + """Returns whether it has grid input and output.""" + return True + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + grid: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + grid, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["density"] = model_ret["density"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + charge_spin: torch.Tensor | None = None, + ) -> None: + raise NotImplementedError diff --git a/deepmd/pt/model/model/make_density_model.py b/deepmd/pt/model/model/make_density_model.py new file mode 100644 index 0000000000..d7f50ccfa3 --- /dev/null +++ b/deepmd/pt/model/model/make_density_model.py @@ -0,0 +1,657 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Callable, +) +from typing import ( + Any, + Optional, +) + +import torch + +from deepmd.dpmodel import ( + ModelOutputDef, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableCategory, + OutputVariableOperation, + check_operation_applied, +) +from deepmd.pt.model.atomic_model.base_atomic_model import ( + BaseAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, + fit_output_to_model_output, +) +from deepmd.pt.utils.env import ( + GLOBAL_PT_ENER_FLOAT_PRECISION, + GLOBAL_PT_FLOAT_PRECISION, + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.nlist import ( + build_directional_neighbor_list, + extend_input_and_build_neighbor_list, + nlist_distinguish_types, +) +from deepmd.utils.path import ( + DPPath, +) + + +def make_density_model(T_AtomicModel: type[BaseAtomicModel]) -> type[BaseModel]: + """Make a density model as a derived class of an atomic model. + + The model provide two interfaces. + + 1. the `forward_common_lower`, that takes extended coordinates, atyps and neighbor list, + and outputs the atomic and property and derivatives (if required) on the extended region. + + 2. the `forward_common`, that takes coordinates, atypes and cell and predicts + the atomic and reduced property, and derivatives (if required) on the local region. + + Parameters + ---------- + T_AtomicModel + The atomic model. + + Returns + ------- + CM + The model. + + """ + + class CM(BaseModel): + def __init__( + self, + *args: Any, + # underscore to prevent conflict with normal inputs + atomic_model_: T_AtomicModel | None = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if atomic_model_ is not None: + self.atomic_model: T_AtomicModel = atomic_model_ + else: + self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) + self.precision_dict = PRECISION_DICT + self.reverse_precision_dict = RESERVED_PRECISION_DICT + self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION + self.global_pt_ener_float_precision = GLOBAL_PT_ENER_FLOAT_PRECISION + + def model_output_def(self) -> ModelOutputDef: + """Get the output def for the model.""" + return ModelOutputDef(self.atomic_output_def()) + + @torch.jit.export + def model_output_type(self) -> list[str]: + """Get the output type for the model.""" + output_def = self.model_output_def() + var_defs = output_def.var_defs + # jit: Comprehension ifs are not supported yet + # type hint is critical for JIT + vars: list[str] = [] + for kk, vv in var_defs.items(): + # .value is critical for JIT + if vv.category == OutputVariableCategory.OUT.value: + vars.append(kk) + return vars + + @torch.jit.export + def has_chg_spin_ebd(self) -> bool: + """Check if the model has charge spin embedding.""" + return self.atomic_model.has_chg_spin_ebd() + + @torch.jit.export + def get_dim_chg_spin(self) -> int: + """Get the dimension of charge_spin input.""" + return self.atomic_model.get_dim_chg_spin() + + @torch.jit.export + def has_default_chg_spin(self) -> bool: + """Check if the model has default charge_spin values.""" + return self.atomic_model.has_default_chg_spin() + + @torch.jit.export + def get_default_chg_spin(self) -> torch.Tensor | None: + """Get the default charge_spin values.""" + return self.atomic_model.get_default_chg_spin() + + # cannot use the name forward. torch script does not work + def forward_common( + self, + coord: torch.Tensor, + atype: torch.Tensor, + grid: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + coord + The coordinates of the grids. + shape: nf x (ngrid x 3) + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,torch.Tensor]. + The keys are defined by the `ModelOutputDef`. + + """ + assert grid is not None + cc, gg, bb, fp, ap, input_prec = self.input_type_cast( + coord, grid, box=box, fparam=fparam, aparam=aparam + ) + del coord, grid, box, fparam, aparam + gg = gg.view(gg.shape[0], -1, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=self.mixed_types(), + box=bb, + ) + grid_type = torch.zeros( + gg.shape[0], gg.shape[1], device=gg.device, dtype=atype.dtype + ) + grid_nlist = build_directional_neighbor_list( + gg, + grid_type, + extended_coord, + extended_atype, + self.get_rcut(), + self.get_sel(), + distinguish_types=(not self.mixed_types()), + ) + model_predict_lower = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + grid=gg, + grid_type=grid_type, + grid_nlist=grid_nlist, + mapping=mapping, + do_atomic_virial=do_atomic_virial, + fparam=fp, + aparam=ap, + ) + model_predict = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + model_predict = self.output_type_cast(model_predict, input_prec) + return model_predict + + def get_out_bias(self) -> torch.Tensor: + return self.atomic_model.get_out_bias() + + def set_out_bias(self, out_bias: torch.Tensor) -> None: + self.atomic_model.set_out_bias(out_bias) + + def change_out_bias( + self, + merged: Callable[[], list[dict]] | list[dict], + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias of atomic model according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + """ + self.atomic_model.change_out_bias( + merged, + bias_adjust_mode=bias_adjust_mode, + ) + + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + grid: torch.Tensor, + grid_type: torch.Tensor, + grid_nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extra_nlist_sort: bool = False, + ) -> dict[str, torch.Tensor]: + """Return model prediction. Lower interface that takes + extended atomic coordinates and types, nlist, and mapping + as input, and returns the predictions on the extended region. + The predictions are not reduced. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x (nall x 3) + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel. + mapping + mapps the extended indices to local indices. nf x nall. + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + whether calculate atomic virial. + comm_dict + The data needed for communication for parallel inference. + extra_nlist_sort + whether to forcibly sort the nlist. + + Returns + ------- + result_dict + the result dict, defined by the `FittingOutputDef`. + + """ + nframes, nall = extended_atype.shape[:2] + extended_coord = extended_coord.view(nframes, -1, 3) + nlist = self.format_nlist( + extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort + ) + assert grid is not None + cc_ext, gg, _, fp, ap, input_prec = self.input_type_cast( + extended_coord, grid, fparam=fparam, aparam=aparam + ) + del extended_coord, grid, fparam, aparam + atomic_ret = self.atomic_model.forward_common_atomic( + cc_ext, + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + comm_dict=comm_dict, + grid=gg, + grid_type=grid_type, + grid_nlist=grid_nlist, + ) + model_predict = fit_output_to_model_output( + atomic_ret, + self.atomic_output_def(), + cc_ext, + do_atomic_virial=do_atomic_virial, + create_graph=self.training, + ) + model_predict = self.output_type_cast(model_predict, input_prec) + return model_predict + + def input_type_cast( + self, + coord: torch.Tensor, + grid: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + str, + ]: + """Cast the input data to global float type.""" + input_prec = self.reverse_precision_dict[coord.dtype] + ### + ### type checking would not pass jit, convert to coord prec anyway + ### + # for vv, kk in zip([fparam, aparam], ["frame", "atomic"]): + # if vv is not None and self.reverse_precision_dict[vv.dtype] != input_prec: + # log.warning( + # f"type of {kk} parameter {self.reverse_precision_dict[vv.dtype]}" + # " does not match" + # f" that of the coordinate {input_prec}" + # ) + _lst: list[torch.Tensor | None] = [ + vv.to(coord.dtype) if vv is not None else None + for vv in [grid, box, fparam, aparam] + ] + grid, box, fparam, aparam = _lst + assert grid is not None + if ( + input_prec + == self.reverse_precision_dict[self.global_pt_float_precision] + ): + return coord, grid, box, fparam, aparam, input_prec + else: + pp = self.global_pt_float_precision + return ( + coord.to(pp), + grid.to(pp), + box.to(pp) if box is not None else None, + fparam.to(pp) if fparam is not None else None, + aparam.to(pp) if aparam is not None else None, + input_prec, + ) + + def output_type_cast( + self, + model_ret: dict[str, torch.Tensor], + input_prec: str, + ) -> dict[str, torch.Tensor]: + """Convert the model output to the input prec.""" + do_cast = ( + input_prec + != self.reverse_precision_dict[self.global_pt_float_precision] + ) + pp = self.precision_dict[input_prec] + odef = self.model_output_def() + for kk in odef.keys(): + if kk not in model_ret.keys(): + # do not return energy_derv_c if not do_atomic_virial + continue + if check_operation_applied(odef[kk], OutputVariableOperation.REDU): + model_ret[kk] = ( + model_ret[kk].to(self.global_pt_ener_float_precision) + if model_ret[kk] is not None + else None + ) + elif do_cast: + model_ret[kk] = ( + model_ret[kk].to(pp) if model_ret[kk] is not None else None + ) + return model_ret + + def format_nlist( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extra_nlist_sort: bool = False, + ) -> torch.Tensor: + """Format the neighbor list. + + 1. If the number of neighbors in the `nlist` is equal to sum(self.sel), + it does nothong + + 2. If the number of neighbors in the `nlist` is smaller than sum(self.sel), + the `nlist` is pad with -1. + + 3. If the number of neighbors in the `nlist` is larger than sum(self.sel), + the nearest sum(sel) neighbors will be preseved. + + Known limitations: + + In the case of not self.mixed_types, the nlist is always formatted. + May have side effact on the efficiency. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x nall x 3 + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel + extra_nlist_sort + whether to forcibly sort the nlist. + + Returns + ------- + formated_nlist + the formated nlist. + + """ + mixed_types = self.mixed_types() + nlist = self._format_nlist( + extended_coord, + nlist, + sum(self.get_sel()), + extra_nlist_sort=extra_nlist_sort, + ) + if not mixed_types: + nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel()) + return nlist + + def _format_nlist( + self, + extended_coord: torch.Tensor, + nlist: torch.Tensor, + nnei: int, + extra_nlist_sort: bool = False, + ) -> torch.Tensor: + n_nf, n_nloc, n_nnei = nlist.shape + # nf x nall x 3 + extended_coord = extended_coord.view([n_nf, -1, 3]) + rcut = self.get_rcut() + + if n_nnei < nnei: + nlist = torch.cat( + [ + nlist, + -1 + * torch.ones( + [n_nf, n_nloc, nnei - n_nnei], + dtype=nlist.dtype, + device=nlist.device, + ), + ], + dim=-1, + ) + + if n_nnei > nnei or extra_nlist_sort: + n_nf, n_nloc, n_nnei = nlist.shape + m_real_nei = nlist >= 0 + nlist = torch.where(m_real_nei, nlist, 0) + # nf x nloc x 3 + coord0 = extended_coord[:, :n_nloc, :] + # nf x (nloc x nnei) x 3 + index = nlist.view(n_nf, n_nloc * n_nnei, 1).expand(-1, -1, 3) + coord1 = torch.gather(extended_coord, 1, index) + # nf x nloc x nnei x 3 + coord1 = coord1.view(n_nf, n_nloc, n_nnei, 3) + # nf x nloc x nnei + rr = torch.linalg.norm(coord0[:, :, None, :] - coord1, dim=-1) + rr = torch.where(m_real_nei, rr, float("inf")) + rr, nlist_mapping = torch.sort(rr, dim=-1) + nlist = torch.gather(nlist, 2, nlist_mapping) + nlist = torch.where(rr > rcut, -1, nlist) + nlist = nlist[..., :nnei] + else: # not extra_nlist_sort and n_nnei <= nnei: + pass # great! + assert nlist.shape[-1] == nnei + return nlist + + def do_grad_r( + self, + var_name: str | None = None, + ) -> bool: + """Tell if the output variable `var_name` is r_differentiable. + if var_name is None, returns if any of the variable is r_differentiable. + """ + return self.atomic_model.do_grad_r(var_name) + + def do_grad_c( + self, + var_name: str | None = None, + ) -> bool: + """Tell if the output variable `var_name` is c_differentiable. + if var_name is None, returns if any of the variable is c_differentiable. + """ + return self.atomic_model.do_grad_c(var_name) + + def change_type_map( + self, + type_map: list[str], + model_with_new_type_stat: Optional["CM"] = None, + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + self.atomic_model.change_type_map( + type_map=type_map, + model_with_new_type_stat=model_with_new_type_stat.atomic_model + if model_with_new_type_stat is not None + else None, + ) + + def serialize(self) -> dict: + return self.atomic_model.serialize() + + @classmethod + def deserialize(cls, data: dict) -> "CM": + return cls(atomic_model_=T_AtomicModel.deserialize(data)) + + @torch.jit.export + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.atomic_model.get_dim_fparam() + + @torch.jit.export + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.atomic_model.get_dim_aparam() + + @torch.jit.export + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.atomic_model.get_sel_type() + + @torch.jit.export + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self.atomic_model.is_aparam_nall() + + @torch.jit.export + def get_rcut(self) -> float: + """Get the cut-off radius.""" + return self.atomic_model.get_rcut() + + @torch.jit.export + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.atomic_model.get_type_map() + + @torch.jit.export + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nsel() + + @torch.jit.export + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.atomic_model.get_nnei() + + def atomic_output_def(self) -> FittingOutputDef: + """Get the output def of the atomic model.""" + return self.atomic_model.atomic_output_def() + + def compute_or_load_stat( + self, + sampled_func: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + preset_observed_type: list[str] | None = None, + ) -> None: + """Compute or load the statistics.""" + return self.atomic_model.compute_or_load_stat( + sampled_func, + stat_file_path, + preset_observed_type=preset_observed_type, + ) + + def get_sel(self) -> list[int]: + """Returns the number of selected atoms for each type.""" + return self.atomic_model.get_sel() + + def mixed_types(self) -> bool: + """If true, the model + 1. assumes total number of atoms aligned across frames; + 2. uses a neighbor list that does not distinguish different atomic types. + + If false, the model + 1. assumes total number of atoms of each atom type aligned across frames; + 2. uses a neighbor list that distinguishes different atomic types. + + """ + return self.atomic_model.mixed_types() + + @torch.jit.export + def has_message_passing(self) -> bool: + """Returns whether the model has message passing.""" + return self.atomic_model.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.atomic_model.need_sorted_nlist_for_lower() + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + # directly call the forward_common method when no specific transform rule + return self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return CM diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 37ffec2725..79e1db3bce 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -5,6 +5,9 @@ from .denoise import ( DenoiseNet, ) +from .density import ( + DensityFittingNet, +) from .dipole import ( DipoleFittingNet, ) @@ -32,6 +35,7 @@ "BaseFitting", "DOSFittingNet", "DenoiseNet", + "DensityFittingNet", "DipoleFittingNet", "EnergyFittingNet", "EnergyFittingNetDirect", diff --git a/deepmd/pt/model/task/density.py b/deepmd/pt/model/task/density.py new file mode 100644 index 0000000000..49853ff05c --- /dev/null +++ b/deepmd/pt/model/task/density.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import logging +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.task.fitting import ( + Fitting, + GeneralFitting, +) +from deepmd.pt.model.task.invar_fitting import ( + InvarFitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) + + +@Fitting.register("density") +class DensityFittingNet(InvarFitting): + def __init__( + self, + ntypes: int, + dim_descrpt: int, + neuron: list[int] = [128, 128, 128], + bias_atom_e: torch.Tensor | None = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + seed: int | list[int] | None = None, + type_map: list[str] | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + "density", + ntypes, + dim_descrpt, + 1, + neuron=neuron, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + seed=seed, + type_map=type_map, + **kwargs, + ) + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + self.var_name, + [self.dim_out], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + @classmethod + def deserialize(cls, data: dict) -> "GeneralFitting": + data = copy.deepcopy(data) + check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("var_name") + data.pop("dim_out") + return super().deserialize(data) + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + **super().serialize(), + "type": "density", + } + + # make jit happy with torch 2.0.0 + exclude_types: list[int] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a5b799dbdc..a7c9fc02da 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -43,6 +43,7 @@ EnergyHessianStdLoss, EnergySpinLoss, EnergyStdLoss, + GridDensityLoss, PropertyLoss, TaskLoss, TensorLoss, @@ -2033,6 +2034,7 @@ def get_data( "coord", "atype", "spin", + "grid", "box", "fparam", "aparam", @@ -2216,6 +2218,9 @@ def get_loss( tensor_name = "polar" loss_params["tensor_name"] = tensor_name return TensorLoss(**loss_params) + elif loss_type == "grid_density": + loss_params["starter_learning_rate"] = start_lr + return GridDensityLoss(**loss_params) elif loss_type == "property": task_dim = _model.get_task_dim() var_name = _model.get_var_name() diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index 1d741dd534..b098e98dd3 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -157,6 +157,7 @@ def forward( coord: torch.Tensor, atype: torch.Tensor, spin: torch.Tensor | None = None, + grid: torch.Tensor | None = None, box: torch.Tensor | None = None, cur_lr: torch.Tensor | None = None, label: torch.Tensor | None = None, @@ -188,6 +189,12 @@ def forward( if has_spin: input_dict["spin"] = spin + has_grid = getattr(self.model[task_key], "has_grid", False) + if callable(has_grid): + has_grid = has_grid() + if has_grid: + input_dict["grid"] = grid + if self.inference_only or inference_only: model_pred = self.model[task_key](**input_dict) if self.modifier is not None: diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 807c1f5ba1..176e60bf09 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -222,6 +222,7 @@ def print_summary( name: str, prob: list[float], ) -> None: + return rank = dist.get_rank() if dist.is_initialized() else 0 if rank == 0: print_summary( diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index f8c3685b78..5bf1bcf702 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -200,8 +200,17 @@ def model_forward_auto_batch_size(*args: Any, **kwargs: Any) -> Any: **kwargs, ) + grid_kwargs = {} + if "grid" in system: + grid_kwargs["grid"] = system["grid"] sample_predict = model_forward_auto_batch_size( - coord, atype, box, fparam=fparam, aparam=aparam, charge_spin=charge_spin + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + **grid_kwargs, ) for kk in keys: model_predict[kk].append( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 060bb90524..0338da5d4e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1883,6 +1883,54 @@ def fitting_ener() -> list[Argument]: ] +@fitting_args_plugin.register("density") +def fitting_density() -> list[Argument]: + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_trainable = f"Whether the parameters in the fitting net are trainable. This option can be\n\n\ +- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\ +- list of bool{doc_only_tf_supported}: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1." + doc_rcond = "The condition number used to determine the inital density shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." + doc_seed = "Random seed for parameter initialization of the fitting net" + + return [ + Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), + Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "neuron", + list[int], + optional=True, + default=[120, 120, 120], + alias=["n_neuron"], + doc=doc_neuron, + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_trainable, + ), + Argument( + "rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond + ), + Argument("seed", [int, None], optional=True, doc=doc_seed), + ] + + @fitting_args_plugin.register("dos", doc=doc_dos) def fitting_dos() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." @@ -3448,6 +3496,28 @@ def loss_ener() -> list[Argument]: ] +@loss_args_plugin.register("grid_density") +def loss_grid_density() -> list[Argument]: + doc_start_pref_d = start_pref("density", abbr="d") + doc_limit_pref_d = limit_pref("density") + return [ + Argument( + "start_pref_d", + [float, int], + optional=True, + default=1.00, + doc=doc_start_pref_d, + ), + Argument( + "limit_pref_d", + [float, int], + optional=True, + default=1.00, + doc=doc_limit_pref_d, + ), + ] + + @loss_args_plugin.register("ener_spin") def loss_ener_spin() -> list[Argument]: doc_start_pref_e = start_pref("energy") diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 3fbb9f636f..7e4e3cec60 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -813,6 +813,9 @@ def _load_data( else: dtype = GLOBAL_NP_FLOAT_PRECISION path = self._get_data_path(set_name, key) + if key in ["grid", "density"] and path.is_file(): + data = path.load_numpy().astype(dtype) + return np.float32(1.0), data if path.is_file(): data = path.load_numpy().astype(dtype) try: # YWolfeee: deal with data shape error @@ -963,6 +966,8 @@ def _load_single_data( data = mmap_obj[frame_idx].copy().astype(dtype, copy=False) try: + if key in ("grid", "density"): + return np.float32(1.0), data if vv["atomic"]: # Handle type_sel logic if vv["type_sel"] is not None: diff --git a/examples/density/README.md b/examples/density/README.md new file mode 100644 index 0000000000..ef437c64df --- /dev/null +++ b/examples/density/README.md @@ -0,0 +1,182 @@ +# Charge Density Prediction Example + +This example demonstrates how to train and evaluate a **charge density** model using DeePMD-kit with the PyTorch backend. + +The model predicts the charge density on a set of grid points (`grid`) for a given atomic configuration (`coord`, `atype`, `box`). + +______________________________________________________________________ + +## Directory Structure + +``` +. +├── dpa2/ # DPA-2 descriptor example +│ └── input.json +├── dpa3/ # DPA-3 descriptor example +│ ├── input.json +├── dataset/ +│ └── qm9/ +│ ├── C7H15NO_train/ # Training data (deepmd/npy format) +│ └── C7H15NO_val/ # Validation data (deepmd/npy format) +└── dptest_density_script.py # Evaluation script for density models +``` + +______________________________________________________________________ + +## Data Format + +The training/validation data follows the standard **`deepmd/npy`** format, with two additional files in each `set.000/` directory: + +| File | Shape | Description | +| ----------------- | ----------------------- | ---------------------------------- | +| `coord.npy` | `[nframes, natoms * 3]` | Atomic coordinates | +| `box.npy` | `[nframes, 9]` | Simulation cell vectors | +| `type.raw` | `[natoms]` | Atom type indices | +| `type_map.raw` | `[ntypes]` | Type map (e.g., `C H N O ...`) | +| **`grid.npy`** | `[nframes, ngrid, 3]` | **Grid point coordinates** | +| **`density.npy`** | `[nframes, ngrid, 1]` | **Charge density labels on grids** | + +> **Note:** `grid.npy` and `density.npy` are required for the density model. The number of grid points (`ngrid`) must match between `grid.npy` and `density.npy`. + +______________________________________________________________________ + +## Training + +### 1. Choose a Configuration + +Two example configurations are provided: + +- **`dpa2/input.json`** — Uses the DPA-2 descriptor. +- **`dpa3/input.json`** — Uses the DPA-3 descriptor (recommended). + +Key parameters in `input.json`: + +```json +{ + "model": { + "type_map": [ + "Li", + "Ni", + "Co", + "Mn", + "O", + "C", + "H", + "N", + "F", + "X" + ], + "descriptor": { + "type": "dpa3" + }, + "fitting_net": { + "type": "density", + "neuron": [ + 240, + 240, + 240 + ] + } + }, + "loss": { + "type": "grid_density", + "start_pref_d": 1, + "limit_pref_d": 1 + }, + "training": { + "training_data": { + "systems": [ + "../dataset/qm9/C7H15NO_train" + ], + "batch_size": "auto:128" + }, + "validation_data": { + "systems": [ + "../dataset/qm9/C7H15NO_val" + ], + "batch_size": 1, + "numb_btch": 3 + } + } +} +``` + +### 2. Run Training + +```bash +cd dpa3 # or cd dpa2 +dp --pt train input.json +``` + +The training will output: + +- `model.ckpt-*.pt` — Model checkpoints +- `lcurve.out` — Training/validation loss curves +- `out.json` — Final training parameters + +### 3. Finetune from a Pretrained Model + +To finetune an existing density checkpoint: + +```bash +cd dpa3 +dp --pt train input.json --finetune model.ckpt-*.pt +``` + +> **Note:** For density models, `change_out_bias` (the energy-bias adjustment used in standard finetuning) is **automatically skipped** because density outputs are grid-based, not atomic-based. The descriptor weights are inherited, and the fitting net adapts via normal gradient descent. + +### 4. Freeze the Model + +To export a trained checkpoint into a frozen model for inference: + +```bash +cd dpa3 +dp --pt freeze -c . -o frozen_model +``` + +This generates `frozen_model.pth` (PyTorch backend). + +______________________________________________________________________ + +## Testing / Evaluation + +Use the provided `dptest_density_script.py` to evaluate a trained model on validation or test data. + +### Basic Usage + +```bash +cd /aisi/yuzhiLiu/deepmd-kit-charge/deepmd-kit-dpa3/examples/density + +python dptest_density_script.py \ + dpa3/model.ckpt-*pt \ + dataset/qm9/C7H15NO_val \ + --ratio 0.1 \ + --output val_result.txt +``` + +Arguments: + +| Argument | Description | +| ------------- | ------------------------------------------------------------------------- | +| `model` | Path to the model file (`.pt` checkpoint or `.pth` frozen model) | +| `data_dir` | Root directory of deepmd/npy datasets | +| `--ratio` | Fraction of frames to randomly sample (default: `0.1`) | +| `--output` | If provided, save screen output to this file | +| `--pred-file` | File to save paired `[prediction, label]` array (default: `result.d.out`) | + +### Evaluate the Full QM9 Dataset + +```bash +python dptest_density_script.py \ + dpa3/model.ckpt-100.pt \ + dataset/qm9 \ + --ratio 0.1 +``` + +The script recursively searches all subdirectories containing `type.raw`. + +## Notes + +- **Backend:** This example uses the PyTorch backend (`--pt`). Make sure you have installed DeePMD-kit with PyTorch support. +- **Stat File:** The `input.json` specifies `"stat_file": "./qm9_charge_density.hdf5"` for caching descriptor statistics. It will be generated automatically on the first run. +- **Checkpoint vs. Frozen Model:** `dptest_density_script.py` uses `DeepPot()` to load the model. If loading a training checkpoint (`.pt`) fails, freeze it first with `dp --pt freeze` and use the resulting `.pth` file. diff --git a/examples/density/dataset/qm9/C7H15NO_train/set.000/box.npy b/examples/density/dataset/qm9/C7H15NO_train/set.000/box.npy new file mode 100644 index 0000000000..8d064b2261 Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_train/set.000/box.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_train/set.000/coord.npy b/examples/density/dataset/qm9/C7H15NO_train/set.000/coord.npy new file mode 100644 index 0000000000..4e5b579726 Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_train/set.000/coord.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_train/set.000/density.npy b/examples/density/dataset/qm9/C7H15NO_train/set.000/density.npy new file mode 100644 index 0000000000..520f77fc60 Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_train/set.000/density.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_train/set.000/grid.npy b/examples/density/dataset/qm9/C7H15NO_train/set.000/grid.npy new file mode 100644 index 0000000000..d16f0901a0 Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_train/set.000/grid.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_train/type.raw b/examples/density/dataset/qm9/C7H15NO_train/type.raw new file mode 100644 index 0000000000..c46e48ed0a --- /dev/null +++ b/examples/density/dataset/qm9/C7H15NO_train/type.raw @@ -0,0 +1,24 @@ +5 +5 +5 +5 +5 +5 +5 +4 +7 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 diff --git a/examples/density/dataset/qm9/C7H15NO_train/type_map.raw b/examples/density/dataset/qm9/C7H15NO_train/type_map.raw new file mode 100644 index 0000000000..36c071bf5b --- /dev/null +++ b/examples/density/dataset/qm9/C7H15NO_train/type_map.raw @@ -0,0 +1,9 @@ +Li +Ni +Co +Mn +O +C +H +N +F diff --git a/examples/density/dataset/qm9/C7H15NO_val/set.000/box.npy b/examples/density/dataset/qm9/C7H15NO_val/set.000/box.npy new file mode 100644 index 0000000000..2299ac43d9 Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_val/set.000/box.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_val/set.000/coord.npy b/examples/density/dataset/qm9/C7H15NO_val/set.000/coord.npy new file mode 100644 index 0000000000..832913a48d Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_val/set.000/coord.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_val/set.000/density.npy b/examples/density/dataset/qm9/C7H15NO_val/set.000/density.npy new file mode 100644 index 0000000000..75192b3ec1 Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_val/set.000/density.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_val/set.000/grid.npy b/examples/density/dataset/qm9/C7H15NO_val/set.000/grid.npy new file mode 100644 index 0000000000..961264ac8b Binary files /dev/null and b/examples/density/dataset/qm9/C7H15NO_val/set.000/grid.npy differ diff --git a/examples/density/dataset/qm9/C7H15NO_val/type.raw b/examples/density/dataset/qm9/C7H15NO_val/type.raw new file mode 100644 index 0000000000..c46e48ed0a --- /dev/null +++ b/examples/density/dataset/qm9/C7H15NO_val/type.raw @@ -0,0 +1,24 @@ +5 +5 +5 +5 +5 +5 +5 +4 +7 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 +6 diff --git a/examples/density/dataset/qm9/C7H15NO_val/type_map.raw b/examples/density/dataset/qm9/C7H15NO_val/type_map.raw new file mode 100644 index 0000000000..36c071bf5b --- /dev/null +++ b/examples/density/dataset/qm9/C7H15NO_val/type_map.raw @@ -0,0 +1,9 @@ +Li +Ni +Co +Mn +O +C +H +N +F diff --git a/examples/density/dpa2/input.json b/examples/density/dpa2/input.json new file mode 100644 index 0000000000..2454c9b26c --- /dev/null +++ b/examples/density/dpa2/input.json @@ -0,0 +1,117 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "Li", + "Ni", + "Co", + "Mn", + "O", + "C", + "H", + "N", + "F", + "X" + ], + "descriptor": { + "type": "dpa2", + "repinit": { + "tebd_dim": 8, + "rcut": 6.0, + "rcut_smth": 0.5, + "nsel": 120, + "neuron": [ + 25, + 50, + 100 + ], + "axis_neuron": 12, + "activation_function": "tanh", + "three_body_sel": 48, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + "use_three_body": true + }, + "repformer": { + "rcut": 4.0, + "rcut_smth": 3.5, + "nsel": 48, + "nlayers": 12, + "g1_dim": 128, + "g2_dim": 32, + "attn2_hidden": 32, + "attn2_nhead": 4, + "attn1_hidden": 128, + "attn1_nhead": 4, + "axis_neuron": 4, + "update_h2": false, + "update_g1_has_conv": true, + "update_g1_has_grrg": true, + "update_g1_has_drrd": true, + "update_g1_has_attn": false, + "update_g2_has_g1g1": false, + "update_g2_has_attn": true, + "update_style": "res_residual", + "update_residual": 0.01, + "update_residual_init": "norm", + "attn2_has_gate": true, + "use_sqrt_nnei": true, + "g1_out_conv": true, + "g1_out_mlp": true + }, + "precision": "float32", + "add_tebd_to_repinit_out": false + }, + "fitting_net": { + "type": "density", + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float32", + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-08, + "_comment": "that's all" + }, + "loss": { + "type": "grid_density", + "start_pref_d": 1, + "limit_pref_d": 1, + "_comment": " that's all" + }, + "training": { + "stat_file": "./dpa2.hdf5", + "training_data": { + "systems": [ + "../dataset/qm9/C7H15NO_train" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../dataset/qm9/C7H15NO_val" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 200, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/density/dpa3/input.json b/examples/density/dpa3/input.json new file mode 100644 index 0000000000..0ee1270fae --- /dev/null +++ b/examples/density/dpa3/input.json @@ -0,0 +1,107 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "Li", + "Ni", + "Co", + "Mn", + "O", + "C", + "H", + "N", + "F", + "X" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 12, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 120, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 30, + "axis_neuron": 4, + "skip_stat": true, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "use_dynamic_sel": true, + "sel_reduce_factor": 10.0, + "edge_use_dist": true, + "use_env_envelope": true, + "optim_update": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "custom_silu:10.0", + "use_tebd_bias": false, + "env_protection": 0.1, + "precision": "float32", + "concat_output_tebd": false + }, + "fitting_net": { + "type": "density", + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "seed": 1, + "precision": "float32", + "activation_function": "silu", + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 1e-05, + "_comment": "that's all" + }, + "loss": { + "type": "grid_density", + "start_pref_d": 1, + "limit_pref_d": 1, + "_comment": " that's all" + }, + "training": { + "stat_file": "./qm9_charge_density.hdf5", + "training_data": { + "systems": [ + "../dataset/qm9/C7H15NO_train" + ], + "batch_size": "auto:128", + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../dataset/qm9/C7H15NO_val" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment": "that's all" + }, + "numb_steps": 100, + "warmup_steps": 0, + "gradient_max_norm": 5.0, + "seed": 10, + "max_ckpt_keep": 100, + "opt_type": "AdamW", + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 10000, + "_comment": "that's all" + } +} diff --git a/examples/density/dptest_density_script.py b/examples/density/dptest_density_script.py new file mode 100644 index 0000000000..7de198466a --- /dev/null +++ b/examples/density/dptest_density_script.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Evaluate a charge-density model on validation/test datasets. + +Usage +----- +python test_density_new.py model.pt /path/to/data --ratio 0.1 --output result.txt +""" + +import argparse +import glob +import math +import os +import random +import sys + +import dpdata +import numpy as np +from dpdata.data_type import ( + Axis, + DataType, +) +from tqdm import ( + tqdm, +) + +from deepmd.infer import ( + DeepPot, +) + +# Register custom dpdata types for grid density +_GRID_DATA_TYPE = DataType( + "grid", + np.ndarray, + shape=(Axis.NFRAMES, 125, 3), + required=False, +) +_DENSITY_DATA_TYPE = DataType( + "density", + np.ndarray, + shape=(Axis.NFRAMES, 125, 1), + required=False, +) +dpdata.System.register_data_type(_GRID_DATA_TYPE) +dpdata.System.register_data_type(_DENSITY_DATA_TYPE) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Evaluate charge-density DeepPot model." + ) + parser.add_argument("model", type=str, help="Path to the model file (.pt or .pth).") + parser.add_argument( + "data_dir", type=str, help="Root directory of deepmd/npy datasets." + ) + parser.add_argument( + "--ratio", + type=float, + default=0.1, + help="Fraction of frames to randomly sample from each system (default: 0.1).", + ) + parser.add_argument( + "--output", + "-o", + type=str, + default=None, + help="If provided, write screen output to this file as well.", + ) + parser.add_argument( + "--pred-file", + type=str, + default="result.d.out", + help="File to save paired [prediction, label] array (default: result.d.out).", + ) + return parser.parse_args() + + +class TeeLogger: + """Redirect stdout to both the terminal and a file.""" + + def __init__(self, filepath: str) -> None: + self.terminal = sys.stdout + self.log = open(filepath, "w") + + def write(self, message: str) -> None: + self.terminal.write(message) + self.log.write(message) + + def flush(self) -> None: + self.terminal.flush() + self.log.flush() + + def close(self) -> None: + self.log.close() + + +def evaluate( + model_path: str, data_dir: str, ratio: float +) -> tuple[np.ndarray, np.ndarray]: + """Run inference and return (predictions, labels).""" + dm = DeepPot(model_path) + type_map = dm.get_type_map() + + pred_list, label_list = [], [] + pattern = os.path.join(data_dir, "**/type.raw") + systems = glob.glob(pattern, recursive=True) + + if not systems: + raise RuntimeError(f"No deepmd/npy systems found under '{data_dir}'") + + for f in tqdm(systems, desc="Systems"): + sys_name = os.path.dirname(f) + s = dpdata.System(sys_name, fmt="deepmd/npy", type_map=type_map) + + n_frames = len(s) + n_sample = max(1, math.floor(n_frames * ratio)) + indices = random.sample(range(n_frames), n_sample) + s = s.sub_system(indices) + + coord = s.data["coords"].reshape(len(s), -1) + atype = list(s.data["atom_types"]) + cell = s.data["cells"].reshape(len(s), -1) + grid = s.data["grid"].reshape(len(s), -1) + + density_pred = dm.eval(coord, cell, atype, grid=grid) + density_label = s.data["density"].reshape(len(s), -1) + + pred_list.append(density_pred) + label_list.append(density_label) + sys.stdout.write(f" {sys_name:60s} frames={n_sample}/{n_frames}" + "\n") + + predictions = np.concatenate(pred_list) + labels = np.concatenate(label_list) + return predictions, labels + + +def print_summary(pred: np.ndarray, label: np.ndarray) -> None: + diff = pred - label + mae = np.mean(np.abs(diff)) + rmse = np.sqrt(np.mean(diff**2)) + label_mean_abs = np.mean(np.abs(label)) + label_std = np.std(label) + + sys.stdout.write("\n" + "=" * 60 + "\n") + sys.stdout.write("Summary" + "\n") + sys.stdout.write("=" * 60 + "\n") + sys.stdout.write(f" Number of grid points : {label.size}" + "\n") + sys.stdout.write(f" Label std : {label_std:.6e}" + "\n") + sys.stdout.write(f" RMSE : {rmse:.6e}" + "\n") + sys.stdout.write(f" MAE : {mae:.6e}" + "\n") + sys.stdout.write(f" Mean |label| : {label_mean_abs:.6e}" + "\n") + sys.stdout.write( + f" epsilon_MAE (MAE/Mean|label|) : {mae / label_mean_abs:.6e}" + "\n" + ) + sys.stdout.write("=" * 60 + "\n") + + +def main() -> None: + args = parse_args() + + if args.output: + tee = TeeLogger(args.output) + sys.stdout = tee + sys.stdout.write(f"[INFO] Screen output will also be saved to: {args.output}\n") + + sys.stdout.write(f"[INFO] Model : {args.model}" + "\n") + sys.stdout.write(f"[INFO] Data : {args.data_dir}" + "\n") + sys.stdout.write(f"[INFO] Ratio : {args.ratio}\n") + + pred, label = evaluate(args.model, args.data_dir, args.ratio) + + # Save paired predictions & labels + out_array = np.stack([pred.reshape(-1), label.reshape(-1)], axis=1) + np.savetxt(args.pred_file, out_array) + sys.stdout.write(f"\n[INFO] Paired [pred, label] saved to: {args.pred_file}" + "\n") + + print_summary(pred, label) + + if args.output: + tee.close() + sys.stdout = tee.terminal + + +if __name__ == "__main__": + main()