Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
593 changes: 593 additions & 0 deletions examples/tutorials/nudged_elastic_band.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ include = ["docs/**/*.py", "docs/**/*.ipynb", "examples/**/*.py"]
[tool.ty.overrides.rules]
invalid-argument-type = "ignore"
invalid-assignment = "ignore"
invalid-attribute-override = "ignore"
not-iterable = "ignore"
not-subscriptable = "ignore"
unresolved-attribute = "ignore"
Expand Down
29 changes: 29 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,35 @@ def test_in_flight_max_iterations(
assert batcher.iteration_count[idx] == max_iterations


def test_in_flight_max_iterations_completes_whole_group(
si_double_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
grouped_state = si_double_sim_state.clone()
grouped_state.group_idx = torch.zeros(
grouped_state.n_systems, device=grouped_state.device, dtype=torch.long
)
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=800.0,
max_iterations=1,
)
batcher.load_states(grouped_state)

state, [] = batcher.next_batch(None, None)
assert state is not None
assert state.n_systems == grouped_state.n_systems
assert state.n_groups == 1

convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool)
next_state, completed_states = batcher.next_batch(state, convergence_tensor)

assert next_state is None
assert len(completed_states) == 1
assert completed_states[0].n_systems == grouped_state.n_systems


@pytest.mark.parametrize(
"num_steps_per_batch",
[
Expand Down
30 changes: 30 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,36 @@ def test_fire_optimization(
)


@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor))
def test_fire_uses_group_scoped_adaptive_state(
ar_double_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor
) -> None:
ar_double_sim_state.group_idx = torch.zeros(
ar_double_sim_state.n_systems,
device=ar_double_sim_state.device,
dtype=torch.int64,
)

state = ts.fire_init(
ar_double_sim_state,
lj_model,
fire_flavor=fire_flavor,
dt_start=0.1,
alpha_start=0.1,
)

assert state.n_groups == 1
assert state.dt.shape == (1,)
assert state.alpha.shape == (1,)
assert state.n_pos.shape == (1,)

updated = ts.fire_step(state=state, model=lj_model, dt_max=0.3)

assert updated.dt.shape == (1,)
assert updated.alpha.shape == (1,)
assert updated.n_pos.shape == (1,)


def test_bfgs_optimization(
ar_supercell_sim_state: SimState, lj_model: ModelInterface
) -> None:
Expand Down
27 changes: 26 additions & 1 deletion tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,36 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None:
per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom"))
assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"}
per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system"))
assert set(per_system_attrs) == {"cell"}
assert set(per_system_attrs) == {"cell", "group_idx"}
per_group_attrs = dict(get_attrs_for_scope(si_sim_state, "per-group"))
assert set(per_group_attrs) == set()
global_attrs = dict(get_attrs_for_scope(si_sim_state, "global"))
assert set(global_attrs) == {"pbc", "_rng"}


def test_group_idx_defaults_to_one_group_per_system(
si_double_sim_state: SimState,
) -> None:
assert torch.equal(
si_double_sim_state.group_idx,
torch.arange(si_double_sim_state.n_systems, device=si_double_sim_state.device),
)
assert si_double_sim_state.n_groups == si_double_sim_state.n_systems


def test_slice_remaps_group_idx(si_double_sim_state: SimState) -> None:
state = si_double_sim_state.clone()
state.group_idx = torch.zeros(state.n_systems, device=state.device, dtype=torch.int64)

sliced = _slice_state(state, [1, 0])

assert torch.equal(
sliced.group_idx,
torch.zeros(sliced.n_systems, device=sliced.device, dtype=torch.int64),
)
assert sliced.n_groups == 1


def test_all_attributes_must_be_specified_in_scopes() -> None:
"""Test that an error is raised when we forget to specify the scope
for an attribute in a child SimState class."""
Expand Down
201 changes: 201 additions & 0 deletions tests/workflows/test_neb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import numpy as np
import torch
from ase import Atoms
from ase.mep import NEB as ASENEB
from ase.mep.neb import ImprovedTangentMethod, NEBState

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from torch_sim.models.interface import ModelInterface
from torch_sim.workflows.neb import (
NEB,
assemble_path,
calculate_neb_forces,
interpolate_path,
)


class HarmonicModel(ModelInterface):
def __init__(self) -> None:
super().__init__()
self._device = DEVICE
self._dtype = DTYPE
self._compute_forces = True
self._compute_stress = True

def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
del kwargs
per_atom_energy = 0.5 * (state.positions**2).sum(dim=1)
energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype)
energy.scatter_add_(0, state.system_idx, per_atom_energy)
return {
"energy": energy,
"forces": -state.positions,
"stress": torch.zeros(
state.n_systems, 3, 3, device=state.device, dtype=state.dtype
),
}


class GroupIndexedModel(HarmonicModel):
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
if state.group_idx.max() >= 1:
raise AssertionError("NEB endpoint/path assembly should preserve one group.")
return super().forward(state, **kwargs)


def _single_atom_state(position: float) -> ts.SimState:
return ts.SimState(
positions=torch.tensor([[position, 0.0, 0.0]], device=DEVICE, dtype=DTYPE),
masses=torch.ones(1, device=DEVICE, dtype=DTYPE),
cell=torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0) * 10.0,
pbc=False,
atomic_numbers=torch.tensor([18], device=DEVICE),
system_idx=torch.zeros(1, device=DEVICE, dtype=torch.long),
)


def test_assemble_path_preserves_one_optimizer_group() -> None:
initial = _single_atom_state(0.0)
final = _single_atom_state(1.0)
movable = interpolate_path(initial, final, n_images=3)

path = assemble_path(initial, movable, final)

assert path.n_systems == 5
assert path.n_groups == 1
assert torch.equal(path.group_idx, torch.zeros(5, device=DEVICE, dtype=torch.long))


def test_interpolate_path_uses_movable_images_only() -> None:
initial = _single_atom_state(0.0)
final = _single_atom_state(1.0)

path = interpolate_path(initial, final, n_images=3)

assert path.n_systems == 3
assert path.n_groups == 1
assert torch.equal(path.group_idx, torch.zeros(3, device=DEVICE, dtype=torch.long))
assert torch.allclose(
path.positions[:, 0],
torch.tensor([0.25, 0.5, 0.75], device=DEVICE, dtype=DTYPE),
)


def test_calculate_neb_forces_matches_ase_step0_components() -> None:
n_images = 5
n_atoms = 2
spring_constant = 0.1
positions = torch.tensor(
[
[[0.0, 0.0, 0.0], [0.0, 0.8, 0.0]],
[[0.2, 0.1, 0.0], [0.1, 0.9, 0.0]],
[[0.5, 0.25, 0.0], [0.2, 1.0, 0.1]],
[[0.8, 0.35, 0.0], [0.25, 1.05, 0.2]],
[[1.0, 0.5, 0.0], [0.4, 1.2, 0.3]],
],
device=DEVICE,
dtype=DTYPE,
)
energies = torch.tensor([0.0, 0.2, 0.7, 0.4, 0.1], device=DEVICE, dtype=DTYPE)
true_forces = torch.tensor(
[
[[0.1, -0.2, 0.0], [0.0, 0.3, -0.1]],
[[-0.2, 0.1, 0.2], [0.2, -0.1, 0.0]],
[[0.3, 0.0, -0.1], [-0.1, 0.2, 0.1]],
],
device=DEVICE,
dtype=DTYPE,
)
path_state = ts.SimState(
positions=positions.reshape(-1, 3),
masses=torch.ones(n_images * n_atoms, device=DEVICE, dtype=DTYPE),
cell=torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(n_images, 1, 1)
* 10.0,
pbc=False,
atomic_numbers=torch.tensor([18, 18], device=DEVICE).repeat(n_images),
system_idx=torch.repeat_interleave(
torch.arange(n_images, device=DEVICE), repeats=n_atoms
),
)

torch_forces = calculate_neb_forces(
path_state,
true_forces.reshape(-1, 3),
energies[1:-1],
energies[0],
energies[-1],
spring_constant=spring_constant,
use_climbing_image=True,
).reshape(n_images - 2, n_atoms, 3)

ase_images = [
Atoms(
"Ar2",
positions=image_positions.detach().cpu().numpy(),
cell=np.eye(3) * 10.0,
pbc=False,
)
for image_positions in positions
]
ase_neb = ASENEB(ase_images, k=spring_constant, climb=True, method="improvedtangent")
ase_state = NEBState(ase_neb, ase_neb.images, energies.detach().cpu().numpy())
tangent_method = ImprovedTangentMethod(ase_neb)
ase_forces = []
true_forces_np = true_forces.detach().cpu().numpy()
for image_index in range(1, n_images - 1):
spring1 = ase_state.spring(image_index - 1)
spring2 = ase_state.spring(image_index)
tangent = tangent_method.get_tangent(ase_state, spring1, spring2, image_index)
tangent_norm = np.linalg.norm(tangent)
if tangent_norm > 1e-15:
tangent = tangent / tangent_norm
force = true_forces_np[image_index - 1]
force_dot_tangent = np.vdot(force, tangent)
if ase_neb.climb and image_index == ase_state.imax:
ase_forces.append(force - 2 * force_dot_tangent * tangent)
else:
spring_force = (spring2.nt * spring2.k - spring1.nt * spring1.k) * tangent
ase_forces.append(force - force_dot_tangent * tangent + spring_force)

assert torch.allclose(
torch_forces,
torch.tensor(np.array(ase_forces), device=DEVICE, dtype=DTYPE),
atol=1e-12,
rtol=1e-12,
)


def test_neb_run_uses_single_chain_optimize_without_moving_endpoints() -> None:
initial = _single_atom_state(0.0)
final = _single_atom_state(1.0)
neb = NEB(
model=HarmonicModel(),
n_images=1,
optimizer_type="gd",
optimizer_params={"pos_lr": 0.1},
)

result = neb.run(initial, final, max_steps=3, fmax=1e-12)

assert result.n_systems == 3
assert torch.allclose(
result.positions[:, 0],
torch.tensor([0.0, 0.5, 1.0], device=DEVICE, dtype=DTYPE),
)
assert result.n_groups == 1


def test_neb_run_does_not_offset_endpoint_groups() -> None:
initial = _single_atom_state(0.0)
final = _single_atom_state(1.0)
neb = NEB(
model=GroupIndexedModel(),
n_images=1,
optimizer_type="gd",
optimizer_params={"pos_lr": 0.1},
)

result = neb.run(initial, final, max_steps=1, fmax=1e-12)

assert result.n_groups == 1
Loading
Loading