diff --git a/aimnet2calc/calculator.py b/aimnet2calc/calculator.py index 011a4b2..5c77e53 100644 --- a/aimnet2calc/calculator.py +++ b/aimnet2calc/calculator.py @@ -29,7 +29,7 @@ class AIMNet2Calculator: keys_out = ['energy', 'charges', 'forces', 'hessian', 'stress'] atom_feature_keys = ['coord', 'numbers', 'charges', 'forces'] - def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): + def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2', max_nb_lr_guess=1024): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' if isinstance(model, str): p = get_model_path(model) @@ -42,6 +42,7 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'): self.cutoff = self.model.cutoff self.lr = hasattr(self.model, 'cutoff_lr') self.cutoff_lr = getattr(self.model, 'cutoff_lr', float('inf')) + self.max_nb_lr_guess=max_nb_lr_guess # indicator if input was flattened self._batch = None @@ -149,12 +150,16 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: assert data['cell'].ndim == 2, 'Expected 2D tensor for cell' if 'nbmat' not in data: data['coord'] = move_coord_to_cell(data['coord'], data['cell']) - mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff) + mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff, max_nb=128) data['nbmat'], data['nb_pad_mask'], data['shifts'] = mat_idxj, mat_pad, mat_S if self.lr: if 'nbmat_lr' not in data: assert self.cutoff_lr < torch.inf, 'Long-range cutoff must be finite for PBC' - data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr) + data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr, + max_nb=self.max_nb_lr_guess) + max_nb = torch.sum(data['nb_pad_mask_lr'], axis=1).max() + if max_nb > self.max_nb_lr_guess: + self.max_nb_lr_guess = int(max_nb * 1.05) data['cutoff_lr'] = torch.tensor(self.cutoff_lr, device=self.device) else: if 'nbmat' not in data: diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 623ff6e..2b1ae0d 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from typing import Optional, Tuple -from torch_cluster import radius_graph +from torch_cluster import radius_graph, radius import numba try: # optionaly use numba cuda @@ -87,28 +87,138 @@ def _cuda_dense_nb_mat_sft(conn_matrix, mat_idxj, mat_pad, mat_S_idx): k += 1 -def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: +def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float, max_nb: int=48) -> Tuple[Tensor, Tensor, Tensor]: """ Compute dense neighbor lists for periodic boundary conditions case. Coordinates must be in cartesian coordinates and be within the unit cell. Single crystal only, no support for batched coord or multiple unit cells. """ assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' - # non-PBC version device = coord.device - + reciprocal_cell = cell.inverse().t() inv_distances = reciprocal_cell.norm(2, -1) shifts = _calc_shifts(inv_distances, cutoff) + + if coord.shape[0] > 10e3 and cutoff < 20: # avoid making big NxMxS conn_mat for big systems + _fn = nblist_torch_cluster_pbc + mat_idxj, mat_pad, mat_S = _fn(coord, cell, shifts, cutoff, max_nb) + return mat_idxj, mat_pad, mat_S + d = torch.cdist(coord.unsqueeze(0), coord.unsqueeze(0) + (shifts @ cell).unsqueeze(1)) conn_mat = ((d < cutoff) & (d > 0.1)).transpose(0, 1).contiguous() + if device.type == 'cuda' and _numba_cuda_available: _fn = _nblist_pbc_cuda + mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts) else: _fn = _nblist_pbc_cpu - mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts) + mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts, device) return mat_idxj, mat_pad, mat_S +def nblist_torch_cluster_pbc(coord: Tensor, cell: Tensor, shifts: Tensor, + cutoff: float, max_nb: int) -> Tuple[Tensor, Tensor, Tensor]: + assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D' + device = coord.device + + # put the zero shift first for convenience when removing self-interaction from torch_cluster.radius + ind = torch.argwhere(torch.all(shifts == 0, axis=1))[0,0] + shifts = shifts[torch.tensor([ind] + \ + list(range(ind)) + \ + list(range(ind+1, len(shifts))))].clone().detach().to(device).to(torch.int8) + + supercoord = torch.vstack( + [coord+(shift @ cell) for shift in shifts.to(torch.float)]) + + max_num_neighbors = max_nb + flag = True + while flag: + edges = radius(supercoord, coord, cutoff, max_num_neighbors=max_nb) + max_num_neighbors = torch.unique(edges[0], return_counts=True)[1].max().item() + flag = max_num_neighbors == max_nb + if flag: + max_nb = int(max_nb * 1.5) + + orig_len = coord.shape[0] + mat_idxj = torch.full((orig_len+1, max_nb), orig_len, dtype=torch.int).to(device) + mat_pad = torch.ones((orig_len+1, max_nb), dtype=torch.int8).to(device) + mat_S = torch.full((orig_len+1, max_nb, 3), -1, dtype=torch.int8).to(device) + + if device.type == 'cuda': + threadsperblock = 32 + blockspergrid = (edges.shape[1] + (threadsperblock - 1)) // threadsperblock + _nblist_sparse_pbc_cuda[blockspergrid, threadsperblock]( + orig_len, edges, shifts, + numba.cuda.as_cuda_array(mat_idxj), + numba.cuda.as_cuda_array(mat_pad), + numba.cuda.as_cuda_array(mat_S)) + return mat_idxj, mat_pad, mat_S + + else: + mat_idxj = mat_idxj.cpu().numpy(); mat_pad = mat_pad.cpu().numpy(); mat_S = mat_S.cpu().numpy() + _nblist_sparse_pbc(orig_len, edges.cpu().numpy(), shifts.cpu().numpy(), + mat_idxj, mat_pad, mat_S) + return torch.tensor(mat_idxj).to(device), torch.tensor(mat_pad).to(device), torch.tensor(mat_S).to(device) + +@numba.njit(cache=True) +def _nblist_sparse_pbc(orig_len, edges, shifts, nl, nl_pad, nl_shifts): + e0 = edges[0] + e1 = edges[1] + mask = e0 != e1 + e0 = e0[mask] + e1 = e1[mask] + e1r = e1 % orig_len + e1d = e1 // orig_len + prev = -1 + tmp = 0 + for ct, i in enumerate(e0): + if i == prev: + tmp += 1 + else: + tmp = 0 + prev = i + nl[i, tmp] = e1r[ct] + nl_pad[i, tmp] = 0 + nl_shifts[i, tmp] = shifts[e1d[ct]] + return nl, nl_pad, nl_shifts + + +@numba.cuda.jit(cache=True) +def _nblist_sparse_pbc_cuda(orig_len, edges, shifts, nl, nl_pad, nl_shifts): + e0 = edges[0] + e1 = edges[1] + gi = numba.cuda.grid(1) + if gi > e0[-1]: + return + + #initialise position in array + nn = e0.shape[0]//e0[-1] + start = gi*nn + if e0[start] < gi: + while e0[start] < gi: + start += 1 + else: + while start > 0: + if e0[start] >= gi: + start -= 1 + else: + start += 1 + break + + # start assigning + tmp = 0 + while e0[start] == gi and start