Skip to content
This repository was archived by the owner on Apr 11, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions aimnet2calc/aimnet2ase.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from ase.calculators.calculator import Calculator, all_changes
from aimnet2calc import AIMNet2Calculator
from typing import Union
import torch

import numpy as np
import torch
from ase.calculators.calculator import Calculator, all_changes

from aimnet2calc.calculator import AIMNet2Calculator


class AIMNet2ASE(Calculator):
implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress']
implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress', 'dipole_moment']
def __init__(self, base_calc: Union[AIMNet2Calculator, str] = 'aimnet2', charge=0, mult=1):
super().__init__()
if isinstance(base_calc, str):
Expand Down Expand Up @@ -56,6 +58,12 @@ def update_tensors(self):
if self._t_mol_idx is None:
self.mol_idx = torch.zeros(len(self.atoms), dtype=torch.int64, device=self.base_calc.device)

def get_dipole_moment(self,atoms):
charges = self.get_charges()[:, np.newaxis]
positions = atoms.get_positions()
return np.sum(charges * positions, axis=0)


def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes):
super().calculate(atoms, properties, system_changes)
self.update_tensors()
Expand All @@ -77,8 +85,10 @@ def calculate(self, atoms=None, properties=['energy'], system_changes=all_change
for k, v in results.items():
results[k] = v.detach().cpu().numpy()

self.results['energy'] = results['energy']
self.results['energy'] = results['energy'].item()
self.results['charges'] = results['charges']
self.results['dipole_moment'] = self.get_dipole_moment(self.atoms)

if 'forces' in properties:
self.results['forces'] = results['forces']
if 'stress' in properties:
Expand Down
8 changes: 4 additions & 4 deletions aimnet2calc/nblist.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ def _nblist_pbc_cuda(conn_mat, shifts):
return idx_j, mat_pad, shifts[S_idx]


def _nblist_pbc_cpu(conn_mat, shifts, device):
def _nblist_pbc_cpu(conn_mat, shifts):
conn_mat = conn_mat.cpu().numpy()
mat_idxj, mat_pad, mat_S_idx = _cpu_dense_nb_mat_sft(conn_mat)
mat_idxj = torch.from_numpy(mat_idxj).to(device)
mat_pad = torch.from_numpy(mat_pad).to(device)
mat_S_idx = torch.from_numpy(mat_S_idx).to(device)
mat_idxj = torch.from_numpy(mat_idxj)
mat_pad = torch.from_numpy(mat_pad)
mat_S_idx = torch.from_numpy(mat_S_idx)
mat_S = shifts[mat_S_idx]
return mat_idxj, mat_pad, mat_S

Expand Down
53 changes: 53 additions & 0 deletions test/test_ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

import ase
import numpy as np

from aimnet2calc.aimnet2ase import AIMNet2ASE

MODELS = ('aimnet2', 'aimnet2_b973c')
DIR = os.path.dirname(__file__)


def _struct_pbc():
filename = os.path.join(DIR, '1008775.cif')
return ase.io.read(filename)


def _struct_list():
filename = os.path.join(DIR, 'mols_size_var.xyz')
return ase.io.read(filename, index=':')


def _stuct_batch():
filename = os.path.join(DIR, 'mols_size_36.xyz')
return ase.io.read(filename, index=':')


def _test_dipole(calc, atoms):
atoms.calc = calc
e =atoms.get_potential_energy()
assert isinstance(e, float)

assert hasattr(atoms, 'get_charges')
q = atoms.get_charges()
assert q.shape == (len(atoms),)

assert hasattr(atoms, 'get_dipole_moment')
dm = atoms.get_dipole_moment()
assert dm.shape == (3,)


def test_calculator():
for model in MODELS:
print('Testing model:', model)
calculator = AIMNet2ASE(model)
for atoms_list, runtype in zip((_stuct_batch(), _struct_list(), _struct_pbc()), ('batch', 'list', 'pbc')):
if runtype == 'batch' or runtype == 'list':
for atoms in atoms_list:
_test_dipole(calculator, atoms)

else:
_test_dipole(calculator, atoms_list)


5 changes: 3 additions & 2 deletions test/test_calculator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ase.io
from aimnet2calc.calculator import AIMNet2Calculator
import os

import ase.io
import numpy as np

from aimnet2calc.calculator import AIMNet2Calculator

MODELS = ('aimnet2', 'aimnet2_b973c')
DIR = os.path.dirname(__file__)
Expand Down