diff --git a/aimnet2calc/aimnet2ase.py b/aimnet2calc/aimnet2ase.py index 1155cb0..67fe0ca 100644 --- a/aimnet2calc/aimnet2ase.py +++ b/aimnet2calc/aimnet2ase.py @@ -1,12 +1,14 @@ -from ase.calculators.calculator import Calculator, all_changes -from aimnet2calc import AIMNet2Calculator from typing import Union -import torch + import numpy as np +import torch +from ase.calculators.calculator import Calculator, all_changes + +from aimnet2calc.calculator import AIMNet2Calculator class AIMNet2ASE(Calculator): - implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress'] + implemented_properties = ['energy', 'forces', 'free_energy', 'charges', 'stress', 'dipole_moment'] def __init__(self, base_calc: Union[AIMNet2Calculator, str] = 'aimnet2', charge=0, mult=1): super().__init__() if isinstance(base_calc, str): @@ -56,6 +58,12 @@ def update_tensors(self): if self._t_mol_idx is None: self.mol_idx = torch.zeros(len(self.atoms), dtype=torch.int64, device=self.base_calc.device) + def get_dipole_moment(self,atoms): + charges = self.get_charges()[:, np.newaxis] + positions = atoms.get_positions() + return np.sum(charges * positions, axis=0) + + def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes): super().calculate(atoms, properties, system_changes) self.update_tensors() @@ -77,8 +85,10 @@ def calculate(self, atoms=None, properties=['energy'], system_changes=all_change for k, v in results.items(): results[k] = v.detach().cpu().numpy() - self.results['energy'] = results['energy'] + self.results['energy'] = results['energy'].item() self.results['charges'] = results['charges'] + self.results['dipole_moment'] = self.get_dipole_moment(self.atoms) + if 'forces' in properties: self.results['forces'] = results['forces'] if 'stress' in properties: diff --git a/aimnet2calc/nblist.py b/aimnet2calc/nblist.py index 623ff6e..aba073b 100644 --- a/aimnet2calc/nblist.py +++ b/aimnet2calc/nblist.py @@ -134,12 +134,12 @@ def _nblist_pbc_cuda(conn_mat, shifts): return idx_j, mat_pad, shifts[S_idx] -def _nblist_pbc_cpu(conn_mat, shifts, device): +def _nblist_pbc_cpu(conn_mat, shifts): conn_mat = conn_mat.cpu().numpy() mat_idxj, mat_pad, mat_S_idx = _cpu_dense_nb_mat_sft(conn_mat) - mat_idxj = torch.from_numpy(mat_idxj).to(device) - mat_pad = torch.from_numpy(mat_pad).to(device) - mat_S_idx = torch.from_numpy(mat_S_idx).to(device) + mat_idxj = torch.from_numpy(mat_idxj) + mat_pad = torch.from_numpy(mat_pad) + mat_S_idx = torch.from_numpy(mat_S_idx) mat_S = shifts[mat_S_idx] return mat_idxj, mat_pad, mat_S diff --git a/test/test_ase.py b/test/test_ase.py new file mode 100644 index 0000000..4143814 --- /dev/null +++ b/test/test_ase.py @@ -0,0 +1,53 @@ +import os + +import ase +import numpy as np + +from aimnet2calc.aimnet2ase import AIMNet2ASE + +MODELS = ('aimnet2', 'aimnet2_b973c') +DIR = os.path.dirname(__file__) + + +def _struct_pbc(): + filename = os.path.join(DIR, '1008775.cif') + return ase.io.read(filename) + + +def _struct_list(): + filename = os.path.join(DIR, 'mols_size_var.xyz') + return ase.io.read(filename, index=':') + + +def _stuct_batch(): + filename = os.path.join(DIR, 'mols_size_36.xyz') + return ase.io.read(filename, index=':') + + +def _test_dipole(calc, atoms): + atoms.calc = calc + e =atoms.get_potential_energy() + assert isinstance(e, float) + + assert hasattr(atoms, 'get_charges') + q = atoms.get_charges() + assert q.shape == (len(atoms),) + + assert hasattr(atoms, 'get_dipole_moment') + dm = atoms.get_dipole_moment() + assert dm.shape == (3,) + + +def test_calculator(): + for model in MODELS: + print('Testing model:', model) + calculator = AIMNet2ASE(model) + for atoms_list, runtype in zip((_stuct_batch(), _struct_list(), _struct_pbc()), ('batch', 'list', 'pbc')): + if runtype == 'batch' or runtype == 'list': + for atoms in atoms_list: + _test_dipole(calculator, atoms) + + else: + _test_dipole(calculator, atoms_list) + + diff --git a/test/test_calculator.py b/test/test_calculator.py index 38ec8b7..9711a73 100644 --- a/test/test_calculator.py +++ b/test/test_calculator.py @@ -1,8 +1,9 @@ -import ase.io -from aimnet2calc.calculator import AIMNet2Calculator import os + +import ase.io import numpy as np +from aimnet2calc.calculator import AIMNet2Calculator MODELS = ('aimnet2', 'aimnet2_b973c') DIR = os.path.dirname(__file__)