Skip to content
This repository was archived by the owner on Apr 11, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions aimnet2calc/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AIMNet2Calculator:
keys_out = ['energy', 'charges', 'forces', 'hessian', 'stress']
atom_feature_keys = ['coord', 'numbers', 'charges', 'forces']

def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'):
def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2', max_nb_lr_guess=1024):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if isinstance(model, str):
p = get_model_path(model)
Expand All @@ -42,6 +42,7 @@ def __init__(self, model: Union[str, torch.nn.Module] = 'aimnet2'):
self.cutoff = self.model.cutoff
self.lr = hasattr(self.model, 'cutoff_lr')
self.cutoff_lr = getattr(self.model, 'cutoff_lr', float('inf'))
self.max_nb_lr_guess=max_nb_lr_guess

# indicator if input was flattened
self._batch = None
Expand Down Expand Up @@ -149,12 +150,16 @@ def make_nbmat(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
assert data['cell'].ndim == 2, 'Expected 2D tensor for cell'
if 'nbmat' not in data:
data['coord'] = move_coord_to_cell(data['coord'], data['cell'])
mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff)
mat_idxj, mat_pad, mat_S = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff, max_nb=128)
data['nbmat'], data['nb_pad_mask'], data['shifts'] = mat_idxj, mat_pad, mat_S
if self.lr:
if 'nbmat_lr' not in data:
assert self.cutoff_lr < torch.inf, 'Long-range cutoff must be finite for PBC'
data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr)
data['nbmat_lr'], data['nb_pad_mask_lr'], data['shifts_lr'] = nblists_torch_pbc(data['coord'], data['cell'], self.cutoff_lr,
max_nb=self.max_nb_lr_guess)
max_nb = torch.sum(data['nb_pad_mask_lr'], axis=1).max()
if max_nb > self.max_nb_lr_guess:
self.max_nb_lr_guess = int(max_nb * 1.05)
data['cutoff_lr'] = torch.tensor(self.cutoff_lr, device=self.device)
else:
if 'nbmat' not in data:
Expand Down
131 changes: 115 additions & 16 deletions aimnet2calc/nblist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch import Tensor
from typing import Optional, Tuple
from torch_cluster import radius_graph
from torch_cluster import radius_graph, radius
import numba
try:
# optionaly use numba cuda
Expand Down Expand Up @@ -87,28 +87,138 @@ def _cuda_dense_nb_mat_sft(conn_matrix, mat_idxj, mat_pad, mat_S_idx):
k += 1


def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
def nblists_torch_pbc(coord: Tensor, cell: Tensor, cutoff: float, max_nb: int=48) -> Tuple[Tensor, Tensor, Tensor]:
""" Compute dense neighbor lists for periodic boundary conditions case.
Coordinates must be in cartesian coordinates and be within the unit cell.
Single crystal only, no support for batched coord or multiple unit cells.
"""
assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D'
# non-PBC version
device = coord.device

reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
shifts = _calc_shifts(inv_distances, cutoff)

if coord.shape[0] > 10e3 and cutoff < 20: # avoid making big NxMxS conn_mat for big systems
_fn = nblist_torch_cluster_pbc
mat_idxj, mat_pad, mat_S = _fn(coord, cell, shifts, cutoff, max_nb)
return mat_idxj, mat_pad, mat_S

d = torch.cdist(coord.unsqueeze(0), coord.unsqueeze(0) + (shifts @ cell).unsqueeze(1))
conn_mat = ((d < cutoff) & (d > 0.1)).transpose(0, 1).contiguous()

if device.type == 'cuda' and _numba_cuda_available:
_fn = _nblist_pbc_cuda
mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts)
else:
_fn = _nblist_pbc_cpu
mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts)
mat_idxj, mat_pad, mat_S = _fn(conn_mat, shifts, device)
return mat_idxj, mat_pad, mat_S


def nblist_torch_cluster_pbc(coord: Tensor, cell: Tensor, shifts: Tensor,
cutoff: float, max_nb: int) -> Tuple[Tensor, Tensor, Tensor]:
assert coord.ndim == 2, 'Expected 2D tensor for coord, got {coord.ndim}D'
device = coord.device

# put the zero shift first for convenience when removing self-interaction from torch_cluster.radius
ind = torch.argwhere(torch.all(shifts == 0, axis=1))[0,0]
shifts = shifts[torch.tensor([ind] + \
list(range(ind)) + \
list(range(ind+1, len(shifts))))].clone().detach().to(device).to(torch.int8)

supercoord = torch.vstack(
[coord+(shift @ cell) for shift in shifts.to(torch.float)])

max_num_neighbors = max_nb
flag = True
while flag:
edges = radius(supercoord, coord, cutoff, max_num_neighbors=max_nb)
max_num_neighbors = torch.unique(edges[0], return_counts=True)[1].max().item()
flag = max_num_neighbors == max_nb
if flag:
max_nb = int(max_nb * 1.5)

orig_len = coord.shape[0]
mat_idxj = torch.full((orig_len+1, max_nb), orig_len, dtype=torch.int).to(device)
mat_pad = torch.ones((orig_len+1, max_nb), dtype=torch.int8).to(device)
mat_S = torch.full((orig_len+1, max_nb, 3), -1, dtype=torch.int8).to(device)

if device.type == 'cuda':
threadsperblock = 32
blockspergrid = (edges.shape[1] + (threadsperblock - 1)) // threadsperblock
_nblist_sparse_pbc_cuda[blockspergrid, threadsperblock](
orig_len, edges, shifts,
numba.cuda.as_cuda_array(mat_idxj),
numba.cuda.as_cuda_array(mat_pad),
numba.cuda.as_cuda_array(mat_S))
return mat_idxj, mat_pad, mat_S

else:
mat_idxj = mat_idxj.cpu().numpy(); mat_pad = mat_pad.cpu().numpy(); mat_S = mat_S.cpu().numpy()
_nblist_sparse_pbc(orig_len, edges.cpu().numpy(), shifts.cpu().numpy(),
mat_idxj, mat_pad, mat_S)
return torch.tensor(mat_idxj).to(device), torch.tensor(mat_pad).to(device), torch.tensor(mat_S).to(device)

@numba.njit(cache=True)
def _nblist_sparse_pbc(orig_len, edges, shifts, nl, nl_pad, nl_shifts):
e0 = edges[0]
e1 = edges[1]
mask = e0 != e1
e0 = e0[mask]
e1 = e1[mask]
e1r = e1 % orig_len
e1d = e1 // orig_len
prev = -1
tmp = 0
for ct, i in enumerate(e0):
if i == prev:
tmp += 1
else:
tmp = 0
prev = i
nl[i, tmp] = e1r[ct]
nl_pad[i, tmp] = 0
nl_shifts[i, tmp] = shifts[e1d[ct]]
return nl, nl_pad, nl_shifts


@numba.cuda.jit(cache=True)
def _nblist_sparse_pbc_cuda(orig_len, edges, shifts, nl, nl_pad, nl_shifts):
e0 = edges[0]
e1 = edges[1]
gi = numba.cuda.grid(1)
if gi > e0[-1]:
return

#initialise position in array
nn = e0.shape[0]//e0[-1]
start = gi*nn
if e0[start] < gi:
while e0[start] < gi:
start += 1
else:
while start > 0:
if e0[start] >= gi:
start -= 1
else:
start += 1
break

# start assigning
tmp = 0
while e0[start] == gi and start<e0.shape[0]:
if e0[start] == e1[start]:
start += 1
continue
nl[gi, tmp] = e1[start] % orig_len
nl_pad[gi, tmp] = 0
for j in range(3):
nl_shifts[gi, tmp, j] = shifts[e1[start] // orig_len, j]
start += 1
tmp += 1


def _calc_shifts(inv_distances, cutoff):
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
dc = [torch.arange(-num_repeats[i], num_repeats[i] + 1, device=inv_distances.device) for i in range(len(num_repeats))]
Expand Down Expand Up @@ -142,14 +252,3 @@ def _nblist_pbc_cpu(conn_mat, shifts, device):
mat_S_idx = torch.from_numpy(mat_S_idx).to(device)
mat_S = shifts[mat_S_idx]
return mat_idxj, mat_pad, mat_S