Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ test = [
"torch-sim-atomistic[io,symmetry,vesin]",
"platformdirs>=4.0.0",
"psutil>=7.0.0",
"pymatgen>=2025.6.14",
"pytest-cov>=6",
"pytest>=8",
"spglib>=2.6",
"vesin[torch]>=0.5.3",
]
vesin = ["vesin[torch]>=0.5.3"]
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"]
symmetry = ["moyopy>=0.7.8"]
mace = ["mace-torch>=0.3.15"]
mattersim = ["mattersim>=1.2.2"]
Expand Down
115 changes: 115 additions & 0 deletions tests/models/test_electrostatics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for the electrostatics ModelInterface wrappers."""

import traceback # noqa: I001

import pytest
import torch
from ase.build import bulk

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test

try:
from torch_sim.models.electrostatics import DSFCoulombModel, EwaldModel, PMEModel
except (ImportError, OSError, RuntimeError):
pytest.skip(
f"nvalchemiops not installed: {traceback.format_exc()}",
allow_module_level=True,
)


def _make_charged_state(
device: torch.device = DEVICE,
dtype: torch.dtype = DTYPE,
) -> ts.SimState:
"""Build a small NaCl-like state with alternating +1/-1 site charges."""
atoms = bulk("NaCl", crystalstructure="rocksalt", a=5.64, cubic=True)
state = ts.io.atoms_to_state(atoms, device, dtype)
n = state.n_atoms
charges = torch.empty(n, dtype=dtype, device=device)
charges[::2] = 1.0
charges[1::2] = -1.0
state._atom_extras["partial_charges"] = charges # noqa: SLF001
return state


@pytest.fixture
def dsf_model() -> DSFCoulombModel:
return DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)


@pytest.fixture
def ewald_model() -> EwaldModel:
return EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)


@pytest.fixture
def pme_model() -> PMEModel:
return PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)


def _add_partial_charges(state: ts.SimState) -> ts.SimState:
"""Inject alternating +/-0.5 site charges into a state."""
n = state.n_atoms
charges = torch.zeros(n, dtype=state.positions.dtype, device=state.device)
charges[::2] = 0.5
charges[1::2] = -0.5
state._atom_extras["partial_charges"] = charges # noqa: SLF001
return state


test_dsf_model_outputs = make_validate_model_outputs_test(
model_fixture_name="dsf_model",
device=DEVICE,
dtype=DTYPE,
state_modifiers=[_add_partial_charges],
)
test_ewald_model_outputs = make_validate_model_outputs_test(
model_fixture_name="ewald_model",
device=DEVICE,
dtype=DTYPE,
state_modifiers=[_add_partial_charges],
)
test_pme_model_outputs = make_validate_model_outputs_test(
model_fixture_name="pme_model",
device=DEVICE,
dtype=DTYPE,
state_modifiers=[_add_partial_charges],
)


def test_dsf_nonzero_energy() -> None:
"""Charged system should produce nonzero electrostatic energy."""
model = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
state = _make_charged_state()
out = model(state)
assert out["energy"].abs().item() > 0


def test_ewald_pme_energy_agreement() -> None:
"""Ewald and PME should give the same converged Coulomb energy."""
state = _make_charged_state()
ewald = EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
pme = PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE)
torch.testing.assert_close(
ewald(state)["energy"], pme(state)["energy"], atol=1e-3, rtol=1e-3
)


def test_sum_model_lj_plus_dsf() -> None:
"""LJ + DSF should be additive through SumModel."""
from torch_sim.models.interface import SumModel
from torch_sim.models.lennard_jones import LennardJonesModel

lj = LennardJonesModel(
sigma=2.8, epsilon=0.01, cutoff=7.0, device=DEVICE, dtype=DTYPE
)
dsf = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE)
combined = SumModel(lj, dsf)
state = _make_charged_state()
lj_out = lj(state)
dsf_out = dsf(state)
sum_out = combined(state)
torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + dsf_out["energy"])
torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + dsf_out["forces"])
Loading
Loading