Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/models/test_dispersion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests for the D3DispersionModel wrapper."""

import traceback # noqa: I001

import pytest
import torch

from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test

try:
from nvalchemiops.torch.interactions.dispersion import D3Parameters

from torch_sim.models.dispersion import D3DispersionModel
except (ImportError, OSError, RuntimeError):
pytest.skip(
f"nvalchemiops not installed: {traceback.format_exc()}",
allow_module_level=True,
)


def _make_d3_params(device: torch.device = DEVICE) -> D3Parameters:
"""Build minimal D3 reference parameters for testing (elements up to Fe=26)."""
max_z = 26
mesh = 5
return D3Parameters(
rcov=torch.rand(max_z + 1, device=device),
r4r2=torch.rand(max_z + 1, device=device),
c6ab=torch.rand(max_z + 1, max_z + 1, mesh, mesh, device=device),
cn_ref=torch.rand(max_z + 1, max_z + 1, mesh, mesh, device=device),
)


# BJ damping parameters from
# https://github.com/dftd3/simple-dftd3/blob/main/assets/parameters.toml
PBE_BJ = {"a1": 0.4289, "s8": 0.7875, "a2": 4.4407, "s6": 1.0}
R2SCAN_BJ = {"a1": 0.49484001, "s8": 0.78981345, "a2": 5.73083694, "s6": 1.0}


@pytest.fixture
def d3_model_pbe() -> D3DispersionModel:
return D3DispersionModel(
**PBE_BJ,
d3_params=_make_d3_params(),
cutoff=12.0,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def d3_model_r2scan() -> D3DispersionModel:
return D3DispersionModel(
**R2SCAN_BJ,
d3_params=_make_d3_params(),
cutoff=12.0,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


test_d3_pbe_outputs = make_validate_model_outputs_test(
model_fixture_name="d3_model_pbe", device=DEVICE, dtype=DTYPE
)

test_d3_r2scan_outputs = make_validate_model_outputs_test(
model_fixture_name="d3_model_r2scan", device=DEVICE, dtype=DTYPE
)
117 changes: 9 additions & 108 deletions tests/models/test_sum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from torch_sim.models.interface import SumModel, validate_model_outputs
from tests.models.conftest import make_validate_model_outputs_test
from torch_sim.models.interface import SumModel
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.morse import MorseModel

Expand All @@ -23,19 +24,6 @@ def lj_model_a() -> LennardJonesModel:
)


@pytest.fixture
def lj_model_b() -> LennardJonesModel:
return LennardJonesModel(
sigma=2.0,
epsilon=0.005,
cutoff=5.0,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)


@pytest.fixture
def morse_model() -> MorseModel:
return MorseModel(
Expand All @@ -55,6 +43,11 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode
return SumModel(lj_model_a, morse_model)


test_sum_model_outputs = make_validate_model_outputs_test(
model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE
)


def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None:
with pytest.raises(ValueError, match="at least two"):
SumModel(lj_model_a)
Expand All @@ -75,27 +68,7 @@ def test_sum_model_dtype_mismatch() -> None:
SumModel(m1, m2)


def test_sum_model_properties(sum_model: SumModel) -> None:
assert sum_model.device == DEVICE
assert sum_model.dtype == DTYPE
assert sum_model.compute_stress is True
assert sum_model.compute_forces is True


def test_sum_model_energy_summation(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_sim_state)
morse_out = morse_model(si_sim_state)
sum_out = sum_model(si_sim_state)
expected_energy = lj_out["energy"] + morse_out["energy"]
torch.testing.assert_close(sum_out["energy"], expected_energy)


def test_sum_model_forces_summation(
def test_sum_model_additivity(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
Expand All @@ -104,94 +77,22 @@ def test_sum_model_forces_summation(
lj_out = lj_model_a(si_sim_state)
morse_out = morse_model(si_sim_state)
sum_out = sum_model(si_sim_state)
expected_forces = lj_out["forces"] + morse_out["forces"]
torch.testing.assert_close(sum_out["forces"], expected_forces)


def test_sum_model_stress_summation(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_sim_state)
morse_out = morse_model(si_sim_state)
sum_out = sum_model(si_sim_state)
expected_stress = lj_out["stress"] + morse_out["stress"]
torch.testing.assert_close(sum_out["stress"], expected_stress)


def test_sum_model_batched(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
sum_model: SumModel,
si_double_sim_state: ts.SimState,
) -> None:
lj_out = lj_model_a(si_double_sim_state)
morse_out = morse_model(si_double_sim_state)
sum_out = sum_model(si_double_sim_state)
torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + morse_out["energy"])
torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + morse_out["forces"])
torch.testing.assert_close(sum_out["stress"], lj_out["stress"] + morse_out["stress"])


def test_sum_model_three_models(
lj_model_a: LennardJonesModel,
lj_model_b: LennardJonesModel,
morse_model: MorseModel,
si_sim_state: ts.SimState,
) -> None:
triple = SumModel(lj_model_a, lj_model_b, morse_model)
a_out = lj_model_a(si_sim_state)
b_out = lj_model_b(si_sim_state)
c_out = morse_model(si_sim_state)
sum_out = triple(si_sim_state)
torch.testing.assert_close(
sum_out["energy"], a_out["energy"] + b_out["energy"] + c_out["energy"]
)


def test_sum_model_compute_stress_setter(
def test_sum_model_setters(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
assert sm.compute_stress is True
sm.compute_stress = False
assert sm.compute_stress is False


def test_sum_model_compute_forces_setter(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
sm.compute_forces = False
assert sm.compute_forces is False


def test_sum_model_memory_scales_with(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
sm = SumModel(lj_model_a, morse_model)
assert sm.memory_scales_with == "n_atoms_x_density"


def test_sum_model_force_conservation(
sum_model: SumModel, si_double_sim_state: ts.SimState
) -> None:
results = sum_model(si_double_sim_state)
for sys_idx in range(si_double_sim_state.n_systems):
mask = si_double_sim_state.system_idx == sys_idx
assert torch.allclose(
results["forces"][mask].sum(dim=0),
torch.zeros(3, dtype=DTYPE),
atol=1e-10,
)


def test_sum_model_validate_outputs(sum_model: SumModel) -> None:
validate_model_outputs(sum_model, DEVICE, DTYPE, check_detached=True)


def test_sum_model_retain_graph(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> None:
Expand Down
Loading
Loading