diff --git a/src/flashmd/stepper.py b/src/flashmd/stepper.py index a8e46b9..9e74a9c 100644 --- a/src/flashmd/stepper.py +++ b/src/flashmd/stepper.py @@ -1,9 +1,9 @@ # from ..utils.pretrained import load_pretrained_models import ase.units import torch +import vesin.metatomic from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import AtomisticModel, ModelEvaluationOptions, ModelOutput, System -from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists from .constraints import enforce_physical_constraints @@ -35,9 +35,7 @@ def step(self, system: System): if system.positions.dtype != self.dtype: raise ValueError("System dtype does not match stepper dtype.") - system = get_system_with_neighbor_lists( - system, self.model.requested_neighbor_lists() - ) + vesin.metatomic.compute_requested_neighbors([system], "angstrom", self.model) masses = system.get_data("masses").block().values model_outputs = self.model( diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index 859503c..5d5e995 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -1,11 +1,11 @@ import ase.build -import ase.io import ase.units import torch -from ase.md import VelocityVerlet +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from flashmd import get_pretrained from flashmd.ase import EnergyCalculator +from flashmd.ase.velocity_verlet import VelocityVerlet def test_isolated_atom(monkeypatch, tmp_path): @@ -13,14 +13,20 @@ def test_isolated_atom(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) atoms = ase.Atoms("O", positions=[[0, 0, 0]]) + MaxwellBoltzmannDistribution(atoms, temperature_K=300) - time_step = 64 + time_step = 8 device = "cuda" if torch.cuda.is_available() else "cpu" - energy_model, _ = get_pretrained("pet-omatpes-v2", time_step) + energy_model, flashmd_model = get_pretrained("pet-omatpes-v2", time_step) calculator = EnergyCalculator(energy_model, device=device) atoms.calc = calculator - dyn = VelocityVerlet(atoms=atoms, timestep=time_step * ase.units.fs) + dyn = VelocityVerlet( + atoms=atoms, + timestep=time_step * ase.units.fs, + model=flashmd_model, + device=device, + ) dyn.run(10) @@ -32,12 +38,18 @@ def test_slab_plus_isolated_atom(monkeypatch, tmp_path): slab = ase.build.fcc111("Al", size=(2, 2, 3), vacuum=10) isolated_atom = ase.Atoms("O", positions=[[0, 0, 24]]) atoms = slab + isolated_atom + MaxwellBoltzmannDistribution(atoms, temperature_K=300) - time_step = 64 + time_step = 8 device = "cuda" if torch.cuda.is_available() else "cpu" - energy_model, _ = get_pretrained("pet-omatpes-v2", time_step) + energy_model, flashmd_model = get_pretrained("pet-omatpes-v2", time_step) calculator = EnergyCalculator(energy_model, device=device) atoms.calc = calculator - dyn = VelocityVerlet(atoms=atoms, timestep=time_step * ase.units.fs) + dyn = VelocityVerlet( + atoms=atoms, + timestep=time_step * ase.units.fs, + model=flashmd_model, + device=device, + ) dyn.run(10)