From c202284f63c83be7df801b4a1e01034eeae26065 Mon Sep 17 00:00:00 2001 From: Roman Zubatyuk Date: Fri, 3 May 2024 22:54:27 -0400 Subject: [PATCH 01/27] First commit --- aimnet2calc/__init__.py | 12 ++ aimnet2calc/aimnet2ase.py | 73 ++++++++ aimnet2calc/calculator.py | 229 +++++++++++++++++++++++ aimnet2calc/ensemble.py | 93 ++++++++++ aimnet2calc/models.py | 37 ++++ aimnet2calc/nblist.py | 151 +++++++++++++++ test/1008775.cif | 108 +++++++++++ test/mols_size_36.xyz | 380 ++++++++++++++++++++++++++++++++++++++ test/mols_size_var.xyz | 299 ++++++++++++++++++++++++++++++ test/test_calculator.py | 104 +++++++++++ 10 files changed, 1486 insertions(+) create mode 100644 aimnet2calc/__init__.py create mode 100644 aimnet2calc/aimnet2ase.py create mode 100644 aimnet2calc/calculator.py create mode 100644 aimnet2calc/ensemble.py create mode 100644 aimnet2calc/models.py create mode 100644 aimnet2calc/nblist.py create mode 100644 test/1008775.cif create mode 100644 test/mols_size_36.xyz create mode 100644 test/mols_size_var.xyz create mode 100644 test/test_calculator.py diff --git a/aimnet2calc/__init__.py b/aimnet2calc/__init__.py new file mode 100644 index 0000000..99f862c --- /dev/null +++ b/aimnet2calc/__init__.py @@ -0,0 +1,12 @@ +from .calculator import AIMNet2Calculator +__all__ = ['AIMNet2Calculator'] + +try: + from .aimnet2ase import AIMNet2ASE + __all__.append('AIMNet2ASE') +except ImportError: + import warnings + warnings.warn('ASE is not installed. AIMNet2ASE will not be available.') + pass + + diff --git a/aimnet2calc/aimnet2ase.py b/aimnet2calc/aimnet2ase.py new file mode 100644 index 0000000..dd6031e --- /dev/null +++ b/aimnet2calc/aimnet2ase.py @@ -0,0 +1,73 @@ +from ase.calculators.calculator import Calculator, all_changes +from aimnet2calc import AIMNet2Calculator +from typing import Union +import torch + + +class AIMNet2ASE(Calculator): + implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress'] + def __init__(self, base_calc: Union[AIMNet2Calculator, str], charge=0, mult=1): + super().__init__() + if isinstance(base_calc, str): + base_calc = AIMNet2Calculator(base_calc) + self.base_calc = base_calc + self.charge = charge + self.mult = mult + self.do_reset() + + def do_reset(self): + self._t_numbers = None + self._t_charge = None + self._t_mult = None + self._t_mol_idx = None + self.charge = 0.0 + self.mult = 1.0 + + def set_atoms(self, atoms): + self.atoms = atoms + self.do_reset() + + def set_charge(self, charge): + self.charge = charge + + def set_mult(self, mult): + self.mult = mult + + def uptade_tensors(self): + if self._t_numbers is None: + self._t_numbers = torch.tensor(self.atoms.numbers, dtype=torch.int64, device=self.base_calc.device) + if self._t_charge is None: + self._t_charge = torch.tensor(self.charge, dtype=torch.float32, device=self.base_calc.device) + if self._t_mult is None: + self._t_mult = torch.tensor(self.mult, dtype=torch.float32, device=self.base_calc.device) + if self._t_mol_idx is None: + self.mol_idx = torch.zeros(len(self.atoms), dtype=torch.int64, device=self.base_calc.device) + + def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes): + super().calculate(atoms, properties, system_changes) + self.uptade_tensors() + + if self.atoms.cell is not None and self.atoms.pbc.any(): + assert self.base_calc.cutoff_lr < float('inf'), 'Long-range cutoff must be finite for PBC' + cell = self.atoms.cell.array + else: + cell = None + + results = self.base_calc({ + 'coord': torch.tensor(self.atoms.positions, dtype=torch.float32, device=self.base_calc.device), + 'numbers': self._t_numbers, + 'cell': cell, + 'mol_idx': self._t_mol_idx, + 'charge': self._t_charge, + 'mult': self._t_mult, + }, forces='forces' in properties, stress='stress' in properties) + for k, v in results.items(): + results[k] = v.detach().cpu().numpy() + + self.results['energy'] = results['energy'] + self.results['charges'] = results['charges'] + if 'forces' in properties: + self.results['forces'] = results['forces'] + if 'stress' in properties: + self.results['stress'] = results['stress'] + diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py new file mode 100644 index 0000000..1957ebb --- /dev/null +++ b/aimnet2calc/calculator.py @@ -0,0 +1,229 @@ +import torch +from torch import nn, Tensor +from typing import Union, Dict, Any +from aimnet2calc.nblist import nblist_torch_cluster, nblists_torch_pbc +from aimnet2calc.models import get_model_path + + +class AIMNet2Calculator: + """ Genegic AIMNet2 calculator + A helper class to load AIMNet2 models and perform inference. + """ + + keys_in = ['coord', 'numbers', 'charge'] + keys_in_optional = ['mult', 'mol_idx', 'nbmat', 'nbmat_lr', 'nb_pad_mask', 'nb_pad_mask_lr', 'shifts', 'cell'] + keys_out = ['energy', 'charges', 'forces'] + atom_feature_keys = ['coord', 'numbers', 'charges', 'forces'] + + def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if isinstance(model, str): + p = get_model_path(model) + self.model = torch.jit.load(p, map_location=self.device) + elif isinstance(model, nn.Module): + self.model = model.to(self.device) + else: + raise AttributeError('Invalid model type/name.') + + self.cutoff = self.model.cutoff + self.lr = hasattr(self.model, 'cutoff_lr') + self.cutoff_lr = getattr(self.model, 'cutoff_lr', float('inf')) + + # indicator if input was flattened + self._batch = None + self._saved_for_grad = None + + def __call__(self, *args, **kwargs): + return self.eval(*args, **kwargs) + + def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]: + data = self.prepare_input(data) + if hessian and data['mol_idx'][-1] > 0: + raise NotImplementedError('Hessian calculation is not supported for multiple molecules') + data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian) + with torch.jit.optimized_execution(True): + data = self.model(data) + data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian) + data = self.process_output(data) + return data + + def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]: + data = self.to_input_tensors(data) + data = self.mol_flatten(data) + if 'cell' in data and data['cell'] is not None and data['mol_idx'][-1] > 0: + raise NotImplementedError('PBC with multiple molecules is not supported') + data = self.make_nbmat(data) + data = self.pad_input(data) + return data + + def process_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + data = self.unpad_output(data) + data = self.mol_unflatten(data) + data = self.keep_only(data) + return data + + def to_input_tensors(self, data: Dict[str, Any]) -> Dict[str, Tensor]: + ret = dict() + for k in self.keys_in: + assert k in data, f'Missing key {k} in the input data' + # always detach !! + ret[k] = torch.as_tensor(data[k], device=self.device).detach() + for k in self.keys_in_optional: + if k in data and data[k] is not None: + ret[k] = torch.as_tensor(data[k], device=self.device).detach() + # convert any scalar tensors to shape (1,) tensors + for k, v in ret.items(): + if v.ndim == 0: + ret[k] = v.unsqueeze(0) + return ret + + def mol_flatten(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + assert data['coord'].ndim in {2, 3}, 'Expected 2D or 3D tensor for coord' + if data['coord'].ndim == 3: + B, N = data['coord'].shape[:2] + self._batch = B + data['mol_idx'] = torch.repeat_interleave(torch.arange(0, B, device=self.device), torch.full((B,), N, device=self.device)) + for k, v in data.items(): + if k in self.atom_feature_keys: + assert v.ndim >= 2, f'Expected at least 2D tensor for {k}, got {v.ndim}D' + data[k] = v.flatten(0, 1) + else: + self._batch = None + if 'mol_idx' not in data: + data['mol_idx'] = torch.zeros(data['coord'].shape[0], device=self.device) + return data + + def mol_unflatten(self, data: Dict[str, Tensor], batch=None) -> Dict[str, Tensor]: + batch = batch or self._batch + if batch is not None: + for k, v in data.items(): + if k in self.atom_feature_keys: + data[k] = v.view(self._batch, -1, *v.shape[1:]) + return data + + def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + if 'cell' in data and data['cell'] is not None: + assert data['cell'].ndim == 2, 'Expected 2D tensor for cell' + if 'nbmat' not in data: + data['coord'] = move_coord_to_cell(data['coord'], data['cell']) + mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff) + data['nbmat'], data['nb_pad_mask'], data['shifts'] = mat_idxj, mat_pad, mat_S + if self.lr: + if 'nbmat_lr' not in data: + assert self.cutoff_lr < torch.inf, 'Long-range cutoff must be finite for PBC' + data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr) + else: + if 'nbmat' not in data: + data['nbmat'] = nblist_torch_cluster(data['coord'], self.cutoff, data['mol_idx'], max_nb=128) + if self.lr: + if 'nbmat_lr' not in data: + data['nbmat_lr'] = nblist_torch_cluster(data['coord'], self.cutoff_lr, data['mol_idx'], max_nb=1024) + return data + + def pad_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + N = data['nbmat'].shape[0] + data['coord'] = maybe_pad_dim0(data['coord'], N) + data['numbers'] = maybe_pad_dim0(data['numbers'], N) + data['mol_idx'] = maybe_pad_dim0(data['mol_idx'], N, value=data['mol_idx'][-1]) + return data + + def unpad_output(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + N = data['nbmat'].shape[0] - 1 + for k, v in data.items(): + if k in self.atom_feature_keys: + data[k] = maybe_unpad_dim0(v, N) + return data + + def set_grad_tensors(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]: + self._saved_for_grad = dict() + if forces or hessian: + data['coord'].requires_grad_(True) + self._saved_for_grad['coord'] = data['coord'] + if stress: + assert 'cell' in data, 'Stress calculation requires cell' + scaling = torch.eye(3, requires_grad=True, dtype=data['cell'].dtype, device=data['cell'].dtype) + data['coord'] = data['coord'] @ scaling + data['cell'] = data['cell'] @ scaling + self._saved_for_grad['scaling'] = scaling + return data + + def keep_only(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + ret = dict() + for k, v in data.items(): + if k in self.keys_out or (k.endswith('_std') and k[:-4] in self.keys_out): + ret[k] = v + return ret + + def get_derivatives(self, data: Dict[str, Tensor], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]: + training = getattr(self.model, 'training', False) + _create_graph = hessian or training + x = [] + if hessian: + forces = True + if forces and ('forces' not in data or (_create_graph and not data['forces'].requires_grad)): + forces = True + x.append(self._saved_for_grad['coord']) + if stress: + x.append(self._saved_for_grad['scaling']) + if x: + tot_energy = data['energy'].sum() + deriv = torch.autograd.grad(tot_energy, x, create_graph=_create_graph) + if forces: + data['forces'] = - deriv[0] + if stress: + if not forces: + dedc = deriv[0] + else: + dedc = deriv[1] + data['stress'] = dedc / data['cell'].detach().det().abs() + if hessian: + data['hessian'] = self.calculate_hessian(data['forces'], self._saved_for_grad['coord']) + return data + + @staticmethod + def calculate_hessian(forces, coord): + # here forces have shape (N, 3) and coord has shape (N+1, 3) + # return hessian with shape (N, 3, N, 3) + hessian = - torch.stack([ + torch.autograd.grad(_f, coord, retain_graph=True)[0] + for _f in forces.flatten().unbind() + ]).view(-1, 3, coord.shape[0], 3)[:-1, :, :-1, :] + return hessian + + +def maybe_pad_dim0(a: Tensor, N: int, value=0.0) -> Tensor: + _shape_diff = N - a.shape[0] + assert _shape_diff == 0 or _shape_diff == 1, 'Invalid shape' + if _shape_diff == 1: + a = pad_dim0(a, value=value) + return a + +def pad_dim0(a: Tensor, value=0.0) -> Tensor: + shapes = [0] * ((a.ndim - 1)*2) + [0, 1] + a = torch.nn.functional.pad(a, shapes, mode='constant', value=value) + return a + +def maybe_unpad_dim0(a: Tensor, N: int) -> Tensor: + _shape_diff = a.shape[0] - N + assert _shape_diff == 0 or _shape_diff == 1, 'Invalid shape' + if _shape_diff == 1: + a = a[:-1] + return a + +def move_coord_to_cell(coord, cell): + coord_f = coord @ cell.inverse() + coord_f = coord_f % 1 + return coord_f @ cell + + +def _named_children_rec(module): + if isinstance(module, torch.nn.Module): + for name, module in module.named_children(): + yield name, module + yield from _named_children_rec(module) + + +def set_lrcoulomb_method(model, method): + for name, module in _named_children_rec(model): + if name == 'lrcoulomb': + module.set_method(method) diff --git a/aimnet2calc/ensemble.py b/aimnet2calc/ensemble.py new file mode 100644 index 0000000..9a3ce25 --- /dev/null +++ b/aimnet2calc/ensemble.py @@ -0,0 +1,93 @@ +import torch +from torch import Tensor, nn +from typing import List, Dict, Union + + + +class Forces(nn.Module): + """Compute forces from energy using autograd. + """ + def __init__(self, module: nn.Module, + x: str = 'coord', y: str = 'energy', key_out: str = 'forces', + detach: bool = True): + super().__init__() + self.add_module('module', module) + self.x = x + self.y = y + self.key_out = key_out + self.detach = detach + + def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + prev = torch.is_grad_enabled() + torch.set_grad_enabled(True) + data[self.x].requires_grad_(True) + data = self.module(data) + y = data[self.y] + create_graph = self.training or not self.detach + g = torch.autograd.grad( + [y.sum()], [data[self.x]], create_graph=create_graph)[0] + assert g is not None + data[self.key_out] = - g + torch.set_grad_enabled(prev) + return data + + +class EnsembledModel(nn.Module): + def __init__(self, models: List[nn.Module], + out=['energy', 'forces', 'charges'], + detach=False): + super().__init__() + self.models = nn.ModuleList(models) + self.out = out + self.detach = detach + + def forward(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: + res : List[Dict[str, Tensor]] = [] + for model in self.models: + _in = dict() + for k in data: + _in[k] = data[k] + _out = model(_in) + _r = dict() + for k in _out: + if k in self.out: + _r[k] = _out[k] + if self.detach: + _r[k] = _r[k].detach() + res.append(_r) + + for k in res[0]: + v = [] + for x in res: + v.append(x[k]) + vv = torch.stack(v, dim=0) + data[k] = vv.mean(dim=0) + data[k + '_std'] = vv.std(dim=0) + + return data + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--models', type=str, nargs='+', required=True) + parser.add_argument('--out-keys', type=str, nargs='+', default=['energy', 'forces', 'charges']) + parser.add_argument('--detach', action='store_true') + parser.add_argument('--forces', action='store_true') + parser.add_argument('--grad-of-forces', action='store_true') + parser.add_argument('output', type=str) + args = parser.parse_args() + + models = [torch.jit.load(m, map_location='cpu') for m in args.models] + if args.forces: + models = [Forces(m, detach=not args.grad_on_forces) for m in models] + + print('Ensembling {len(models)} models.') + ens = EnsembledModel(models, out=args.out_keys, detach=args.detach) + ens = torch.jit.script(ens) + ens.save(args.output) + + + + \ No newline at end of file diff --git a/aimnet2calc/models.py b/aimnet2calc/models.py new file mode 100644 index 0000000..a6608a3 --- /dev/null +++ b/aimnet2calc/models.py @@ -0,0 +1,37 @@ +import os +import requests + +# model registry aliases +model_registry_aliases = {} +model_registry_aliases['aimnet2'] = 'aimnet2_wb97m_0_240428' +model_registry_aliases['aimnet2_wb97m'] = model_registry_aliases['aimnet2'] +model_registry_aliases['aimnet2_wb97m_ens'] = 'aimnet2_wb97m_ens_240428' +model_registry_aliases['aimnet2_ens'] = model_registry_aliases['aimnet2_wb97m_ens'] +model_registry_aliases['aimnet2_b973c'] = 'aimnet2_b973c_0_240428' +model_registry_aliases['aimnet2_b973c_ens'] = 'aimnet2_b973c_ens_240428' +model_registry_aliases['aimnet2_qr'] = 'aimnet2_qr_b97m_qzvp' + + +def get_model_path(s: str): + # direct file path + if os.path.isfile(s): + print('Found model file:', s) + return s + # check aliases + if s in model_registry_aliases: + s = model_registry_aliases[s] + # add jpt extension + if not s.endswith('.jpt'): + s = s + '.jpt' + s_local = os.path.join(os.path.dirname(__file__), 'assets', s) + if os.path.isfile(s_local): + print('Found model file:', s_local) + else: + url = f'https://github.com/zubatyuk/aimnet-model-zoo/raw/main/aimnet2/{s}' + print('Downloading model file from', url) + r = requests.get(url) + r.raise_for_status() + with open(s_local, 'wb') as f: + f.write(r.content) + print('Saved to ', s_local) + return s_local diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py new file mode 100644 index 0000000..17f5015 --- /dev/null +++ b/aimnet2calc/nblist.py @@ -0,0 +1,151 @@ +import torch +from torch import Tensor +from typing import Optional, Tuple +from torch_cluster import radius_graph +import numba +try: + # optionaly use numba cuda + import numba.cuda + _numba_cuda_available = True +except ImportError: + _numba_cuda_available = False +import numpy as np + + +@numba.njit(cache=True) +def sparse_nb_to_dense_half(idx, natom, max_nb): + dense_nb = np.full((natom+1, max_nb), natom, dtype=np.int32) + last_idx = np.zeros((natom,), dtype=np.int32) + for k in range(idx.shape[0]): + i, j = idx[k] + il, jl = last_idx[i], last_idx[j] + dense_nb[i, il] = j + dense_nb[j, jl] = i + last_idx[i] += 1 + last_idx[j] += 1 + return dense_nb + + +def nblist_torch_cluster(coord: Tensor, cutoff: float, mol_idx: Optional[Tensor] = None, max_nb: int = 256): + device = coord.device + assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' + assert coord.shape[0] < 2147483646, 'Too many atoms, max supported is 2147483646' + max_num_neighbors = max_nb + while max_num_neighbors == max_nb: + sparse_nb = radius_graph(coord, batch=mol_idx, r=cutoff, max_num_neighbors=max_nb).to(torch.int32) + max_num_neighbors = torch.unique(sparse_nb[0], return_counts=True)[1].max().item() + max_nb *= 2 + #assert max_num_neighbors < max_nb, f'Increase max_nb in nblist_torch_cluster (current value {max_nb}, cutoff {cutoff})' + sparse_nb_half = sparse_nb[:, sparse_nb[0] > sparse_nb[1]] + dense_nb = sparse_nb_to_dense_half(sparse_nb_half.mT.cpu().numpy(), coord.shape[0], max_num_neighbors) + dense_nb = torch.as_tensor(dense_nb, device=device) + return dense_nb + + +### dense neighbor matrix kernels + +@numba.njit(cache=True, parallel=True) +def _cpu_dense_nb_mat_sft(conn_matrix): + N, S = conn_matrix.shape[:2] + # figure out max number of neighbors + _s_flat_conn_matrix = conn_matrix.reshape(N, -1) + maxnb = np.max(np.sum(_s_flat_conn_matrix, axis=-1)) + M = maxnb + # atom idx matrix + mat_idxj = np.full((N + 1, M), N, dtype=np.int_) + # padding matrix + mat_pad = np.ones((N + 1, M), dtype=np.bool_) + # shitfs matrix + mat_S_idx = np.zeros((N + 1, M), dtype=np.int_) + for _n in numba.prange(N): + _i = 0 + for _s in range(S): + for _m in range(N): + if conn_matrix[_n, _s, _m] == True: + mat_idxj[_n, _i] = _m + mat_pad[_n, _i] = False + mat_S_idx[_n, _i] = _s + _i += 1 + return mat_idxj, mat_pad, mat_S_idx + +if _numba_cuda_available: + @numba.cuda.jit(cache=True) + def _cuda_dense_nb_mat_sft(conn_matrix, mat_idxj, mat_pad, mat_S_idx): + i = numba.cuda.grid(1) + if i < conn_matrix.shape[0]: + k = 0 + for s in range(conn_matrix.shape[1]): + for j in range(conn_matrix.shape[2]): + if conn_matrix[i, s, j] > 0: + mat_idxj[i, k] = j + mat_pad[i, k] = 0 + mat_S_idx[i, k] = s + k += 1 + + +def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: + """ Compute dense neighbor lists for periodic boundary conditions case. + Coordinates must be in cartesian coordinates and be within the unit cell. + Single crystal only, no support for batched coord or multiple unit cells. + """ + assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' + # non-PBC version + device = coord.device + + reciprocal_cell = cell.inverse().t() + inv_distances = reciprocal_cell.norm(2, -1) + shifts = _calc_shifts(inv_distances, cutoff) + d = torch.cdist(coord.unsqueeze(0), coord.unsqueeze(0) + (shifts @ cell).unsqueeze(1)) + conn_mat = ((d < cutoff) & (d > 0.1)).transpose(0, 1).contiguous() + if device.type == 'cuda' and _numba_cuda_available: + _fn = _nblist_pbc_cuda + else: + _fn = _nblist_pbc_cpu + mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts) + return mat_idxj, mat_pad, mat_S + + +def _calc_shifts(inv_distances, cutoff): + num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long) + dc = [torch.arange(-num_repeats[i], num_repeats[i] + 1, device=inv_distances.device) for i in range(len(num_repeats))] + shifts = torch.cartesian_prod(*dc).to(torch.float) + return shifts + + +def _nblist_pbc_cuda(conn_mat, shifts): + N = conn_mat.shape[0] + M = conn_mat.view(N, -1).sum(-1).max() + threadsperblock = 32 + blockspergrid = (N + (threadsperblock - 1)) // threadsperblock + idx_j = torch.full((N + 1, M), N, dtype=torch.int64, device=conn_mat.device) + mat_pad = torch.ones((N + 1, M), dtype=torch.int8, device=conn_mat.device) + S_idx = torch.zeros((N + 1, M), dtype=torch.int64, device=conn_mat.device) + conn_mat = conn_mat.to(torch.int8) + _conn_mat = numba.cuda.as_cuda_array(conn_mat) + _idx_j = numba.cuda.as_cuda_array(idx_j) + _mat_pad = numba.cuda.as_cuda_array(mat_pad) + _S_idx = numba.cuda.as_cuda_array(S_idx) + _cuda_dense_nb_mat_sft[blockspergrid, threadsperblock](_conn_mat, _idx_j, _mat_pad, _S_idx) + mat_pad = mat_pad.to(torch.bool) + return idx_j, mat_pad, shifts[S_idx] + + +def _nblist_pbc_cpu(conn_mat, shifts, device): + conn_mat = conn_mat.cpu().numpy() + mat_idxj, mat_pad, mat_S_idx = _cpu_dense_nb_mat_sft(conn_mat) + mat_idxj = torch.from_numpy(mat_idxj).to(device) + mat_pad = torch.from_numpy(mat_pad).to(device) + mat_S_idx = torch.from_numpy(mat_S_idx).to(device) + mat_S = shifts[mat_S_idx] + return mat_idxj, mat_pad, mat_S + + + + + + + + + + + diff --git a/test/1008775.cif b/test/1008775.cif new file mode 100644 index 0000000..873cc9f --- /dev/null +++ b/test/1008775.cif @@ -0,0 +1,108 @@ +#------------------------------------------------------------------------------ +#$Date: 2023-03-26 11:09:57 +0300 (Sun, 26 Mar 2023) $ +#$Revision: 282068 $ +#$URL: file:///home/coder/svn-repositories/cod/cif/1/00/87/1008775.cif $ +#------------------------------------------------------------------------------ +# +# This file is available in the Crystallography Open Database (COD), +# http://www.crystallography.net/ +# +# All data on this site have been placed in the public domain by the +# contributors. +# +data_1008775 +loop_ +_publ_author_name +'Guth, H' +'Heger, G' +'Klein, S' +'Treutmann, W' +'Scheringer, C' +_publ_section_title +; +Strukturverfeinerung von Harnstoff mit Neutronenbeugungsdaten bei 60, +123, 293 K und X-N- und X-X (1S2)-Synthesen bei etwa 100 K +; +_journal_coden_ASTM ZEKRDZ +_journal_name_full +; +Zeitschrift fuer Kristallographie (149,1979-) +; +_journal_page_first 237 +_journal_page_last 254 +_journal_volume 153 +_journal_year 1980 +_chemical_compound_source 'synthetic from solution' +_chemical_formula_structural 'C O (N H2)2' +_chemical_formula_sum 'C H4 N2 O' +_chemical_name_mineral Urea +_chemical_name_systematic Urea +_space_group_IT_number 113 +_symmetry_cell_setting tetragonal +_symmetry_Int_Tables_number 113 +_symmetry_space_group_name_Hall 'P -4 2ab' +_symmetry_space_group_name_H-M 'P -4 21 m' +_cell_angle_alpha 90 +_cell_angle_beta 90 +_cell_angle_gamma 90 +_cell_formula_units_Z 2 +_cell_length_a 5.576 +_cell_length_b 5.576 +_cell_length_c 4.692 +_cell_volume 145.9 +_database_code_amcsd 0010831 +_diffrn_ambient_temperature 100 +_exptl_crystal_density_diffrn 1.367 +_cod_original_formula_sum 'H4 N2 O' +_cod_database_code 1008775 +loop_ +_space_group_symop_operation_xyz +x,y,z +1/2-x,1/2+y,-z +-x,-y,z +1/2+x,1/2-y,-z +-y,x,-z +1/2+y,1/2+x,z +y,-x,-z +1/2-y,1/2-x,z +loop_ +_atom_site_aniso_label +_atom_site_aniso_U_11 +_atom_site_aniso_U_12 +_atom_site_aniso_U_13 +_atom_site_aniso_U_22 +_atom_site_aniso_U_23 +_atom_site_aniso_U_33 +C1 0.0127 0.0006 0. 0.0127 0. 0.0056 +O1 0.0161 0.0015 0. 0.0161 0. 0.0052 +N1 0.0244 -0.0118 0.0001 0.0244 0.0001 0.0082 +H1 0.0403 -0.0192 -0.0022 0.0403 -0.0022 0.0256 +H2 0.0403 -0.0132 0.0021 0.0403 0.0021 0.0153 +loop_ +_atom_site_label +_atom_site_type_symbol +_atom_site_symmetry_multiplicity +_atom_site_Wyckoff_symbol +_atom_site_fract_x +_atom_site_fract_y +_atom_site_fract_z +_atom_site_occupancy +_atom_site_attached_hydrogens +_atom_site_calc_flag +C1 C4+ 2 c 0. 0.5 0.329 1. 0 d +O1 O2- 2 c 0. 0.5 0.597 1. 0 d +N1 N3- 4 e 0.1455 0.6455 0.1791 1. 0 d +H1 H1+ 4 e 0.2561 0.7561 0.2855 1. 0 d +H2 H1+ 4 e 0.1429 0.6429 -0.037 1. 0 d +loop_ +_atom_type_symbol +_atom_type_oxidation_number +C4+ 4.000 +O2- -2.000 +N3- -3.000 +H1+ 1.000 +loop_ +_cod_related_entry_id +_cod_related_entry_database +_cod_related_entry_code +1 AMCSD 0010831 diff --git a/test/mols_size_36.xyz b/test/mols_size_36.xyz new file mode 100644 index 0000000..f57a514 --- /dev/null +++ b/test/mols_size_36.xyz @@ -0,0 +1,380 @@ +36 +2100cb04033933fc8faa29c1 +Si 3.065708 0.746566 0.151857 +C 2.675602 -0.746797 -0.915662 +Si 2.514801 -2.443970 -0.143056 +C 0.695983 -2.778497 -0.419477 +C 0.253433 -1.413334 -0.933630 +C 1.204815 -0.430063 -1.151225 +C 0.823636 0.904258 -1.124970 +C 1.842274 1.925912 -0.619470 +Si 0.676436 2.959379 0.414138 +C -0.840754 2.688975 -0.657744 +C -0.519536 1.251652 -1.051507 +C -1.473277 0.267376 -0.855913 +C -1.069774 -1.054134 -0.714658 +C -1.904397 -1.956973 0.194041 +Si -0.497902 -2.862977 1.026420 +Si -2.877739 -0.632033 1.079031 +C -2.836613 0.647153 -0.293570 +Si -2.526226 2.447430 0.113533 +H 4.471417 1.207376 0.124933 +H 2.730928 0.456145 1.567125 +H 3.237299 -0.707145 -1.843924 +H 2.843153 -2.426245 1.293908 +H 3.399128 -3.447617 -0.775498 +H 0.462361 -3.561426 -1.134502 +H 2.262668 2.493977 -1.445227 +H 1.077270 4.365283 0.637861 +H 0.468853 2.351635 1.753319 +H -0.803736 3.349854 -1.518199 +H -2.533348 -2.621146 -0.392855 +H -0.812862 -4.215280 1.536497 +H 0.032383 -2.084956 2.173689 +H -2.145257 -0.120714 2.264902 +H -4.221599 -1.040845 1.544330 +H -3.626235 0.452390 -1.012967 +H -2.490455 2.677028 1.569661 +H -3.558442 3.351762 -0.441192 +36 +218ef4420d58cba4e1ad8e92 +C -2.411395 -0.390155 -2.408391 +Si -0.545926 -0.185703 -2.686560 +Si 0.065763 0.072074 -0.390497 +C -0.552263 1.707147 0.381437 +C -0.108039 2.902137 -0.469337 +C -2.065063 1.747942 0.620685 +C -0.679275 -1.330309 0.685245 +C -0.257070 -2.739783 0.273802 +C -0.419690 -1.116213 2.180250 +C 1.980604 0.150739 -0.382780 +C 2.609284 0.367605 0.996766 +C 2.609641 -1.051326 -1.089252 +H -2.939528 -0.519687 -3.360241 +H -2.842939 0.484130 -1.920994 +H -2.621964 -1.256719 -1.780895 +H -0.054439 1.780546 1.352833 +H -0.566703 2.871735 -1.459424 +H 0.972947 2.913253 -0.603290 +H -0.405278 3.840725 0.000461 +H -2.613584 1.765966 -0.319604 +H -2.337429 2.653991 1.164660 +H -2.412858 0.890882 1.196376 +H -1.759326 -1.245101 0.531358 +H 0.791126 -2.924064 0.505921 +H -0.405964 -2.923764 -0.789141 +H -0.845244 -3.478981 0.819609 +H 0.641968 -1.190914 2.412071 +H -0.932943 -1.879515 2.768336 +H -0.769870 -0.142140 2.517444 +H 2.207764 1.041078 -0.977617 +H 2.498810 -0.514624 1.626528 +H 2.158834 1.213968 1.512308 +H 3.678211 0.562304 0.898147 +H 2.449309 -1.968481 -0.526283 +H 3.686480 -0.911644 -1.191880 +H 2.196064 -1.197097 -2.088048 +36 +2144d266e32159c22efd9887 +C -2.520928 3.876528 -0.737659 +O -2.946203 2.535672 -0.765245 +C -2.550853 1.827484 0.400505 +C -2.968713 0.378358 0.255446 +C -2.285284 -0.315715 -0.916852 +N -0.826362 -0.263614 -0.813028 +C -0.156940 -1.044083 0.067543 +O -0.705735 -1.775777 0.873123 +C 1.374773 -0.975828 -0.010165 +C 1.992335 -0.920491 1.382928 +C 3.512897 -0.834705 1.309811 +C 1.432797 0.247499 2.189145 +N 1.819597 -2.176970 -0.716561 +C 1.799014 -2.261313 -2.081638 +O 1.704904 -1.322660 -2.882393 +C 1.811772 -3.424362 -2.914373 +H -2.879263 4.357607 -1.643818 +H -1.429852 3.949999 -0.708954 +H -2.926965 4.408507 0.126872 +H -3.020313 2.280898 1.279373 +H -1.465362 1.899183 0.531177 +H -4.047257 0.347691 0.103584 +H -2.739493 -0.164408 1.170935 +H -2.562167 0.168592 -1.849279 +H -2.601188 -1.357356 -0.960307 +H -0.309307 0.049168 -1.616838 +H 1.694281 -0.118353 -0.597204 +H 1.703238 -1.850002 1.878024 +H 3.819691 0.097489 0.831833 +H 3.935256 -1.658968 0.738560 +H 3.941732 -0.855913 2.310356 +H 1.648963 1.195947 1.695274 +H 1.892881 0.272043 3.174706 +H 0.356774 0.169647 2.323019 +H 1.656837 -3.042088 -0.222482 +H 2.844459 -3.699701 -3.205413 +36 +21da48e25953c9fe77b5024b +O -0.691028 5.415777 -1.014183 +C -0.654877 4.365708 -0.464323 +O 0.154934 4.155951 0.630805 +C 0.042443 2.865448 1.104297 +O 0.713551 2.493696 2.010159 +C -1.034696 2.124198 0.326621 +C -1.414466 3.100132 -0.777547 +C -0.633322 0.733036 -0.146544 +C 0.633867 0.720737 -1.005146 +C 0.786419 -0.652357 -1.635447 +C 2.089587 -0.999202 -2.332987 +O 2.112149 -2.423373 -2.359906 +C 1.541696 -2.892867 -1.142398 +C 0.727578 -1.742087 -0.591137 +C 0.100860 -1.628316 0.576247 +C -0.477496 -0.230390 1.024847 +C -1.689963 -0.805737 1.691905 +C -1.369207 -1.439263 2.853965 +C -0.897210 -2.799573 2.603794 +C -0.136372 -2.803954 1.481454 +H -1.866220 2.034995 1.025330 +H -2.471053 3.345522 -0.804169 +H -1.134709 2.750884 -1.767197 +H -1.456554 0.352102 -0.759595 +H 1.501755 0.931163 -0.376225 +H 0.605179 1.493068 -1.771451 +H -0.034313 -0.817315 -2.341078 +H 2.936512 -0.614240 -1.756495 +H 2.161452 -0.642216 -3.354828 +H 2.328768 -3.196294 -0.445863 +H 0.916691 -3.757820 -1.364277 +H 0.220348 0.186209 1.753260 +H -2.281324 -1.400229 0.995862 +H -0.813837 -0.888295 3.611462 +H -0.904137 -3.622398 3.302508 +H 0.386989 -3.712701 1.218303 +36 +2111eb1010a628e1df9afdba +C 4.212005 -0.918802 0.294283 +C 3.249982 -1.866028 0.650497 +C 1.919327 -1.487427 0.697437 +C 1.550992 -0.188829 0.397094 +C 2.492347 0.770245 -0.076338 +C 3.859081 0.370224 -0.015108 +N 0.194022 0.201066 0.321523 +C -0.944181 -0.299459 1.028971 +C -0.917918 -0.973087 2.316917 +C -2.071222 -1.334528 2.873397 +C -3.419244 -1.099709 2.241770 +C -3.444747 -0.264572 0.935244 +C -2.085590 0.009548 0.375209 +C -1.693154 0.723776 -0.849377 +C -0.277159 0.858308 -0.739945 +C 0.305463 1.648466 -1.830692 +C 1.647582 1.124878 -2.194963 +C -0.892279 1.568211 -2.867897 +C -2.105150 1.200793 -2.059686 +H 5.258335 -1.195236 0.270875 +H 3.531841 -2.877662 0.884678 +H 1.152406 -2.218527 0.930112 +H 2.266233 1.820647 0.055701 +H 4.627888 1.095819 -0.231830 +H 0.026664 -1.150089 2.808038 +H -2.069685 -1.843364 3.828929 +H -4.063238 -0.621474 2.982030 +H -3.879776 -2.072674 2.059700 +H -3.942872 0.692567 1.114849 +H -4.055693 -0.774527 0.187159 +H 0.349244 2.703516 -1.531257 +H -0.624246 0.747288 -3.542586 +H 1.622215 0.137000 -2.651250 +H 2.303432 1.806594 -2.720425 +H -0.981589 2.469697 -3.473250 +H -3.101318 1.237354 -2.469807 +36 +2186c48c411011a3c63bad30 +O -5.010432 -0.188321 -1.038047 +C -3.787674 -0.642816 -0.681077 +O -3.642421 -1.753706 -0.242283 +N -2.776561 0.267147 -0.846778 +C -1.408040 -0.104155 -0.730545 +C -0.755865 0.849383 0.435887 +C 0.649331 0.613743 0.452980 +O -0.664660 1.603329 -2.150784 +C -0.679816 0.281053 -2.033267 +C 0.717907 -0.399495 -2.077406 +N 1.816570 0.500431 -1.712752 +C 1.457556 1.346505 -0.551521 +C 1.303851 -0.316505 1.427566 +C 0.463799 -1.580230 1.641470 +C 2.719892 -0.702742 0.984488 +C 1.376008 0.478185 2.761219 +H -4.943058 0.720377 -1.340428 +H -2.888088 1.064453 -1.462371 +H -1.321537 -1.144793 -0.454343 +H -1.293919 0.614406 1.343328 +H -0.977517 1.834116 0.045146 +H -1.302478 -0.233693 -2.806541 +H 0.948923 -0.759817 -3.080382 +H 0.717390 -1.264385 -1.409235 +H 1.873569 1.144847 -2.488481 +H 0.836325 2.174801 -0.893717 +H 2.376649 1.702965 -0.094538 +H 0.467610 -2.207672 0.751460 +H -0.570680 -1.341731 1.882462 +H 0.876085 -2.162577 2.462319 +H 2.729258 -1.034096 -0.049473 +H 3.096781 -1.503905 1.616846 +H 3.400135 0.141971 1.065493 +H 0.379784 0.676766 3.148731 +H 1.892089 1.424979 2.621642 +H 1.923191 -0.098823 3.502929 +36 +210e3fac5ae68f0d37c3677e +C -1.787608 -0.268572 2.576578 +C -0.491794 0.261160 1.965575 +C -0.794069 1.272238 0.855592 +S -2.069793 0.713242 -0.320187 +C -2.369045 2.248984 -1.215373 +C 1.003253 2.245603 -2.238620 +S 1.164300 -0.309179 -1.045501 +C 0.303046 -1.395604 0.151441 +C 0.502589 -0.850362 1.559100 +C 0.500506 -2.008965 2.546975 +C 0.366749 -0.829624 -2.568919 +C 2.657000 -1.357148 -1.178835 +H -2.496389 0.546025 2.702825 +H -1.620813 -0.721981 3.551156 +H -2.258326 -1.011812 1.934855 +H 0.017975 0.831115 2.748058 +H 0.091137 1.551565 0.281321 +H -1.189274 2.173251 1.321591 +H -2.652714 3.033200 -0.516805 +H -3.202872 2.052096 -1.884687 +H -1.486397 2.546833 -1.783222 +H 0.846273 3.044903 -1.508325 +H 0.433992 2.448024 -3.148347 +H 2.065718 2.217388 -2.493791 +H 0.714247 -2.394802 0.040768 +H -0.740560 -1.422990 -0.147673 +H 1.502116 -0.403037 1.572989 +H -0.426000 -2.581136 2.486560 +H 0.600936 -1.658568 3.573030 +H 1.323591 -2.690736 2.335767 +H -0.614294 -0.368350 -2.547605 +H 0.300301 -1.909497 -2.672494 +H 0.963135 -0.379559 -3.352435 +H 3.130024 -1.345833 -0.203163 +H 3.313009 -0.906946 -1.914398 +H 2.400043 -2.370923 -1.463800 +36 +21fee233cd9b8153b7db0b4d +C -2.627654 1.300146 -1.334164 +S -0.731146 1.313485 -1.022909 +C -0.311595 1.339685 -2.821385 +C 0.029647 0.048687 -0.214912 +C -0.169112 -1.338986 -0.356269 +C 0.407668 -2.191366 0.545695 +C 1.195379 -1.659290 1.613456 +C 1.369059 -0.317510 1.773630 +C 0.783467 0.598818 0.871312 +C 0.797560 1.997612 0.877779 +C 0.072114 2.560601 -0.228713 +C -0.122509 3.955453 -0.365709 +C 0.465483 4.779589 0.543024 +C 1.229661 4.247674 1.620196 +C 1.387472 2.903986 1.789837 +C 0.206430 -3.639073 0.423982 +C -0.493562 -4.391756 -0.460378 +C -0.329140 -5.755177 -0.059904 +C 0.459637 -5.729982 1.036615 +O 0.796754 -4.451313 1.346547 +H -2.959192 0.418020 -1.875087 +H -2.951317 2.214776 -1.826534 +H -3.010084 1.284359 -0.314284 +H -0.726819 0.452173 -3.286130 +H 0.772132 1.330259 -2.840840 +H -0.701591 2.246026 -3.268900 +H -0.770951 -1.734967 -1.159880 +H 1.655697 -2.341110 2.309612 +H 1.958326 0.038587 2.606086 +H -0.725379 4.358634 -1.172108 +H 0.350724 5.848215 0.450471 +H 1.686264 4.932800 2.316602 +H 1.955758 2.516834 2.623303 +H -1.061018 -4.023529 -1.297763 +H -0.746578 -6.627703 -0.530293 +H 0.858418 -6.484653 1.688011 +36 +21b1e5a32f7725b34bff3620 +C -1.512259 -4.427968 -3.456075 +O -2.136328 -3.255044 -2.992028 +C -1.663293 -2.104874 -3.658239 +C -2.341209 -0.890706 -3.088021 +O -1.832180 -0.629161 -1.796496 +C -2.391527 0.535388 -1.236727 +C -1.951613 0.660353 0.198256 +O -0.580949 1.058432 0.188536 +C 0.071656 1.156704 1.365776 +C -0.523359 1.080376 2.619309 +C 0.274467 1.188329 3.748266 +C 1.649032 1.353326 3.647982 +C 2.227433 1.466989 2.381771 +C 1.441310 1.365593 1.251352 +S 3.980487 1.801635 2.202283 +O 4.195811 3.207474 2.326128 +O 4.660792 0.862220 3.044070 +N 4.278943 1.392872 0.644813 +I 2.670147 1.359143 5.489076 +O 2.230330 3.002599 6.157678 +H -1.906039 -5.258590 -2.877827 +H -0.427763 -4.385617 -3.319641 +H -1.723635 -4.603753 -4.515002 +H -1.887995 -2.170818 -4.728996 +H -0.578813 -2.008176 -3.542317 +H -3.418913 -1.072258 -3.041864 +H -2.166912 -0.031045 -3.744476 +H -3.484983 0.494795 -1.272108 +H -2.064130 1.422038 -1.787659 +H -2.062394 -0.299695 0.703737 +H -2.550380 1.410873 0.716812 +H -1.588579 0.949543 2.731854 +H -0.183784 1.150953 4.727260 +H 1.860894 1.434996 0.262444 +H 4.548919 2.190310 0.089997 +H 4.886849 0.592762 0.560087 +36 +21f6db1a42745e5cef04e599 +C -3.440030 1.944994 0.391180 +O -2.185167 2.259583 -0.164609 +C -1.695983 1.403913 -1.209640 +C -2.654784 1.365867 -2.390558 +C -1.431020 -0.008735 -0.695315 +C -0.432664 -0.030657 0.450279 +N 0.809402 0.621258 0.064542 +C 0.652221 1.973612 -0.465774 +C -0.356312 2.011047 -1.608392 +C 2.041933 0.060221 0.066520 +O 2.983972 0.540063 -0.540422 +C 2.268973 -1.240245 0.837519 +C 3.229288 -2.090185 0.302409 +C 3.489381 -3.322279 0.885901 +C 2.799460 -3.698379 2.028618 +C 1.885424 -2.831097 2.611244 +C 1.641505 -1.596033 2.028534 +B 3.074919 -5.108156 2.612509 +H -3.595419 2.643329 1.207790 +H -3.469849 0.930960 0.795801 +H -4.252047 2.060728 -0.327652 +H -3.586681 0.874022 -2.122959 +H -2.202959 0.812104 -3.211107 +H -2.877885 2.374027 -2.730309 +H -1.032770 -0.611937 -1.512264 +H -2.362073 -0.471281 -0.373933 +H -0.237131 -1.062035 0.712769 +H -0.858684 0.484172 1.313340 +H 1.623090 2.308620 -0.809885 +H 0.309587 2.630170 0.332978 +H 0.053334 1.467952 -2.461228 +H -0.516017 3.045199 -1.911785 +H 3.776203 -1.773840 -0.571239 +H 4.233023 -3.964399 0.436811 +H 1.346528 -3.093020 3.510609 +H 0.969215 -0.909563 2.517716 diff --git a/test/mols_size_var.xyz b/test/mols_size_var.xyz new file mode 100644 index 0000000..8ebadc7 --- /dev/null +++ b/test/mols_size_var.xyz @@ -0,0 +1,299 @@ +24 +894b83f876d0e615b55e2bfb +C -3.948132 1.061743 -1.267402 +N -2.514775 1.129775 -1.012968 +C -1.776592 0.157334 -1.800394 +C -0.280265 0.312549 -1.606070 +O 0.036499 -0.034997 -0.239254 +B 1.158158 -0.246256 0.414658 +C 2.147089 0.290482 1.500052 +C 3.608858 0.164002 1.149637 +C 3.572156 -1.111747 0.774896 +C 2.101558 -1.329272 1.026798 +C 1.673518 -1.551184 2.461941 +C 1.713278 -0.274088 2.836354 +H -4.473478 1.756024 -0.616374 +H -4.373355 0.062106 -1.117094 +H -4.138451 1.354760 -2.298996 +H -2.353188 0.945490 -0.031611 +H -1.991255 0.326497 -2.856629 +H -2.061680 -0.878390 -1.570186 +H 0.015668 1.343332 -1.791375 +H 0.255427 -0.348806 -2.286583 +H 4.382909 0.912914 1.220882 +H 4.301304 -1.819670 0.411552 +H 1.425479 -2.482143 2.945471 +H 1.519268 0.259546 3.752690 +34 +e35b3f96a6f628894e98e52c +C -2.104213 -1.683405 -2.906723 +C -0.933091 -2.531512 -2.430102 +C 0.269877 -1.697274 -1.979091 +C -0.073278 -0.767177 -0.836887 +C -0.575590 -1.227787 0.374393 +C -0.803916 -0.370189 1.439254 +C -0.543892 0.989774 1.339257 +C -0.057369 1.423361 0.118408 +C 0.162630 0.602761 -0.956082 +Cl 0.310522 3.194224 -0.207901 +O 0.798583 3.213918 -1.579941 +O -0.983881 3.832343 -0.046975 +C -0.788295 1.885476 2.531100 +O 0.373303 2.644559 2.889405 +C 1.369787 1.939339 3.482596 +O 1.292922 0.776441 3.790071 +N 2.446995 2.738928 3.664102 +C 1.452495 -2.593287 -1.604985 +H -2.942694 -2.303994 -3.218342 +H -1.817809 -1.061537 -3.754757 +H -2.449602 -1.021750 -2.114142 +H -0.608725 -3.195166 -3.233318 +H -1.250174 -3.177823 -1.607940 +H 0.574361 -1.071846 -2.819544 +H -0.781398 -2.282487 0.503369 +H -1.165541 -0.767741 2.378175 +H 0.530525 1.038046 -1.872102 +H -1.089567 1.266060 3.371290 +H -1.554021 2.627198 2.334495 +H 2.417718 3.700883 3.388933 +H 3.266959 2.351714 4.089226 +H 1.191426 -3.260118 -0.783552 +H 2.310200 -2.001518 -1.291924 +H 1.754761 -3.210397 -2.449759 +44 +47c0faa7a2195e86ce042ca8 +C -3.251350 2.438660 3.091167 +C -1.983738 3.122685 2.607094 +I -0.441355 1.575227 1.911035 +C 0.286397 -0.347744 1.008956 +C 0.334493 -0.306430 -0.507639 +C -0.534985 0.850841 -0.963846 +C -1.051270 1.744297 -0.126859 +C -1.958662 2.885759 -0.473882 +C -2.099531 3.126844 -1.994698 +C -1.524861 2.053851 -2.863908 +C -0.806661 1.035996 -2.406699 +C 1.766664 -0.151500 -1.074492 +C 2.671364 -1.283347 -0.624552 +N 2.054350 -2.610397 -0.830507 +C 0.612663 -2.807968 -0.580232 +C -0.179773 -1.635617 -1.113790 +S 2.965394 -3.827565 -0.201408 +O 2.302557 -5.064984 -0.489194 +O 4.323429 -3.609302 -0.602870 +C 2.894932 -3.594688 1.559849 +H -3.836015 3.097108 3.732481 +H -3.013666 1.548516 3.676051 +H -3.886809 2.118037 2.268515 +H -1.410887 3.524394 3.443473 +H -2.174543 3.943068 1.926368 +H -0.465064 -1.052027 1.365961 +H 1.224473 -0.588054 1.505261 +H -1.571985 3.792296 -0.013336 +H -2.936416 2.701022 -0.026935 +H -3.150762 3.258182 -2.244916 +H -1.617173 4.068185 -2.258155 +H -1.698445 2.159692 -3.925846 +H -0.397284 0.325761 -3.107040 +H 2.185234 0.801156 -0.750523 +H 1.733078 -0.137809 -2.163433 +H 2.937412 -1.137607 0.423193 +H 3.596258 -1.295544 -1.188203 +H 0.415106 -2.946179 0.484694 +H 0.344734 -3.730924 -1.079789 +H -1.237040 -1.762647 -0.890067 +H -0.081208 -1.612594 -2.196900 +H 1.860726 -3.655539 1.881937 +H 3.319144 -2.625585 1.797890 +H 3.481080 -4.387516 2.015791 +34 +f26c9b85fcb5c2450f53354c +C -0.941774 -2.812179 3.638077 +C 0.024964 -1.992409 3.256691 +C 0.779655 -2.236028 1.996546 +O 0.581946 -3.155815 1.244378 +O 1.704199 -1.278933 1.815131 +C 2.489141 -1.302570 0.598882 +C 2.119588 -0.074288 -0.214416 +C 0.643501 -0.034806 -0.596799 +C 0.250169 1.161749 -1.455235 +C -1.247181 1.188678 -1.757730 +C -1.691039 2.401791 -2.588307 +Si -3.476840 2.460983 -3.243286 +C -4.427617 2.467669 -1.556009 +Br -3.995520 4.064796 -0.534851 +Br -6.307364 2.382355 -2.005847 +C 3.943962 -1.350046 1.006949 +H -1.495455 -2.636042 4.548193 +H -1.196814 -3.679748 3.045593 +H 0.305328 -1.119503 3.827040 +H 2.224489 -2.201153 0.045811 +H 2.738274 -0.077934 -1.114125 +H 2.380729 0.821613 0.353861 +H 0.043447 -0.021761 0.315989 +H 0.386866 -0.964901 -1.112606 +H 0.806521 1.141847 -2.393959 +H 0.522767 2.086583 -0.940987 +H -1.774495 1.179462 -0.800704 +H -1.518695 0.262880 -2.269415 +H -1.078013 2.455945 -3.493262 +H -1.474431 3.324601 -2.038873 +H -4.241204 1.629717 -0.896496 +H 4.204778 -0.476076 1.601966 +H 4.578110 -1.372928 0.122750 +H 4.138002 -2.243547 1.595034 +22 +a1e4f0fd5077c51485d871bf +O -2.761952 -2.344092 1.508507 +C -1.432569 -2.083247 1.484961 +O -0.689712 -2.763663 2.131542 +C -0.979997 -0.932181 0.613117 +C -2.044506 -0.151285 -0.197637 +C -1.200108 0.681110 -1.097474 +C -0.486740 1.631718 -0.496916 +C 0.958773 1.406383 -0.620063 +C 1.320699 0.183116 -0.043617 +C 0.364055 -0.766112 0.608538 +C 2.675411 -0.125110 -0.068327 +N 3.595803 0.648446 -0.623402 +N 3.246115 1.793185 -1.182706 +C 1.972924 2.167793 -1.178711 +H -3.253205 -1.696781 0.997588 +H -2.684339 -0.852271 -0.734974 +H -0.764826 0.156518 -1.945856 +H -2.665575 0.442491 0.474074 +H -0.860386 2.034506 0.442423 +H 0.871300 -1.509329 1.214332 +H 3.057605 -1.039946 0.365955 +H 1.761215 3.118748 -1.651352 +30 +151f42f5540e42c1bafe54e4 +C 0.632688 0.311523 3.235998 +C 1.266442 -0.788801 2.386033 +B 0.949823 -0.684665 1.027859 +C 0.337499 -0.266726 -0.267832 +C 0.755693 -0.594238 -1.516917 +N 1.733153 -1.496116 -1.753831 +C 0.157439 0.033653 -2.705255 +C -0.861264 0.977012 -2.501738 +N -1.289439 1.260839 -1.291259 +C -0.837834 0.658061 -0.031742 +C -2.021724 -0.007059 0.663361 +C -1.538970 1.736345 -3.607680 +B 0.565991 -0.246445 -4.099158 +C 2.104218 -1.764962 3.191174 +H -0.027964 0.957396 2.660768 +H 1.399057 0.937544 3.693096 +H 0.044776 -0.124241 4.046168 +H 2.218467 -1.931679 -0.993037 +H 2.070454 -1.655934 -2.687387 +H -2.016985 1.951989 -1.191696 +H -0.517984 1.486985 0.606197 +H -2.428169 -0.808986 0.048455 +H -2.797384 0.730860 0.860445 +H -1.708102 -0.429465 1.613823 +H -1.138526 1.456338 -4.576606 +H -1.397272 2.804524 -3.456105 +H -2.607016 1.529764 -3.587445 +H 2.116948 -1.445603 4.233707 +H 3.136525 -1.809778 2.842064 +H 1.699452 -2.778128 3.158541 +25 +3497522cc00e5cfd7bca4fb4 +N 3.906988 1.954006 2.160210 +C 3.007133 2.386573 1.594463 +C 1.859406 2.892590 0.881720 +C 1.182176 2.099169 0.072972 +C 0.612752 1.323357 -0.814533 +C -0.149460 0.131481 -0.467273 +O -0.162743 -0.252895 0.832773 +C -0.820804 -1.447870 0.883780 +C -1.242125 -1.819596 -0.353025 +C -0.806270 -0.775523 -1.231063 +C -1.974757 -3.068913 -0.727301 +C -0.934205 -2.084876 2.220576 +C 0.764317 1.619107 -2.292977 +C 1.498435 4.273368 1.086614 +N 1.212145 5.377534 1.212044 +H -0.963580 -0.722300 -2.293399 +H -2.180250 -3.679531 0.149784 +H -1.389092 -3.668498 -1.422187 +H -2.923939 -2.838937 -1.207367 +H -0.313743 -1.548393 2.933655 +H -0.601820 -3.120597 2.183991 +H -1.960001 -2.073592 2.584947 +H 1.509392 2.386687 -2.464477 +H -0.190010 1.945494 -2.706817 +H 1.060068 0.712162 -2.817120 +16 +39e2c45440ca87e56cb7fa29 +O -1.014588 -2.634654 1.793546 +C -1.552462 -1.430617 1.602442 +O -2.574103 -1.075324 2.097097 +C -0.722871 -0.523958 0.673357 +N -1.129161 0.722200 0.455676 +C -0.334164 1.398226 -0.334334 +N 0.771282 0.987744 -0.924501 +C 1.051195 -0.257860 -0.630474 +N 0.371810 -1.064771 0.147354 +N 2.309724 -0.845101 -1.254917 +O 2.942280 -0.118095 -1.973016 +O 2.556923 -1.990427 -0.969595 +N -0.735048 2.845030 -0.597600 +O -1.765243 3.208772 -0.093545 +O 0.015946 3.491838 -1.281499 +H -0.191523 -2.712999 1.290010 +28 +90eecab78ed1b35bd990ab41 +O 3.330770 2.903872 1.761034 +C 3.413864 1.558937 1.821045 +O 4.366066 1.004366 2.293595 +C 2.264332 0.809637 1.271064 +C 1.102421 1.245924 0.766043 +C 0.279306 0.097306 0.353499 +C -0.867349 0.103768 -0.448325 +C -1.231359 1.381181 -1.021114 +C -2.167358 1.542756 -1.977763 +C -2.891615 0.402210 -2.465832 +C -2.668823 -0.803996 -1.925868 +C -1.717808 -1.038852 -0.850671 +C -1.886896 -2.214570 -0.135804 +C -0.871536 -2.363240 0.845817 +C 0.385549 -2.188303 0.340337 +N 0.999701 -1.017180 0.698412 +C 2.311205 -0.685235 1.230577 +H 2.508280 3.159099 1.330286 +H 0.809387 2.278598 0.670799 +H -0.699529 2.252224 -0.676044 +H -2.375257 2.519224 -2.385321 +H -3.628546 0.538853 -3.246742 +H -3.249921 -1.660348 -2.243270 +H -2.793298 -2.803685 -0.226813 +H -1.086475 -2.104573 1.873415 +H 0.790065 -2.766803 -0.476750 +H 3.122324 -1.048905 0.598228 +H 2.452499 -1.102261 2.226163 +22 +a49d311aff2713cbe7d857cf +O 1.928443 -1.565235 -0.983191 +C 1.360832 -2.685709 -0.521925 +O 1.935457 -3.735882 -0.511664 +C -0.042404 -2.560774 0.003075 +C -1.021272 -1.688218 -0.555504 +C -0.939070 -0.664086 -1.455165 +C -0.597411 0.529293 -0.738165 +C -0.269185 0.171532 0.562007 +C -0.569120 -1.230091 0.778149 +C 0.171037 1.195853 1.565951 +N -0.206354 2.528855 1.090431 +C -0.321085 2.832592 -0.237700 +C -0.504149 1.898171 -1.188421 +H 1.309913 -0.830485 -0.868006 +H -0.318259 -3.431509 0.581345 +H -0.972513 -0.747064 -2.529442 +H -0.967080 -1.553664 1.727926 +H 1.254733 1.155195 1.732794 +H -0.307140 1.031609 2.534712 +H 0.003515 3.288737 1.708415 +H -0.303212 3.887120 -0.473734 +H -0.625683 2.173759 -2.221885 diff --git a/test/test_calculator.py b/test/test_calculator.py new file mode 100644 index 0000000..38ec8b7 --- /dev/null +++ b/test/test_calculator.py @@ -0,0 +1,104 @@ +import ase.io +from aimnet2calc.calculator import AIMNet2Calculator +import os +import numpy as np + + +MODELS = ('aimnet2', 'aimnet2_b973c') +DIR = os.path.dirname(__file__) + + +def _struct_pbc(): + filename = os.path.join(DIR, '1008775.cif') + atoms = ase.io.read(filename) + ret = dict() + ret['coord'] = atoms.positions + ret['numbers'] = atoms.numbers + ret['charge'] = 0.0 + ret['cell'] = atoms.cell.array + return ret + + +def _struct_list(): + filename = os.path.join(DIR, 'mols_size_var.xyz') + atoms = ase.io.read(filename, index=':') + ret = dict() + ret['coord'] = np.concatenate([a.positions for a in atoms]) + ret['numbers'] = np.concatenate([a.numbers for a in atoms]) + ret['mol_idx'] = np.concatenate([[i] * len(a) for i, a in enumerate(atoms)]) + ret['charge'] = [0.0] * len(atoms) + return ret + + +def _stuct_batch(): + filename = os.path.join(DIR, 'mols_size_36.xyz') + atoms = ase.io.read(filename, index=':') + ret = dict() + ret['coord'] = [a.positions for a in atoms] + ret['numbers'] = [a.numbers for a in atoms] + ret['charge'] = [0.0] * len(atoms) + return ret + + +def _test_energy(calc, data): + _out = calc(data) + assert 'energy' in _out + assert len(_out['energy']) == len(data['charge']) + assert _out['energy'].requires_grad == False + + +def _test_forces(calc, data): + _out = calc(data, forces=True) + assert 'energy' in _out + assert 'forces' in _out + assert len(_out['energy']) == len(data['charge']) + assert _out['energy'].requires_grad == True + assert len(_out['forces']) == len(data['coord']), _out['forces'].shape + assert _out['forces'].requires_grad == False + + +def _test_forces_stress(calc, data): + _out = calc(data, forces=True, stress=True) + assert 'energy' in _out + assert 'forces' in _out + assert 'stress' in _out + assert len(_out['energy']) == len(data['charge']) + assert _out['energy'].requires_grad == True + assert len(_out['forces']) == len(data['coord']) + assert _out['forces'].requires_grad == False + assert len(_out['stress']) == 3 + assert _out['stress'].requires_grad == False + + +def _test_hessian(calc, data): + _out = calc(data, hessian=True) + assert 'energy' in _out + assert 'forces' in _out + assert 'hessian' in _out + assert len(_out['energy']) == len(data['charge']) + assert _out['energy'].requires_grad == True + assert len(_out['forces']) == len(data['coord']) + assert _out['forces'].requires_grad == False + assert len(_out['hessian']) == len(data['coord']) + assert _out['hessian'].requires_grad == False + + +def test_calculator(): + for model in MODELS: + print('Testing model:', model) + calc = AIMNet2Calculator(model) + for data, typ in zip((_stuct_batch(), _struct_list(), _struct_pbc()), ('batch', 'list', 'pbc')): + if typ == 'pbc' and not (calc.cutoff_lr < float('inf')): + print('Skipping PBC with LR') + continue + print('Testing data:', typ) + print('energy: ', _test_energy(calc, data)) + print('forces: ', _test_forces(calc, data)) + if len(data['charge']) == 1: + print('hessian: ', _test_hessian(calc, data)) + if typ == 'pbc': + print('forces+stress: ', _test_forces_stress(calc, data)) + + +if __name__ == '__main__': + test_calculator() \ No newline at end of file From 3a8e983d4ed46a9c6a2ccd584c5988a54c9ebf6b Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 15:42:55 -0400 Subject: [PATCH 02/27] fixes --- .gitignore | 40 +++++++++++++++++++++++++++ aimnet2calc/calculator.py | 57 +++++++++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 8 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..570674b --- /dev/null +++ b/.gitignore @@ -0,0 +1,40 @@ +# vscode history +.history/ + +# Python +__pycache__/ +*.pyc +*.pyo +*.pyd + +# Virtual Environment +venv/ +env/ +env.bak/ +env1/ +env2/ +.env + +# IDE +.vscode/ +*.code-workspace + +# Logs +*.log + +# Build +build/ +dist/ +*.egg-info/ +*.egg + +# Tests +test-reports/ +.coverage + +# Miscellaneous +*.bak +*.swp +*.tmp +*.tmp.* +*~ \ No newline at end of file diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index 1957ebb..ca7b951 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -10,8 +10,22 @@ class AIMNet2Calculator: A helper class to load AIMNet2 models and perform inference. """ - keys_in = ['coord', 'numbers', 'charge'] - keys_in_optional = ['mult', 'mol_idx', 'nbmat', 'nbmat_lr', 'nb_pad_mask', 'nb_pad_mask_lr', 'shifts', 'cell'] + keys_in = { + 'coord': torch.float, + 'numbers': torch.int, + 'charge': torch.float + } + keys_in_optional = { + 'mult': torch.float, + 'mol_idx': torch.int, + 'nbmat': torch.int, + 'nbmat_lr': torch.int, + 'nb_pad_mask': torch.bool, + 'nb_pad_mask_lr': torch.bool, + 'shifts': torch.float, + 'shifts_lr': torch.float, + 'cell': torch.float + } keys_out = ['energy', 'charges', 'forces'] atom_feature_keys = ['coord', 'numbers', 'charges', 'forces'] @@ -31,11 +45,32 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): # indicator if input was flattened self._batch = None + # placeholder for tensors that require grad self._saved_for_grad = None + # set flag of current Coulomb model + coul_methods = set(getattr(mod, 'method', None) for mod in iter_lrcoulomb_mods(self.model)) + assert len(coul_methods) == 1, 'Multiple Coulomb methods found.' + self._coulomb_method = coul_methods.pop() def __call__(self, *args, **kwargs): return self.eval(*args, **kwargs) + def set_lrcoulomb_method(self, method, dsf_cutoff=15.0, dsf_alpha=0.2): + assert method in ('simple', 'dsf', 'ewald'), f'Invalid method: {method}' + if method == 'simple': + for mod in iter_lrcoulomb_mods(self.model): + mod.method = 'simple' + self.cutoff_lr = float('inf') + elif method == 'dsf': + for mod in iter_lrcoulomb_mods(self.model): + mod.method = 'dsf' + self.cutoff_lr = dsf_cutoff + mod.dsf_alpha = dsf_alpha + elif method == 'ewald': + for mod in iter_lrcoulomb_mods(self.model): + mod.method = 'ewald' + self._coulomb_method = method + def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]: data = self.prepare_input(data) if hessian and data['mol_idx'][-1] > 0: @@ -50,8 +85,12 @@ def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) def prepare_input(self, data: Dict[str, Any]) -> Dict[str, Tensor]: data = self.to_input_tensors(data) data = self.mol_flatten(data) - if 'cell' in data and data['cell'] is not None and data['mol_idx'][-1] > 0: - raise NotImplementedError('PBC with multiple molecules is not supported') + if data.get('cell') is not None: + if data['mol_idx'][-1] > 0: + raise NotImplementedError('PBC with multiple molecules is not implemented yet.') + if self._coulomb_method == 'simple': + print('Switching to DSF Coulomb for PBC') + self.set_lrcoulomb_method('dsf') data = self.make_nbmat(data) data = self.pad_input(data) return data @@ -67,10 +106,10 @@ def to_input_tensors(self, data: Dict[str, Any]) -> Dict[str, Tensor]: for k in self.keys_in: assert k in data, f'Missing key {k} in the input data' # always detach !! - ret[k] = torch.as_tensor(data[k], device=self.device).detach() + ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in[k]).detach() for k in self.keys_in_optional: if k in data and data[k] is not None: - ret[k] = torch.as_tensor(data[k], device=self.device).detach() + ret[k] = torch.as_tensor(data[k], device=self.device, dtype=self.keys_in_optional[k]).detach() # convert any scalar tensors to shape (1,) tensors for k, v in ret.items(): if v.ndim == 0: @@ -112,12 +151,14 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: if 'nbmat_lr' not in data: assert self.cutoff_lr < torch.inf, 'Long-range cutoff must be finite for PBC' data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr) + data['cutoff_lr'] = torch.tensor(self.cutoff_lr, device=self.device) else: if 'nbmat' not in data: data['nbmat'] = nblist_torch_cluster(data['coord'], self.cutoff, data['mol_idx'], max_nb=128) if self.lr: if 'nbmat_lr' not in data: data['nbmat_lr'] = nblist_torch_cluster(data['coord'], self.cutoff_lr, data['mol_idx'], max_nb=1024) + data['cutoff_lr'] = torch.tensor(self.cutoff_lr, device=self.device) return data def pad_input(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: @@ -223,7 +264,7 @@ def _named_children_rec(module): yield from _named_children_rec(module) -def set_lrcoulomb_method(model, method): +def iter_lrcoulomb_mods(model): for name, module in _named_children_rec(model): if name == 'lrcoulomb': - module.set_method(method) + yield module From 7eb032cf53fd05031748f2cd31d0d97a9ae97560 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 15:55:07 -0400 Subject: [PATCH 03/27] fix model zoo --- aimnet2calc/models.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/aimnet2calc/models.py b/aimnet2calc/models.py index a6608a3..f14e93f 100644 --- a/aimnet2calc/models.py +++ b/aimnet2calc/models.py @@ -3,13 +3,9 @@ # model registry aliases model_registry_aliases = {} -model_registry_aliases['aimnet2'] = 'aimnet2_wb97m_0_240428' +model_registry_aliases['aimnet2'] = 'aimnet2_wb97m_0' model_registry_aliases['aimnet2_wb97m'] = model_registry_aliases['aimnet2'] -model_registry_aliases['aimnet2_wb97m_ens'] = 'aimnet2_wb97m_ens_240428' -model_registry_aliases['aimnet2_ens'] = model_registry_aliases['aimnet2_wb97m_ens'] -model_registry_aliases['aimnet2_b973c'] = 'aimnet2_b973c_0_240428' -model_registry_aliases['aimnet2_b973c_ens'] = 'aimnet2_b973c_ens_240428' -model_registry_aliases['aimnet2_qr'] = 'aimnet2_qr_b97m_qzvp' +model_registry_aliases['aimnet2_b973c'] = 'aimnet2_b973c_0' def get_model_path(s: str): From 96222c62e175d5dcfc51361872b252fad676052d Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 15:55:07 -0400 Subject: [PATCH 04/27] fix model zoo --- aimnet2calc/models.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/aimnet2calc/models.py b/aimnet2calc/models.py index a6608a3..985f1d1 100644 --- a/aimnet2calc/models.py +++ b/aimnet2calc/models.py @@ -3,13 +3,9 @@ # model registry aliases model_registry_aliases = {} -model_registry_aliases['aimnet2'] = 'aimnet2_wb97m_0_240428' +model_registry_aliases['aimnet2'] = 'aimnet2_wb97m_0' model_registry_aliases['aimnet2_wb97m'] = model_registry_aliases['aimnet2'] -model_registry_aliases['aimnet2_wb97m_ens'] = 'aimnet2_wb97m_ens_240428' -model_registry_aliases['aimnet2_ens'] = model_registry_aliases['aimnet2_wb97m_ens'] -model_registry_aliases['aimnet2_b973c'] = 'aimnet2_b973c_0_240428' -model_registry_aliases['aimnet2_b973c_ens'] = 'aimnet2_b973c_ens_240428' -model_registry_aliases['aimnet2_qr'] = 'aimnet2_qr_b97m_qzvp' +model_registry_aliases['aimnet2_b973c'] = 'aimnet2_b973c_0' def get_model_path(s: str): @@ -23,6 +19,7 @@ def get_model_path(s: str): # add jpt extension if not s.endswith('.jpt'): s = s + '.jpt' + os.makedirs(os.path.join(os.path.dirname(__file__), 'assets'), exist_ok=True) s_local = os.path.join(os.path.dirname(__file__), 'assets', s) if os.path.isfile(s_local): print('Found model file:', s_local) From e9b20956b9c7907dca151734eb7be20eea62cf76 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 16:07:11 -0400 Subject: [PATCH 05/27] fix --- aimnet2calc/calculator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index ca7b951..f87acfb 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -26,7 +26,7 @@ class AIMNet2Calculator: 'shifts_lr': torch.float, 'cell': torch.float } - keys_out = ['energy', 'charges', 'forces'] + keys_out = ['energy', 'charges', 'forces', 'hessian', 'stress'] atom_feature_keys = ['coord', 'numbers', 'charges', 'forces'] def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): @@ -182,7 +182,7 @@ def set_grad_tensors(self, data: Dict[str, Tensor], forces=False, stress=False, self._saved_for_grad['coord'] = data['coord'] if stress: assert 'cell' in data, 'Stress calculation requires cell' - scaling = torch.eye(3, requires_grad=True, dtype=data['cell'].dtype, device=data['cell'].dtype) + scaling = torch.eye(3, requires_grad=True, dtype=data['cell'].dtype, device=data['cell'].device) data['coord'] = data['coord'] @ scaling data['cell'] = data['cell'] @ scaling self._saved_for_grad['scaling'] = scaling From fe031a9a48f8af4fdca3fef845be6a4d3a7ceab5 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 16:55:22 -0400 Subject: [PATCH 06/27] fix jit.optimized_execution --- aimnet2calc/calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index f87acfb..ad96128 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -76,7 +76,7 @@ def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) if hessian and data['mol_idx'][-1] > 0: raise NotImplementedError('Hessian calculation is not supported for multiple molecules') data = self.set_grad_tensors(data, forces=forces, stress=stress, hessian=hessian) - with torch.jit.optimized_execution(True): + with torch.jit.optimized_execution(False): data = self.model(data) data = self.get_derivatives(data, forces=forces, stress=stress, hessian=hessian) data = self.process_output(data) From e0dea01ea213922d134885f768e0f4162ebb5a39 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 18:31:16 -0400 Subject: [PATCH 07/27] add pysisiphus interface --- aimnet2calc/aimnet2ase.py | 2 +- aimnet2calc/aimnet2pysis.py | 67 +++++++++++++++++++++++++++++++++++++ test/mol_single.xyz | 24 +++++++++++++ test/mol_single_opt.yml | 12 +++++++ 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 aimnet2calc/aimnet2pysis.py create mode 100644 test/mol_single.xyz create mode 100644 test/mol_single_opt.yml diff --git a/aimnet2calc/aimnet2ase.py b/aimnet2calc/aimnet2ase.py index dd6031e..7429b4c 100644 --- a/aimnet2calc/aimnet2ase.py +++ b/aimnet2calc/aimnet2ase.py @@ -6,7 +6,7 @@ class AIMNet2ASE(Calculator): implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress'] - def __init__(self, base_calc: Union[AIMNet2Calculator, str], charge=0, mult=1): + def __init__(self, base_calc: Union[AIMNet2Calculator, str] = 'aimnet2', charge=0, mult=1): super().__init__() if isinstance(base_calc, str): base_calc = AIMNet2Calculator(base_calc) diff --git a/aimnet2calc/aimnet2pysis.py b/aimnet2calc/aimnet2pysis.py new file mode 100644 index 0000000..07b31d7 --- /dev/null +++ b/aimnet2calc/aimnet2pysis.py @@ -0,0 +1,67 @@ +from pysisyphus.calculators.Calculator import Calculator +from pysisyphus.elem_data import ATOMIC_NUMBERS +from pysisyphus.constants import BOHR2ANG, ANG2BOHR, AU2EV +from aimnet2calc import AIMNet2Calculator +from typing import Union +import torch + + +EV2AU = 1 / AU2EV + + +class AIMNet2Pysis(Calculator): + implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress'] + def __init__(self, model: Union[AIMNet2Calculator, str] = 'aimnet2', charge=0, mult=1, **kwargs): + super().__init__(charge=charge, mult=mult, **kwargs) + if isinstance(model, str): + model = AIMNet2Calculator(model) + self.model = model + + def _prepere_input(self, atoms, coord): + device = self.base_calc.device + numbers = torch.as_tensor([ATOMIC_NUMBERS[a.lower()] for a in atoms], device=device) + coord = torch.as_tensor(coord, dtype=torch.float, device=device).view(-1, 3) * BOHR2ANG + charge = torch.as_tensor([self.charge], dtype=torch.float, device=device) + mult = torch.as_tensor([self.mult], dtype=torch.float, device=device) + return dict(coord=coord, numbers=numbers, charge=charge, mult=mult) + + @staticmethod + def _results_get_energy(results): + return results['energy'].item() * EV2AU + + @staticmethod + def _results_get_forces(results): + return (results['forces'].detach() * (EV2AU / ANG2BOHR)).flatten().to(torch.double).cpu().numpy() + + @staticmethod + def _results_get_hessian(results): + return (results['hessian'].flatten(0, 1).flatten(-2, -1) * (EV2AU / ANG2BOHR / ANG2BOHR)).to(torch.double).cpu().numpy() + + + def get_energy(self, atoms, coords): + _in = self._prepere_input(atoms, coords) + res = self.model(_in) + energy = self._results_get_energy(res) + return dict(energy=energy) + + def get_forces(self, atoms, coords): + _in = self._prepere_input(atoms, coords) + res = self.model(_in, forces=True) + energy = self._results_get_energy(res) + forces = self._results_get_forces(res) + return dict(energy=energy, forces=forces) + + def get_hessian(self, atoms, coords): + _in = self._prepere_input(atoms, coords) + res = self.model(_in, forces=True, hessian=True) + energy = self._results_get_energy(res) + forces = self._results_get_forces(res) + hessian = self._results_get_hessian(res) + return dict(energy=energy, forces=forces, hessian=hessian) + + +def run_pysis(): + from pysisyphus import run + run.CALC_DICT['aimnet'] = AIMNet2Pysis + run.run() + diff --git a/test/mol_single.xyz b/test/mol_single.xyz new file mode 100644 index 0000000..ad2b77d --- /dev/null +++ b/test/mol_single.xyz @@ -0,0 +1,24 @@ +22 +a49d311aff2713cbe7d857cf +O 1.928443 -1.565235 -0.983191 +C 1.360832 -2.685709 -0.521925 +O 1.935457 -3.735882 -0.511664 +C -0.042404 -2.560774 0.003075 +C -1.021272 -1.688218 -0.555504 +C -0.939070 -0.664086 -1.455165 +C -0.597411 0.529293 -0.738165 +C -0.269185 0.171532 0.562007 +C -0.569120 -1.230091 0.778149 +C 0.171037 1.195853 1.565951 +N -0.206354 2.528855 1.090431 +C -0.321085 2.832592 -0.237700 +C -0.504149 1.898171 -1.188421 +H 1.309913 -0.830485 -0.868006 +H -0.318259 -3.431509 0.581345 +H -0.972513 -0.747064 -2.529442 +H -0.967080 -1.553664 1.727926 +H 1.254733 1.155195 1.732794 +H -0.307140 1.031609 2.534712 +H 0.003515 3.288737 1.708415 +H -0.303212 3.887120 -0.473734 +H -0.625683 2.173759 -2.221885 diff --git a/test/mol_single_opt.yml b/test/mol_single_opt.yml new file mode 100644 index 0000000..b4d086c --- /dev/null +++ b/test/mol_single_opt.yml @@ -0,0 +1,12 @@ +calc: + type: aimnet + base_calc: aimnet2_b973c + +geom: + type: dlc + fn: mol_single.xyz + +opt: + type: rfo + max_cycles: 50 + From 8fc04e7d1b2e74106123812b59b4dfe0da7b46fb Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 18:37:55 -0400 Subject: [PATCH 08/27] add setup.py --- setup.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..06df753 --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup, find_packages + +setup( + name='aimnet2calc', + version='0.0.1', + author='Roman Zubatyuk', + author_email='zubatyuk@gmail.com', + description='Interface for AIMNet2 models', + packages=find_packages(), + install_requires=[ + 'torch>2.0,<3', + 'torch_cluster', + 'numpy', + 'numba', + 'ase', + 'pysisyphus', + 'requests', + # 'openbabel' + ], + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: MIT License', + 'Operating System :: OS Independent', + ], + entry_points={ + 'console_scripts': [ + 'aimnet2pysis=aimnet2calc.aimnet2pysis:run_pysis' + ], + }, +) \ No newline at end of file From c18ce04fa30453d0bead92889a7771799e9d1744 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 19:40:41 -0400 Subject: [PATCH 09/27] fix --- aimnet2calc/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/aimnet2calc/__init__.py b/aimnet2calc/__init__.py index 99f862c..5a59e3a 100644 --- a/aimnet2calc/__init__.py +++ b/aimnet2calc/__init__.py @@ -9,4 +9,11 @@ warnings.warn('ASE is not installed. AIMNet2ASE will not be available.') pass +try: + from .aimnet2pysis import AIMNet2Pysis + __all__.append('AIMNet2Pysis') +except ImportError: + import warnings + warnings.warn('PySisiphus is not installed. AIMNet2Pysis will not be available.') + pass From b029813d32915a41e993d51778012b684d4a35a9 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 20:03:50 -0400 Subject: [PATCH 10/27] add README --- README.md | 127 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..fab3a6b --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ +# AIMNet2 Calculator: Fast, Accurate Molecular Simulations + +This package integrates the powerful AIMNet2 neural network potential into your simulation workflows. AIMNet2 provides fast and reliable energy, force, and property calculations for molecules containing a diverse range of elements. + +## Key Features: + +- **Accurate and Versatile:** AIMNet2 excels at modeling neutral, charged, organic, and elemental-organic systems. +- **Flexible Interfaces:** Use AIMNet2 through convenient calculators for popular simulation packages like ASE and PySisiphus. +- **Flexible Long-Range Interactions:** Optionally employ the Dumped-Shifted Force (DSF) or Ewald summation Coulomb models for accurate calculations in large or periodic systems. + + +## Getting Started + +### 1. Installation + +While package is in alpha stage and repository is private, please install manually with +``` +git clone git@github.com:zubatyuk/aimnet2calc.git +cd aimnet2calc +python setup.py install +``` + +### 2. Available interfaces + +#### ASE [[https://wiki.fysik.dtu.dk/ase]](https://wiki.fysik.dtu.dk/ase) + +``` +from aimnet2calc import AIMNet2ASE +calc = AIMNet2ASE('aimnet2') +``` + +To specify total molecular charge and spin multiplicity, use optional `charge` and `mult` keyword arguments, or `set_charge` and `set_mult` methods: + +``` +calc = AIMNet2ASE('aimnet2', charge=1) +atoms1.calc = calc +# calculations on atoms1 will be done with charge 1 +.... +atoms2.calc = calc +calc.set_charge(-2) +# calculations on atoms1 will be done with charge -2 +``` + +#### PySisiphus [[https://pysisyphus.readthedocs.io]](https://pysisyphus.readthedocs.io/) + +``` +from aimnet2calc import AIMNet2Pysis +calc = AIMNet2Pysis('aimnet2') +``` + +This produces standard PySisiphus calculator. + +Instead of `pysis` command line utility, use `aimnet2pysis`. This registeres AIMNet2 calculator with PySisiphus. +Example `calc` section for PySisiphus YAML files: + +``` +calc: + type: aimnet # use AIMNet2 calculator + model: aimnet2_b973c # use aimnet2_b973c_0.jpt model +``` + +### 3. Base calculator + +``` +from aimnet2calc import AIMNetCalculator +``` + +#### Initialization + +``` +calc = AIMNetCalculator('aimnet2') +``` +will load default AIMNet2 model aimnet2_wb97m_0.jpt as defined at `aimnet2calc/models.py` . If file does not exist on the machine, it will be downloaded from [aimnet-model-zoo](http://github.com/zubatyuk/aimnet-model-zoo) repository. + +``` +calc = AIMNetCalculator('/path/to_a/model.jpt') +``` +will load model from the file. + +#### Input structure + +The calculator accepts a dictionary containig lists, numpy arrays, torch tensors, or anything that could be accepted by `torch.as_tensor`. + +The input could be for a single molecule (dict keys and shapes): + +``` +coord: (B, N, 3) # atomic coordinates in Angstrom +numbers (B, N) # atomic numbers +charge (N,) # molecular charge +mult (N,) # spin multiplicity, optional +``` + +or for a concatenation of molecules: + +``` +coord: (N, 3) # atomic coordinates in Angstrom +numbers (N,) # atomic numbers +charge (B,) # molecular charge +mult (B,) # spin multiplicity, optional +mol_idx (N,) # molecule index for each atom, should contain integers in increasing order. +``` + +where `B` is the number of molecules, `N` is number of atoms. + + +#### Calling calculator + +``` +results = calc(data, forces=False, stress=False, hessian=False) +``` + +`results` would be a dictionary of PyTorch tensors containing `energy`, `charges`, and possibly `forces`, `stress` and `hessian` if requested. + +### 4. Long range Coulomb model + +By default, Coulomb energy is calculated in O(N^2) manner, e.g. pair interaction between every pair of atoms in system. For very large or periodic systems, O(N) Dumped-Shifted Force Coulomb model could be employed [doi: 10.1063/1.2206581](https://doi.org/10.1063/1.2206581). With `AIMNetCalculator` interface, switch between standard and DSF Coulomb implementations im AIMNet2 models: + +``` +# switch to O(N^2) +calc.set_lrcoulomb_method('sdf', dsf_cutoff=15.0, dsf_alpha=0.2) +# switch to O(N) +calc.set_lrcoulomb_method('simple') +``` + + + + From 9646e33a4a0c04e4f3f0022ee95d69f699dcf35b Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 20:06:11 -0400 Subject: [PATCH 11/27] fix pysis yaml --- test/mol_single_opt.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mol_single_opt.yml b/test/mol_single_opt.yml index b4d086c..f35dbf1 100644 --- a/test/mol_single_opt.yml +++ b/test/mol_single_opt.yml @@ -1,6 +1,6 @@ calc: type: aimnet - base_calc: aimnet2_b973c + model: aimnet2_b973c geom: type: dlc From 5f6e3808cdd9e065a3e81b4a852c5646f6202cec Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 20:44:18 -0400 Subject: [PATCH 12/27] update readme --- README.md | 27 ++++++++++++++++++--------- setup.py | 2 +- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index fab3a6b..04f4759 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This package integrates the powerful AIMNet2 neural network potential into your ## Key Features: - **Accurate and Versatile:** AIMNet2 excels at modeling neutral, charged, organic, and elemental-organic systems. -- **Flexible Interfaces:** Use AIMNet2 through convenient calculators for popular simulation packages like ASE and PySisiphus. +- **Flexible Interfaces:** Use AIMNet2 through convenient calculators for popular simulation packages like ASE and PySisyphus. - **Flexible Long-Range Interactions:** Optionally employ the Dumped-Shifted Force (DSF) or Ewald summation Coulomb models for accurate calculations in large or periodic systems. @@ -13,8 +13,17 @@ This package integrates the powerful AIMNet2 neural network potential into your ### 1. Installation -While package is in alpha stage and repository is private, please install manually with -``` +While package is in alpha stage and repository is private, please install into your conda envoronment manually with +``` +# install requirements +conda install -y pytorch pytorch-cuda=12.1 -c pytorch -c nvidia +conda install -y -c pyg pytorch-cluster +conda install -y -c conda-forge openbabel ase +## pysis requirwements +conda install -y -c conda-forge autograd dask distributed h5py fabric jinja2 joblib matplotlib numpy natsort psutil pyyaml rmsd scipy sympy scikit-learn +# now should not do any pip installs +pip install git+https://github.com/eljost/pysisyphus.git +# finally, this repo git clone git@github.com:zubatyuk/aimnet2calc.git cd aimnet2calc python setup.py install @@ -41,17 +50,17 @@ calc.set_charge(-2) # calculations on atoms1 will be done with charge -2 ``` -#### PySisiphus [[https://pysisyphus.readthedocs.io]](https://pysisyphus.readthedocs.io/) +#### PySisyphus [[https://pysisyphus.readthedocs.io]](https://pysisyphus.readthedocs.io/) ``` -from aimnet2calc import AIMNet2Pysis -calc = AIMNet2Pysis('aimnet2') +from aimnet2calc import AIMNet2PySis +calc = AIMNet2PySis('aimnet2') ``` -This produces standard PySisiphus calculator. +This produces standard PySisyphus calculator. -Instead of `pysis` command line utility, use `aimnet2pysis`. This registeres AIMNet2 calculator with PySisiphus. -Example `calc` section for PySisiphus YAML files: +Instead of `Pysis` command line utility, use `aimnet2Pysis`. This registeres AIMNet2 calculator with PySisyphus. +Example `calc` section for PySisyphus YAML files: ``` calc: diff --git a/setup.py b/setup.py index 06df753..6e23ee9 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ packages=find_packages(), install_requires=[ 'torch>2.0,<3', - 'torch_cluster', + 'torch-cluster', 'numpy', 'numba', 'ase', From 1bccd014b8a90280eb2b6e9c1ad272c6f3e90a07 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Sat, 4 May 2024 20:50:28 -0400 Subject: [PATCH 13/27] update readme --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 04f4759..804220a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ While package is in alpha stage and repository is private, please install into y conda install -y pytorch pytorch-cuda=12.1 -c pytorch -c nvidia conda install -y -c pyg pytorch-cluster conda install -y -c conda-forge openbabel ase -## pysis requirwements +## pysis requirements conda install -y -c conda-forge autograd dask distributed h5py fabric jinja2 joblib matplotlib numpy natsort psutil pyyaml rmsd scipy sympy scikit-learn # now should not do any pip installs pip install git+https://github.com/eljost/pysisyphus.git @@ -59,7 +59,7 @@ calc = AIMNet2PySis('aimnet2') This produces standard PySisyphus calculator. -Instead of `Pysis` command line utility, use `aimnet2Pysis`. This registeres AIMNet2 calculator with PySisyphus. +Instead of `Pysis` command line utility, use `aimnet2pysis`. This registeres AIMNet2 calculator with PySisyphus. Example `calc` section for PySisyphus YAML files: ``` @@ -95,8 +95,8 @@ The input could be for a single molecule (dict keys and shapes): ``` coord: (B, N, 3) # atomic coordinates in Angstrom numbers (B, N) # atomic numbers -charge (N,) # molecular charge -mult (N,) # spin multiplicity, optional +charge (B,) # molecular charge +mult (B,) # spin multiplicity, optional ``` or for a concatenation of molecules: @@ -106,7 +106,7 @@ coord: (N, 3) # atomic coordinates in Angstrom numbers (N,) # atomic numbers charge (B,) # molecular charge mult (B,) # spin multiplicity, optional -mol_idx (N,) # molecule index for each atom, should contain integers in increasing order. +mol_idx (N,) # molecule index for each atom, should contain integers in increasing order, with (B-1) is the maximum number. ``` where `B` is the number of molecules, `N` is number of atoms. From 75c4a081d613a705518abf69eba62a79516d446d Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Mon, 6 May 2024 14:07:42 -0400 Subject: [PATCH 14/27] fix --- aimnet2calc/aimnet2ase.py | 2 +- aimnet2calc/calculator.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/aimnet2calc/aimnet2ase.py b/aimnet2calc/aimnet2ase.py index 7429b4c..7020d77 100644 --- a/aimnet2calc/aimnet2ase.py +++ b/aimnet2calc/aimnet2ase.py @@ -48,7 +48,7 @@ def calculate(self, atoms=None, properties=['energy'], system_changes=all_change self.uptade_tensors() if self.atoms.cell is not None and self.atoms.pbc.any(): - assert self.base_calc.cutoff_lr < float('inf'), 'Long-range cutoff must be finite for PBC' + #assert self.base_calc.cutoff_lr < float('inf'), 'Long-range cutoff must be finite for PBC' cell = self.atoms.cell.array else: cell = None diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index ad96128..1d5391c 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -49,8 +49,11 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): self._saved_for_grad = None # set flag of current Coulomb model coul_methods = set(getattr(mod, 'method', None) for mod in iter_lrcoulomb_mods(self.model)) - assert len(coul_methods) == 1, 'Multiple Coulomb methods found.' - self._coulomb_method = coul_methods.pop() + assert len(coul_methods) <= 1, 'Multiple Coulomb methods found.' + if len(coul_methods): + self._coulomb_method = coul_methods.pop() + else: + self._coulomb_method = None def __call__(self, *args, **kwargs): return self.eval(*args, **kwargs) From 619aebcf2a35dc102d395b6c5292027b7b53983d Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Wed, 8 May 2024 21:38:43 -0400 Subject: [PATCH 15/27] add qr model --- aimnet2calc/models.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/aimnet2calc/models.py b/aimnet2calc/models.py index 985f1d1..f8f3522 100644 --- a/aimnet2calc/models.py +++ b/aimnet2calc/models.py @@ -3,9 +3,10 @@ # model registry aliases model_registry_aliases = {} -model_registry_aliases['aimnet2'] = 'aimnet2_wb97m_0' +model_registry_aliases['aimnet2'] = 'aimnet2/aimnet2_wb97m_0' model_registry_aliases['aimnet2_wb97m'] = model_registry_aliases['aimnet2'] -model_registry_aliases['aimnet2_b973c'] = 'aimnet2_b973c_0' +model_registry_aliases['aimnet2_b973c'] = 'aimnet2/aimnet2_b973c_0' +model_registry_aliases['aimnet2-qr'] = 'aimnet2-qr/aimnet2-qr_b97md4_qzvp_2' def get_model_path(s: str): @@ -19,12 +20,13 @@ def get_model_path(s: str): # add jpt extension if not s.endswith('.jpt'): s = s + '.jpt' - os.makedirs(os.path.join(os.path.dirname(__file__), 'assets'), exist_ok=True) + sdir = os.path.dirname(s) + os.makedirs(os.path.join(os.path.dirname(__file__), 'assets', sdir), exist_ok=True) s_local = os.path.join(os.path.dirname(__file__), 'assets', s) if os.path.isfile(s_local): print('Found model file:', s_local) else: - url = f'https://github.com/zubatyuk/aimnet-model-zoo/raw/main/aimnet2/{s}' + url = f'https://github.com/zubatyuk/aimnet-model-zoo/raw/main/{s}' print('Downloading model file from', url) r = requests.get(url) r.raise_for_status() From 09f9dcccb3c2871c8dc4d96b3c2eb229a4cd1ec8 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Wed, 8 May 2024 23:41:55 -0400 Subject: [PATCH 16/27] add implemented_species --- aimnet2calc/aimnet2ase.py | 9 +++++++++ aimnet2calc/calculator.py | 1 + 2 files changed, 10 insertions(+) diff --git a/aimnet2calc/aimnet2ase.py b/aimnet2calc/aimnet2ase.py index 7020d77..33da652 100644 --- a/aimnet2calc/aimnet2ase.py +++ b/aimnet2calc/aimnet2ase.py @@ -2,6 +2,7 @@ from aimnet2calc import AIMNet2Calculator from typing import Union import torch +import numpy as np class AIMNet2ASE(Calculator): @@ -14,6 +15,12 @@ def __init__(self, base_calc: Union[AIMNet2Calculator, str] = 'aimnet2', charge= self.charge = charge self.mult = mult self.do_reset() + # list of implemented species + if hasattr(base_calc, 'implemented_species'): + self.implemented_species = base_calc.implemented_species.cpu().numpy() + else: + self.implemented_species = None + def do_reset(self): self._t_numbers = None @@ -24,6 +31,8 @@ def do_reset(self): self.mult = 1.0 def set_atoms(self, atoms): + if self.implemented_species is not None and not np.in1d(atoms.numbers, self.implemented_species).all(): + raise ValueError('Some species are not implemented in the AIMNet2Calculator') self.atoms = atoms self.do_reset() diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index 1d5391c..6a2a6b1 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -271,3 +271,4 @@ def iter_lrcoulomb_mods(model): for name, module in _named_children_rec(model): if name == 'lrcoulomb': yield module + From 038f6f67b425ee10b17bd501eab86c7eeda98cfe Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Wed, 8 May 2024 23:55:12 -0400 Subject: [PATCH 17/27] Update setup.py --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 6e23ee9..f311e43 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ 'numpy', 'numba', 'ase', - 'pysisyphus', + # 'pysisyphus', 'requests', # 'openbabel' ], @@ -27,4 +27,4 @@ 'aimnet2pysis=aimnet2calc.aimnet2pysis:run_pysis' ], }, -) \ No newline at end of file +) From acb6ba14ed7ef8116010bc56169d3076fd111459 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Tue, 14 May 2024 17:25:35 -0400 Subject: [PATCH 18/27] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 804220a..6ecafa3 100644 --- a/README.md +++ b/README.md @@ -125,9 +125,9 @@ results = calc(data, forces=False, stress=False, hessian=False) By default, Coulomb energy is calculated in O(N^2) manner, e.g. pair interaction between every pair of atoms in system. For very large or periodic systems, O(N) Dumped-Shifted Force Coulomb model could be employed [doi: 10.1063/1.2206581](https://doi.org/10.1063/1.2206581). With `AIMNetCalculator` interface, switch between standard and DSF Coulomb implementations im AIMNet2 models: ``` -# switch to O(N^2) -calc.set_lrcoulomb_method('sdf', dsf_cutoff=15.0, dsf_alpha=0.2) # switch to O(N) +calc.set_lrcoulomb_method('sdf', dsf_cutoff=15.0, dsf_alpha=0.2) +# switch to O(N^2), not suitable for PBC calc.set_lrcoulomb_method('simple') ``` From b62034e4095b4be884d4ad8179b704d6e13ef8f3 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Wed, 15 May 2024 15:30:30 -0400 Subject: [PATCH 19/27] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6ecafa3..02bc85b 100644 --- a/README.md +++ b/README.md @@ -71,18 +71,18 @@ calc: ### 3. Base calculator ``` -from aimnet2calc import AIMNetCalculator +from aimnet2calc import AIMNet2Calculator ``` #### Initialization ``` -calc = AIMNetCalculator('aimnet2') +calc = AIMNet2Calculator('aimnet2') ``` will load default AIMNet2 model aimnet2_wb97m_0.jpt as defined at `aimnet2calc/models.py` . If file does not exist on the machine, it will be downloaded from [aimnet-model-zoo](http://github.com/zubatyuk/aimnet-model-zoo) repository. ``` -calc = AIMNetCalculator('/path/to_a/model.jpt') +calc = AIMNet2Calculator('/path/to_a/model.jpt') ``` will load model from the file. @@ -122,7 +122,7 @@ results = calc(data, forces=False, stress=False, hessian=False) ### 4. Long range Coulomb model -By default, Coulomb energy is calculated in O(N^2) manner, e.g. pair interaction between every pair of atoms in system. For very large or periodic systems, O(N) Dumped-Shifted Force Coulomb model could be employed [doi: 10.1063/1.2206581](https://doi.org/10.1063/1.2206581). With `AIMNetCalculator` interface, switch between standard and DSF Coulomb implementations im AIMNet2 models: +By default, Coulomb energy is calculated in O(N^2) manner, e.g. pair interaction between every pair of atoms in system. For very large or periodic systems, O(N) Dumped-Shifted Force Coulomb model could be employed [doi: 10.1063/1.2206581](https://doi.org/10.1063/1.2206581). With `AIMNet2Calculator` interface, switch between standard and DSF Coulomb implementations im AIMNet2 models: ``` # switch to O(N) From 63a626683b37204bedc7a00563a6ddc50999398b Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Mon, 10 Jun 2024 07:22:09 -0400 Subject: [PATCH 20/27] lr fix --- README.md | 2 +- aimnet2calc/calculator.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 02bc85b..3aac3f8 100644 --- a/README.md +++ b/README.md @@ -126,7 +126,7 @@ By default, Coulomb energy is calculated in O(N^2) manner, e.g. pair interaction ``` # switch to O(N) -calc.set_lrcoulomb_method('sdf', dsf_cutoff=15.0, dsf_alpha=0.2) +calc.set_lrcoulomb_method('dsf', cutoff=15.0, dsf_alpha=0.2) # switch to O(N^2), not suitable for PBC calc.set_lrcoulomb_method('simple') ``` diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index 6a2a6b1..011a4b2 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -49,7 +49,7 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): self._saved_for_grad = None # set flag of current Coulomb model coul_methods = set(getattr(mod, 'method', None) for mod in iter_lrcoulomb_mods(self.model)) - assert len(coul_methods) <= 1, 'Multiple Coulomb methods found.' + assert len(coul_methods) <= 1, 'Multiple Coulomb modules found.' if len(coul_methods): self._coulomb_method = coul_methods.pop() else: @@ -58,7 +58,7 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): def __call__(self, *args, **kwargs): return self.eval(*args, **kwargs) - def set_lrcoulomb_method(self, method, dsf_cutoff=15.0, dsf_alpha=0.2): + def set_lrcoulomb_method(self, method, cutoff=15.0, dsf_alpha=0.2): assert method in ('simple', 'dsf', 'ewald'), f'Invalid method: {method}' if method == 'simple': for mod in iter_lrcoulomb_mods(self.model): @@ -67,11 +67,12 @@ def set_lrcoulomb_method(self, method, dsf_cutoff=15.0, dsf_alpha=0.2): elif method == 'dsf': for mod in iter_lrcoulomb_mods(self.model): mod.method = 'dsf' - self.cutoff_lr = dsf_cutoff + self.cutoff_lr = cutoff mod.dsf_alpha = dsf_alpha elif method == 'ewald': for mod in iter_lrcoulomb_mods(self.model): mod.method = 'ewald' + self.cutoff_lr = cutoff self._coulomb_method = method def eval(self, data: Dict[str, Any], forces=False, stress=False, hessian=False) -> Dict[str, Tensor]: From 625be194b143f160d24e0a2e0b99eedacf346109 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Mon, 24 Jun 2024 16:15:43 -0400 Subject: [PATCH 21/27] Update nblist.py --- aimnet2calc/nblist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 17f5015..183cc6d 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -31,7 +31,7 @@ def nblist_torch_cluster(coord: Tensor, cutoff: float, mol_idx: Optional[Tensor] assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' assert coord.shape[0] < 2147483646, 'Too many atoms, max supported is 2147483646' max_num_neighbors = max_nb - while max_num_neighbors == max_nb: + while max_num_neighbors <= max_nb: sparse_nb = radius_graph(coord, batch=mol_idx, r=cutoff, max_num_neighbors=max_nb).to(torch.int32) max_num_neighbors = torch.unique(sparse_nb[0], return_counts=True)[1].max().item() max_nb *= 2 From f581b635c40bc99d020e586018c9f6e3f899111e Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Mon, 8 Jul 2024 11:42:27 -0400 Subject: [PATCH 22/27] Update nblist.py --- aimnet2calc/nblist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 183cc6d..92015e9 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -31,9 +31,11 @@ def nblist_torch_cluster(coord: Tensor, cutoff: float, mol_idx: Optional[Tensor] assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' assert coord.shape[0] < 2147483646, 'Too many atoms, max supported is 2147483646' max_num_neighbors = max_nb - while max_num_neighbors <= max_nb: + while True: sparse_nb = radius_graph(coord, batch=mol_idx, r=cutoff, max_num_neighbors=max_nb).to(torch.int32) max_num_neighbors = torch.unique(sparse_nb[0], return_counts=True)[1].max().item() + if max_num_neighbors < max_nb: + break max_nb *= 2 #assert max_num_neighbors < max_nb, f'Increase max_nb in nblist_torch_cluster (current value {max_nb}, cutoff {cutoff})' sparse_nb_half = sparse_nb[:, sparse_nb[0] > sparse_nb[1]] From 93be682f2bfb817b93303f8c292bfc604366fed0 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Mon, 8 Jul 2024 12:07:52 -0400 Subject: [PATCH 23/27] Update aimnet2pysis.py --- aimnet2calc/aimnet2pysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aimnet2calc/aimnet2pysis.py b/aimnet2calc/aimnet2pysis.py index 07b31d7..b400a73 100644 --- a/aimnet2calc/aimnet2pysis.py +++ b/aimnet2calc/aimnet2pysis.py @@ -18,7 +18,7 @@ def __init__(self, model: Union[AIMNet2Calculator, str] = 'aimnet2', charge=0, m self.model = model def _prepere_input(self, atoms, coord): - device = self.base_calc.device + device = self.model.device numbers = torch.as_tensor([ATOMIC_NUMBERS[a.lower()] for a in atoms], device=device) coord = torch.as_tensor(coord, dtype=torch.float, device=device).view(-1, 3) * BOHR2ANG charge = torch.as_tensor([self.charge], dtype=torch.float, device=device) From a56ade16193daa0f0d5d89180538e15d94aee659 Mon Sep 17 00:00:00 2001 From: zubatyuk Date: Thu, 11 Jul 2024 09:44:08 -0400 Subject: [PATCH 24/27] Update nblist.py --- aimnet2calc/nblist.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 92015e9..60e0977 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -33,7 +33,10 @@ def nblist_torch_cluster(coord: Tensor, cutoff: float, mol_idx: Optional[Tensor] max_num_neighbors = max_nb while True: sparse_nb = radius_graph(coord, batch=mol_idx, r=cutoff, max_num_neighbors=max_nb).to(torch.int32) - max_num_neighbors = torch.unique(sparse_nb[0], return_counts=True)[1].max().item() + nnb = torch.unique(sparse_nb[0], return_counts=True)[1] + if nnb.numel() == 0: + break + max_num_neighbors = nnb.max().item() if max_num_neighbors < max_nb: break max_nb *= 2 From 6ac8dbe8cfebe7dd898e0b92436feeaacc5829e8 Mon Sep 17 00:00:00 2001 From: MorrowChem Date: Wed, 14 Aug 2024 11:54:17 -0400 Subject: [PATCH 25/27] merge --- README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/README.md b/README.md index 505cbdd..de8e415 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,7 @@ -<<<<<<< HEAD -======= **__ Update 6/10/24 __** We release new code, suaitable for large molecules and perioric calculations. Old code available in the **old** branch. Models were re-compiled and are not compatible with the new code. ->>>>>>> 90a0e3a3f43955a4ff14ec70ce84c48aab5430ce # AIMNet2 Calculator: Fast, Accurate Molecular Simulations This package integrates the powerful AIMNet2 neural network potential into your simulation workflows. AIMNet2 provides fast and reliable energy, force, and property calculations for molecules containing a diverse range of elements. From 401405ecbb20b1cd62043d667020a3715d409acd Mon Sep 17 00:00:00 2001 From: MorrowChem Date: Wed, 14 Aug 2024 12:15:33 -0400 Subject: [PATCH 26/27] alternative impl of neighbourlist building added --- aimnet2calc/calculator.py | 11 +++- aimnet2calc/nblist.py | 131 +++++++++++++++++++++++++++++++++----- 2 files changed, 123 insertions(+), 19 deletions(-) diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index 011a4b2..5c77e53 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -29,7 +29,7 @@ class AIMNet2Calculator: keys_out = ['energy', 'charges', 'forces', 'hessian', 'stress'] atom_feature_keys = ['coord', 'numbers', 'charges', 'forces'] - def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): + def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2', max_nb_lr_guess=1024): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if isinstance(model, str): p = get_model_path(model) @@ -42,6 +42,7 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): self.cutoff = self.model.cutoff self.lr = hasattr(self.model, 'cutoff_lr') self.cutoff_lr = getattr(self.model, 'cutoff_lr', float('inf')) + self.max_nb_lr_guess=max_nb_lr_guess # indicator if input was flattened self._batch = None @@ -149,12 +150,16 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: assert data['cell'].ndim == 2, 'Expected 2D tensor for cell' if 'nbmat' not in data: data['coord'] = move_coord_to_cell(data['coord'], data['cell']) - mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff) + mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff, max_nb=128) data['nbmat'], data['nb_pad_mask'], data['shifts'] = mat_idxj, mat_pad, mat_S if self.lr: if 'nbmat_lr' not in data: assert self.cutoff_lr < torch.inf, 'Long-range cutoff must be finite for PBC' - data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr) + data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr, + max_nb=self.max_nb_lr_guess) + max_nb = torch.sum(data['nb_pad_mask_lr'], axis=1).max() + if max_nb > self.max_nb_lr_guess: + self.max_nb_lr_guess = int(max_nb * 1.05) data['cutoff_lr'] = torch.tensor(self.cutoff_lr, device=self.device) else: if 'nbmat' not in data: diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 623ff6e..5754e35 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from typing import Optional, Tuple -from torch_cluster import radius_graph +from torch_cluster import radius_graph, radius import numba try: # optionaly use numba cuda @@ -87,28 +87,138 @@ def _cuda_dense_nb_mat_sft(conn_matrix, mat_idxj, mat_pad, mat_S_idx): k += 1 -def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: +def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float, max_nb=48) -> Tuple[Tensor, Tensor, Tensor]: """ Compute dense neighbor lists for periodic boundary conditions case. Coordinates must be in cartesian coordinates and be within the unit cell. Single crystal only, no support for batched coord or multiple unit cells. """ assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' - # non-PBC version device = coord.device - + reciprocal_cell = cell.inverse().t() inv_distances = reciprocal_cell.norm(2, -1) shifts = _calc_shifts(inv_distances, cutoff) + + if coord.shape[0] > 10e3 and cutoff < 20: # avoid making big NxMxS conn_mat for big systems + _fn = nblist_torch_cluster_pbc + mat_idxj, mat_pad, mat_S = _fn(coord, cell, shifts, cutoff, max_nb) + return mat_idxj, mat_pad, mat_S + d = torch.cdist(coord.unsqueeze(0), coord.unsqueeze(0) + (shifts @ cell).unsqueeze(1)) conn_mat = ((d < cutoff) & (d > 0.1)).transpose(0, 1).contiguous() + if device.type == 'cuda' and _numba_cuda_available: _fn = _nblist_pbc_cuda + mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts) else: _fn = _nblist_pbc_cpu - mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts) + mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts, device) return mat_idxj, mat_pad, mat_S +def nblist_torch_cluster_pbc(coord: Tensor, cell: Tensor, shifts: Tensor, + cutoff: float, max_nb: int) -> Tuple[Tensor, Tensor, Tensor]: + assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' + device = coord.device + + # put the zero shift first for convenience when removing self-interaction from torch_cluster.radius + ind = torch.argwhere(torch.all(shifts == 0, axis=1))[0,0] + shifts = shifts[torch.tensor([ind] + \ + list(range(ind)) + \ + list(range(ind+1, len(shifts))))].clone().detach().to(device).to(torch.int8) + + supercoord = torch.vstack( + [coord+(shift @ cell) for shift in shifts.to(torch.float)]) + + max_num_neighbors = max_nb + flag = True + while flag: + edges = radius(supercoord, coord, cutoff, max_num_neighbors=max_nb) + max_num_neighbors = torch.unique(edges[0], return_counts=True)[1].max().item() + flag = max_num_neighbors == max_nb + if flag: + max_nb = int(max_nb * 1.5) + + orig_len = coord.shape[0] + mat_idxj = torch.full((orig_len+1, max_nb), orig_len, dtype=torch.int).to(device) + mat_pad = torch.ones((orig_len+1, max_nb), dtype=torch.int8).to(device) + mat_S = torch.full((orig_len+1, max_nb, 3), -1, dtype=torch.int8).to(device) + + if device.type == 'cuda': + threadsperblock = 32 + blockspergrid = (edges.shape[1] + (threadsperblock - 1)) // threadsperblock + _nblist_sparse_pbc_cuda[blockspergrid, threadsperblock]( + orig_len, edges, shifts, + numba.cuda.as_cuda_array(mat_idxj), + numba.cuda.as_cuda_array(mat_pad), + numba.cuda.as_cuda_array(mat_S)) + return mat_idxj, mat_pad, mat_S + + else: + mat_idxj = mat_idxj.cpu().numpy(); mat_pad = mat_pad.cpu().numpy(); mat_S = mat_S.cpu().numpy() + _nblist_sparse_pbc(orig_len, edges.cpu().numpy(), shifts.cpu().numpy(), + mat_idxj, mat_pad, mat_S) + return torch.tensor(mat_idxj).to(device), torch.tensor(mat_pad).to(device), torch.tensor(mat_S).to(device) + +@numba.njit(cache=True) +def _nblist_sparse_pbc(orig_len, edges, shifts, nl, nl_pad, nl_shifts): + e0 = edges[0] + e1 = edges[1] + mask = e0 != e1 + e0 = e0[mask] + e1 = e1[mask] + e1r = e1 % orig_len + e1d = e1 // orig_len + prev = -1 + tmp = 0 + for ct, i in enumerate(e0): + if i == prev: + tmp += 1 + else: + tmp = 0 + prev = i + nl[i, tmp] = e1r[ct] + nl_pad[i, tmp] = 0 + nl_shifts[i, tmp] = shifts[e1d[ct]] + return nl, nl_pad, nl_shifts + + +@numba.cuda.jit(cache=True) +def _nblist_sparse_pbc_cuda(orig_len, edges, shifts, nl, nl_pad, nl_shifts): + e0 = edges[0] + e1 = edges[1] + gi = numba.cuda.grid(1) + if gi > e0[-1]: + return + + #initialise position in array + nn = e0.shape[0]//e0[-1] + start = gi*nn + if e0[start] < gi: + while e0[start] < gi: + start += 1 + else: + while start > 0: + if e0[start] >= gi: + start -= 1 + else: + start += 1 + break + + # start assigning + tmp = 0 + while e0[start] == gi and start Date: Wed, 14 Aug 2024 12:32:29 -0400 Subject: [PATCH 27/27] type hinting --- aimnet2calc/nblist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 5754e35..2b1ae0d 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -87,7 +87,7 @@ def _cuda_dense_nb_mat_sft(conn_matrix, mat_idxj, mat_pad, mat_S_idx): k += 1 -def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float, max_nb=48) -> Tuple[Tensor, Tensor, Tensor]: +def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float, max_nb: int=48) -> Tuple[Tensor, Tensor, Tensor]: """ Compute dense neighbor lists for periodic boundary conditions case. Coordinates must be in cartesian coordinates and be within the unit cell. Single crystal only, no support for batched coord or multiple unit cells.