From 44ee8d76c0b60111c90115f468fdbb6c4137a9d3 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 5 Apr 2026 18:58:10 -0400 Subject: [PATCH] fea: add electrostatics using alchemiops --- pyproject.toml | 4 +- tests/models/test_electrostatics.py | 115 +++++++++ torch_sim/models/electrostatics.py | 387 ++++++++++++++++++++++++++++ 3 files changed, 503 insertions(+), 3 deletions(-) create mode 100644 tests/models/test_electrostatics.py create mode 100644 torch_sim/models/electrostatics.py diff --git a/pyproject.toml b/pyproject.toml index 9ebc3752..9816a6aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,14 +41,12 @@ test = [ "torch-sim-atomistic[io,symmetry,vesin]", "platformdirs>=4.0.0", "psutil>=7.0.0", - "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", "spglib>=2.6", - "vesin[torch]>=0.5.3", ] vesin = ["vesin[torch]>=0.5.3"] -io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] +io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"] symmetry = ["moyopy>=0.7.8"] mace = ["mace-torch>=0.3.15"] mattersim = ["mattersim>=1.2.2"] diff --git a/tests/models/test_electrostatics.py b/tests/models/test_electrostatics.py new file mode 100644 index 00000000..e4e3feaa --- /dev/null +++ b/tests/models/test_electrostatics.py @@ -0,0 +1,115 @@ +"""Tests for the electrostatics ModelInterface wrappers.""" + +import traceback # noqa: I001 + +import pytest +import torch +from ase.build import bulk + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from tests.models.conftest import make_validate_model_outputs_test + +try: + from torch_sim.models.electrostatics import DSFCoulombModel, EwaldModel, PMEModel +except (ImportError, OSError, RuntimeError): + pytest.skip( + f"nvalchemiops not installed: {traceback.format_exc()}", + allow_module_level=True, + ) + + +def _make_charged_state( + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, +) -> ts.SimState: + """Build a small NaCl-like state with alternating +1/-1 site charges.""" + atoms = bulk("NaCl", crystalstructure="rocksalt", a=5.64, cubic=True) + state = ts.io.atoms_to_state(atoms, device, dtype) + n = state.n_atoms + charges = torch.empty(n, dtype=dtype, device=device) + charges[::2] = 1.0 + charges[1::2] = -1.0 + state._atom_extras["partial_charges"] = charges # noqa: SLF001 + return state + + +@pytest.fixture +def dsf_model() -> DSFCoulombModel: + return DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE) + + +@pytest.fixture +def ewald_model() -> EwaldModel: + return EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE) + + +@pytest.fixture +def pme_model() -> PMEModel: + return PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE) + + +def _add_partial_charges(state: ts.SimState) -> ts.SimState: + """Inject alternating +/-0.5 site charges into a state.""" + n = state.n_atoms + charges = torch.zeros(n, dtype=state.positions.dtype, device=state.device) + charges[::2] = 0.5 + charges[1::2] = -0.5 + state._atom_extras["partial_charges"] = charges # noqa: SLF001 + return state + + +test_dsf_model_outputs = make_validate_model_outputs_test( + model_fixture_name="dsf_model", + device=DEVICE, + dtype=DTYPE, + state_modifiers=[_add_partial_charges], +) +test_ewald_model_outputs = make_validate_model_outputs_test( + model_fixture_name="ewald_model", + device=DEVICE, + dtype=DTYPE, + state_modifiers=[_add_partial_charges], +) +test_pme_model_outputs = make_validate_model_outputs_test( + model_fixture_name="pme_model", + device=DEVICE, + dtype=DTYPE, + state_modifiers=[_add_partial_charges], +) + + +def test_dsf_nonzero_energy() -> None: + """Charged system should produce nonzero electrostatic energy.""" + model = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE) + state = _make_charged_state() + out = model(state) + assert out["energy"].abs().item() > 0 + + +def test_ewald_pme_energy_agreement() -> None: + """Ewald and PME should give the same converged Coulomb energy.""" + state = _make_charged_state() + ewald = EwaldModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE) + pme = PMEModel(cutoff=8.0, accuracy=1e-6, device=DEVICE, dtype=DTYPE) + torch.testing.assert_close( + ewald(state)["energy"], pme(state)["energy"], atol=1e-3, rtol=1e-3 + ) + + +def test_sum_model_lj_plus_dsf() -> None: + """LJ + DSF should be additive through SumModel.""" + from torch_sim.models.interface import SumModel + from torch_sim.models.lennard_jones import LennardJonesModel + + lj = LennardJonesModel( + sigma=2.8, epsilon=0.01, cutoff=7.0, device=DEVICE, dtype=DTYPE + ) + dsf = DSFCoulombModel(cutoff=8.0, alpha=0.2, device=DEVICE, dtype=DTYPE) + combined = SumModel(lj, dsf) + state = _make_charged_state() + lj_out = lj(state) + dsf_out = dsf(state) + sum_out = combined(state) + torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + dsf_out["energy"]) + torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + dsf_out["forces"]) diff --git a/torch_sim/models/electrostatics.py b/torch_sim/models/electrostatics.py new file mode 100644 index 00000000..85fb9fd6 --- /dev/null +++ b/torch_sim/models/electrostatics.py @@ -0,0 +1,387 @@ +"""Electrostatics models: DSF, Ewald, and PME. + +Wraps the ``nvalchemiops`` Warp-accelerated electrostatics implementations as +:class:`~torch_sim.models.interface.ModelInterface` subclasses, with full PBC, +stress (virial), and batched system support. Per-atom partial charges are read +from ``state.partial_charges`` (a SimState atom extra). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from nvalchemiops.torch.interactions.electrostatics import ( + dsf_coulomb, + ewald_summation, + particle_mesh_ewald, +) + +from torch_sim._duecredit import dcite +from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import torchsim_nl +from torch_sim.units import UnitConversion + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch_sim.state import SimState + + +def _zero_result( + state: SimState, + dtype: torch.dtype, + compute_forces: bool, # noqa: FBT001 + compute_stress: bool, # noqa: FBT001 +) -> dict[str, torch.Tensor]: + """Return zero energy / forces / stress for non-periodic states.""" + dev = state.positions.device + results: dict[str, torch.Tensor] = { + "energy": torch.zeros(state.n_systems, dtype=dtype, device=dev), + } + if compute_forces: + results["forces"] = torch.zeros(state.n_atoms, 3, dtype=dtype, device=dev) + if compute_stress: + results["stress"] = torch.zeros(state.n_systems, 3, 3, dtype=dtype, device=dev) + return results + + +def _build_csr( + state: SimState, + cutoff: float, + neighbor_list_fn: Callable, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build a CSR neighbor list and integer unit-shift tensor.""" + edge_index, _mapping, unit_shifts = neighbor_list_fn( + state.positions, + state.row_vector_cell, + state.pbc, + cutoff, + state.system_idx, + ) + n_atoms = state.positions.shape[0] + dev = state.positions.device + neighbor_ptr = torch.zeros(n_atoms + 1, dtype=torch.int32, device=dev) + neighbor_ptr[1:] = ( + torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) + ) + return ( + edge_index.to(torch.int32), + neighbor_ptr, + unit_shifts.to(torch.int32), + ) + + +class DSFCoulombModel(ModelInterface): + """Damped Shifted Force electrostatics as a :class:`ModelInterface`. + + Uses the ``nvalchemiops`` DSF kernel for O(N) electrostatic energy, + forces, and (optionally) stress. All user-facing quantities are in + metal units (Angstrom / eV); the Coulomb constant ``ke`` is baked in. + + Per-atom partial charges are read from ``state.partial_charges``. + + Args: + cutoff: Real-space cutoff in Angstrom. + alpha: DSF damping parameter. 0.0 gives shifted-force bare Coulomb. + device: Compute device. Defaults to CUDA if available, else CPU. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + compute_forces: Whether to return forces. Defaults to True. + compute_stress: Whether to return stress. Defaults to True. + neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. + """ + + @dcite("10.1063/1.2206581", description="Fennell & Gezelter DSF method") + def __init__( + self, + cutoff: float = 10.0, + *, + alpha: float = 0.2, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + neighbor_list_fn: Callable = torchsim_nl, + ) -> None: + """Initialize the DSF Coulomb model.""" + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + self.neighbor_list_fn = neighbor_list_fn + self.cutoff = cutoff + self.alpha = alpha + + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: + """Compute DSF electrostatic energy, forces, and stress. + + Args: + state: Simulation state with ``partial_charges`` set as an + atom extra (shape ``[n_atoms]``). + **_kwargs: Unused; accepted for interface compatibility. + + Returns: + dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], + and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. + """ + if not state.has_extras("partial_charges"): + raise ValueError("Partial charges are required for DSF Coulomb summation.") + + charges = state.partial_charges + edge_index, neighbor_ptr, unit_shifts = _build_csr( + state, self.cutoff, self.neighbor_list_fn + ) + cell = state.row_vector_cell.contiguous() + dsf_args: dict = dict( + positions=state.positions, + charges=charges, + cutoff=self.cutoff, + alpha=self.alpha, + neighbor_list=edge_index, + neighbor_ptr=neighbor_ptr, + unit_shifts=unit_shifts, + cell=cell, + batch_idx=state.system_idx.to(torch.int32), + compute_forces=self._compute_forces, + compute_virial=self._compute_stress, + num_systems=state.n_systems, + ) + out = dsf_coulomb(**dsf_args) + if not isinstance(out, tuple): + out = (out,) + energy = (out[0] * UnitConversion.e2_per_Ang_to_eV).to(self._dtype).detach() + results: dict[str, torch.Tensor] = {"energy": energy} + if self._compute_forces: + forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] + results["forces"] = forces.to(self._dtype).detach() + if self._compute_stress: + volumes = state.volume.unsqueeze(-1).unsqueeze(-1) + stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes + results["stress"] = stress.to(self._dtype).detach() + return results + + +class EwaldModel(ModelInterface): + """Classical Ewald summation as a :class:`ModelInterface`. + + Uses the ``nvalchemiops`` Ewald kernel for exact periodic electrostatics. + Returns per-atom energies that are aggregated to per-system. All + user-facing quantities are in metal units (Angstrom / eV). + + Per-atom partial charges are read from ``state.partial_charges``. + + Requires periodic boundary conditions. + + Args: + cutoff: Real-space cutoff in Angstrom. + accuracy: Target accuracy for auto-estimated Ewald parameters. + device: Compute device. Defaults to CUDA if available, else CPU. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + compute_forces: Whether to return forces. Defaults to True. + compute_stress: Whether to return stress. Defaults to True. + neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. + """ + + @dcite("10.1002/andp.19213690304", description="Ewald summation method") + def __init__( + self, + cutoff: float = 10.0, + *, + accuracy: float = 1e-6, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + neighbor_list_fn: Callable = torchsim_nl, + ) -> None: + """Initialize the Ewald model.""" + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + self.neighbor_list_fn = neighbor_list_fn + self.cutoff = cutoff + self.accuracy = accuracy + + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: + """Compute Ewald electrostatic energy, forces, and stress. + + Args: + state: Simulation state with ``partial_charges`` set as an + atom extra (shape ``[n_atoms]``). Returns zeros for + non-periodic states. + **_kwargs: Unused; accepted for interface compatibility. + + Returns: + dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], + and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. + """ + if not state.has_extras("partial_charges"): + raise ValueError("Partial charges are required for Ewald summation.") + + if not state.pbc.any(): + return _zero_result( + state, self._dtype, self._compute_forces, self._compute_stress + ) + charges = state.partial_charges + edge_index, neighbor_ptr, unit_shifts = _build_csr( + state, self.cutoff, self.neighbor_list_fn + ) + cell = state.row_vector_cell.contiguous() + out = ewald_summation( + positions=state.positions, + charges=charges, + cell=cell, + neighbor_list=edge_index, + neighbor_ptr=neighbor_ptr, + neighbor_shifts=unit_shifts, + batch_idx=state.system_idx.to(torch.int32), + compute_forces=self._compute_forces, + compute_virial=self._compute_stress, + accuracy=self.accuracy, + ) + if not isinstance(out, tuple): + out = (out,) + per_atom_energy = out[0] * UnitConversion.e2_per_Ang_to_eV + dev = state.positions.device + energy = torch.zeros(state.n_systems, dtype=torch.float64, device=dev) + energy.scatter_add_(0, state.system_idx.long(), per_atom_energy) + results: dict[str, torch.Tensor] = { + "energy": energy.to(self._dtype).detach(), + } + if self._compute_forces: + forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] + results["forces"] = forces.to(self._dtype).detach() + if self._compute_stress: + volumes = state.volume.unsqueeze(-1).unsqueeze(-1) + stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes + results["stress"] = stress.to(self._dtype).detach() + return results + + +class PMEModel(ModelInterface): + """Particle Mesh Ewald electrostatics as a :class:`ModelInterface`. + + Uses the ``nvalchemiops`` PME kernel for O(N log N) periodic + electrostatics. Returns per-atom energies that are aggregated to + per-system. All user-facing quantities are in metal units + (Angstrom / eV). + + Per-atom partial charges are read from ``state.partial_charges``. + + Requires periodic boundary conditions. + + Args: + cutoff: Real-space cutoff in Angstrom. + accuracy: Target accuracy for auto-estimated parameters. + mesh_spacing: Optional mesh spacing (Angstrom) for automatic mesh sizing. + mesh_dimensions: Explicit FFT mesh dimensions ``(nx, ny, nz)``. + spline_order: B-spline interpolation order. Defaults to 4. + device: Compute device. Defaults to CUDA if available, else CPU. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + compute_forces: Whether to return forces. Defaults to True. + compute_stress: Whether to return stress. Defaults to True. + neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. + """ + + @dcite("10.1063/1.464397", description="Darden et al. PME method") + def __init__( + self, + cutoff: float = 10.0, + *, + accuracy: float = 1e-6, + mesh_spacing: float | None = None, + mesh_dimensions: tuple[int, int, int] | None = None, + spline_order: int = 4, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + neighbor_list_fn: Callable = torchsim_nl, + ) -> None: + """Initialize the PME model.""" + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + self.neighbor_list_fn = neighbor_list_fn + self.cutoff = cutoff + self.accuracy = accuracy + self.mesh_spacing = mesh_spacing + self.mesh_dimensions = mesh_dimensions + self.spline_order = spline_order + + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: + """Compute PME electrostatic energy, forces, and stress. + + Args: + state: Simulation state with ``partial_charges`` set as an + atom extra (shape ``[n_atoms]``). Returns zeros for + non-periodic states. + **_kwargs: Unused; accepted for interface compatibility. + + Returns: + dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], + and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. + """ + if not state.has_extras("partial_charges"): + raise ValueError("Partial charges are required for PME summation.") + + if not state.pbc.any(): + return _zero_result( + state, self._dtype, self._compute_forces, self._compute_stress + ) + charges = state.partial_charges + edge_index, neighbor_ptr, unit_shifts = _build_csr( + state, self.cutoff, self.neighbor_list_fn + ) + cell = state.row_vector_cell.contiguous() + batch_idx = state.system_idx.to(torch.int32) if state.n_systems > 1 else None + pme_kwargs: dict = dict( + positions=state.positions, + charges=charges, + cell=cell, + neighbor_list=edge_index, + neighbor_ptr=neighbor_ptr, + neighbor_shifts=unit_shifts, + batch_idx=batch_idx, + compute_forces=self._compute_forces, + compute_virial=self._compute_stress, + accuracy=self.accuracy, + spline_order=self.spline_order, + ) + if self.mesh_spacing is not None: + pme_kwargs["mesh_spacing"] = self.mesh_spacing + if self.mesh_dimensions is not None: + pme_kwargs["mesh_dimensions"] = self.mesh_dimensions + out = particle_mesh_ewald(**pme_kwargs) + if not isinstance(out, tuple): + out = (out,) + per_atom_energy = out[0] * UnitConversion.e2_per_Ang_to_eV + dev = state.positions.device + energy = torch.zeros(state.n_systems, dtype=torch.float64, device=dev) + energy.scatter_add_(0, state.system_idx.long(), per_atom_energy) + results: dict[str, torch.Tensor] = { + "energy": energy.to(self._dtype).detach(), + } + if self._compute_forces: + forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] + results["forces"] = forces.to(self._dtype).detach() + if self._compute_stress: + volumes = state.volume.unsqueeze(-1).unsqueeze(-1) + stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes + results["stress"] = stress.to(self._dtype).detach() + return results