From dc8fab7a4ff4e9dad9ec5dd3c69186b3753121c8 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 09:56:00 -0400 Subject: [PATCH 1/6] fea: add sum model interface --- tests/models/test_sum_model.py | 203 +++++++++++++++++++++++++++++++++ torch_sim/models/interface.py | 135 +++++++++++++++++++++- 2 files changed, 336 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_sum_model.py diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py new file mode 100644 index 00000000..63d51201 --- /dev/null +++ b/tests/models/test_sum_model.py @@ -0,0 +1,203 @@ +"""Tests for the SumModel composite model.""" + +import pytest +import torch + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.models.interface import SumModel, validate_model_outputs +from torch_sim.models.lennard_jones import LennardJonesModel +from torch_sim.models.morse import MorseModel + + +@pytest.fixture +def lj_model_a() -> LennardJonesModel: + return LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def lj_model_b() -> LennardJonesModel: + return LennardJonesModel( + sigma=2.0, + epsilon=0.005, + cutoff=5.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def morse_model() -> MorseModel: + return MorseModel( + sigma=2.55, + epsilon=0.436, + alpha=1.359, + cutoff=6.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumModel: + return SumModel(lj_model_a, morse_model) + + +def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: + with pytest.raises(ValueError, match="at least two"): + SumModel(lj_model_a) + + +def test_sum_model_device_mismatch() -> None: + m1 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, device=torch.device("cpu")) + m2 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, device=torch.device("cpu")) + object.__setattr__(m2, "_device", torch.device("meta")) + with pytest.raises(ValueError, match="Device mismatch"): + SumModel(m1, m2) + + +def test_sum_model_dtype_mismatch() -> None: + m1 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, dtype=torch.float64) + m2 = LennardJonesModel(sigma=1.0, epsilon=1.0, cutoff=2.5, dtype=torch.float32) + with pytest.raises(ValueError, match="Dtype mismatch"): + SumModel(m1, m2) + + +def test_sum_model_properties(sum_model: SumModel) -> None: + assert sum_model.device == DEVICE + assert sum_model.dtype == DTYPE + assert sum_model.compute_stress is True + assert sum_model.compute_forces is True + + +def test_sum_model_energy_summation( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_sim_state) + morse_out = morse_model(si_sim_state) + sum_out = sum_model(si_sim_state) + expected_energy = lj_out["energy"] + morse_out["energy"] + torch.testing.assert_close(sum_out["energy"], expected_energy) + + +def test_sum_model_forces_summation( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_sim_state) + morse_out = morse_model(si_sim_state) + sum_out = sum_model(si_sim_state) + expected_forces = lj_out["forces"] + morse_out["forces"] + torch.testing.assert_close(sum_out["forces"], expected_forces) + + +def test_sum_model_stress_summation( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_sim_state) + morse_out = morse_model(si_sim_state) + sum_out = sum_model(si_sim_state) + expected_stress = lj_out["stress"] + morse_out["stress"] + torch.testing.assert_close(sum_out["stress"], expected_stress) + + +def test_sum_model_batched( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + sum_model: SumModel, + si_double_sim_state: ts.SimState, +) -> None: + lj_out = lj_model_a(si_double_sim_state) + morse_out = morse_model(si_double_sim_state) + sum_out = sum_model(si_double_sim_state) + torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + morse_out["energy"]) + torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + morse_out["forces"]) + torch.testing.assert_close(sum_out["stress"], lj_out["stress"] + morse_out["stress"]) + + +def test_sum_model_three_models( + lj_model_a: LennardJonesModel, + lj_model_b: LennardJonesModel, + morse_model: MorseModel, + si_sim_state: ts.SimState, +) -> None: + triple = SumModel(lj_model_a, lj_model_b, morse_model) + a_out = lj_model_a(si_sim_state) + b_out = lj_model_b(si_sim_state) + c_out = morse_model(si_sim_state) + sum_out = triple(si_sim_state) + torch.testing.assert_close( + sum_out["energy"], a_out["energy"] + b_out["energy"] + c_out["energy"] + ) + + +def test_sum_model_compute_stress_setter( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + assert sm.compute_stress is True + sm.compute_stress = False + assert sm.compute_stress is False + + +def test_sum_model_compute_forces_setter( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + sm.compute_forces = False + assert sm.compute_forces is False + + +def test_sum_model_memory_scales_with( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + assert sm.memory_scales_with == "n_atoms_x_density" + + +def test_sum_model_force_conservation( + sum_model: SumModel, si_double_sim_state: ts.SimState +) -> None: + results = sum_model(si_double_sim_state) + for sys_idx in range(si_double_sim_state.n_systems): + mask = si_double_sim_state.system_idx == sys_idx + assert torch.allclose( + results["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=DTYPE), + atol=1e-10, + ) + + +def test_sum_model_validate_outputs(sum_model: SumModel) -> None: + validate_model_outputs(sum_model, DEVICE, DTYPE, check_detached=True) + + +def test_sum_model_retain_graph( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> None: + sm = SumModel(lj_model_a, morse_model) + assert sm.retain_graph is False + sm.retain_graph = True + assert lj_model_a.retain_graph is True + assert morse_model.retain_graph is True + assert sm.retain_graph is True diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 8aa6bb5e..a073ab98 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -26,17 +26,29 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): compute_stress property, as some integrators require stress calculations. """ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import torch import torch_sim as ts -from torch_sim.state import SimState -from torch_sim.typing import MemoryScaling + + +if TYPE_CHECKING: + from torch_sim.state import SimState + from torch_sim.typing import MemoryScaling VALIDATE_ATOL = 1e-4 +_MEMORY_SCALING_PRIORITY: dict[MemoryScaling, int] = { + "n_atoms": 0, + "n_atoms_x_density": 1, + "n_edges": 2, +} + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. @@ -171,6 +183,125 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """ +class SumModel(ModelInterface): + """Additive composition of multiple :class:`ModelInterface` models. + + Calls each child model's :meth:`forward` and sums the output tensors + key-by-key, so energies, forces, and stresses are combined additively. + This is the standard way to layer a dispersion correction (e.g. DFT-D3), + an Ewald electrostatic term, or a local pair potential on top of a primary + machine-learning potential. + + Args: + models: Two or more :class:`ModelInterface` instances that share the + same ``device`` and ``dtype``. + + Raises: + ValueError: If fewer than two models are given or if ``device``/``dtype`` + do not match across all models. + + Examples: + ```py + sum_model = SumModel(mace_model, d3_model) + output = sum_model(sim_state) + ``` + """ + + def __init__(self, *models: ModelInterface) -> None: + """Initialize the sum model. + + Args: + models: Two or more :class:`ModelInterface` instances. All must + share the same ``device`` and ``dtype``. + """ + super().__init__() + if len(models) < 2: + raise ValueError("SumModel requires at least two child models") + first = models[0] + for i, m in enumerate(models[1:], start=1): + if m.device != first.device: + raise ValueError( + f"Device mismatch: model 0 has {first.device}, " + f"model {i} has {m.device}" + ) + if m.dtype != first.dtype: + raise ValueError( + f"Dtype mismatch: model 0 has {first.dtype}, model {i} has {m.dtype}" + ) + self.models = torch.nn.ModuleList(models) + self._device = first.device + self._dtype = first.dtype + self._compute_stress = all(m.compute_stress for m in models) + self._compute_forces = all(m.compute_forces for m in models) + + @ModelInterface.compute_stress.setter + def compute_stress(self, value: bool) -> None: # noqa: FBT001 + """Propagate ``compute_stress`` to all child models that support it.""" + for m in self.models: + try: + m.compute_stress = value + except NotImplementedError: + if value: + raise + self._compute_stress = value + + @ModelInterface.compute_forces.setter + def compute_forces(self, value: bool) -> None: # noqa: FBT001 + """Propagate ``compute_forces`` to all child models that support it.""" + for m in self.models: + try: + m.compute_forces = value + except NotImplementedError: + if value: + raise + self._compute_forces = value + + @property + def retain_graph(self) -> bool: + """Whether any child model retains the computation graph.""" + return any(getattr(m, "retain_graph", False) for m in self.models) + + @retain_graph.setter + def retain_graph(self, value: bool) -> None: + for m in self.models: + if hasattr(m, "retain_graph"): + m.retain_graph = value + + @property + def memory_scales_with(self) -> MemoryScaling: + """Most conservative memory-scaling among all child models.""" + best: MemoryScaling = "n_atoms" + for m in self.models: + scaling = m.memory_scales_with + if _MEMORY_SCALING_PRIORITY[scaling] > _MEMORY_SCALING_PRIORITY[best]: + best = scaling + return best + + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: + """Sum the outputs of all child models. + + Each child model is called with the same ``state`` and ``**kwargs``. + Output tensors that appear in multiple children are summed element-wise; + keys unique to a single child are passed through unchanged. + + Args: + state: Simulation state (see :class:`ModelInterface`). + **kwargs: Forwarded to every child model. + + Returns: + Combined output dictionary with summed tensors. + """ + combined: dict[str, torch.Tensor] = {} + for model in self.models: + output = model(state, **kwargs) + for key, tensor in output.items(): + if key in combined: + combined[key] = combined[key] + tensor + else: + combined[key] = tensor + return combined + + def _check_output_detached( output: dict[str, torch.Tensor], model: ModelInterface ) -> None: From 38e31631fd5a58d29a6b07cc361457fbfb960c82 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 14:44:10 -0400 Subject: [PATCH 2/6] lint --- torch_sim/models/interface.py | 18 +++++++++++------- torch_sim/neighbors/vesin.py | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index a073ab98..885627b1 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -234,10 +234,14 @@ def __init__(self, *models: ModelInterface) -> None: self._compute_stress = all(m.compute_stress for m in models) self._compute_forces = all(m.compute_forces for m in models) + def _children(self) -> list[ModelInterface]: + """Return child models with proper typing for static analysis.""" + return list(self.models.children()) # type: ignore[return-value] + @ModelInterface.compute_stress.setter def compute_stress(self, value: bool) -> None: # noqa: FBT001 """Propagate ``compute_stress`` to all child models that support it.""" - for m in self.models: + for m in self._children(): try: m.compute_stress = value except NotImplementedError: @@ -248,7 +252,7 @@ def compute_stress(self, value: bool) -> None: # noqa: FBT001 @ModelInterface.compute_forces.setter def compute_forces(self, value: bool) -> None: # noqa: FBT001 """Propagate ``compute_forces`` to all child models that support it.""" - for m in self.models: + for m in self._children(): try: m.compute_forces = value except NotImplementedError: @@ -259,19 +263,19 @@ def compute_forces(self, value: bool) -> None: # noqa: FBT001 @property def retain_graph(self) -> bool: """Whether any child model retains the computation graph.""" - return any(getattr(m, "retain_graph", False) for m in self.models) + return any(getattr(m, "retain_graph", False) for m in self._children()) @retain_graph.setter def retain_graph(self, value: bool) -> None: - for m in self.models: + for m in self._children(): if hasattr(m, "retain_graph"): - m.retain_graph = value + m.retain_graph = value # type: ignore[union-attr] @property def memory_scales_with(self) -> MemoryScaling: """Most conservative memory-scaling among all child models.""" best: MemoryScaling = "n_atoms" - for m in self.models: + for m in self._children(): scaling = m.memory_scales_with if _MEMORY_SCALING_PRIORITY[scaling] > _MEMORY_SCALING_PRIORITY[best]: best = scaling @@ -292,7 +296,7 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: Combined output dictionary with summed tensors. """ combined: dict[str, torch.Tensor] = {} - for model in self.models: + for model in self._children(): output = model(state, **kwargs) for key, tensor in output.items(): if key in combined: diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index 009fe9bb..16648950 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -12,13 +12,13 @@ try: from vesin import NeighborList as VesinNeighborList except ImportError: - VesinNeighborList = None # type: ignore[assignment] + VesinNeighborList = None try: from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: - VesinNeighborListTorch = None # ty:ignore[invalid-assignment] + VesinNeighborListTorch = None VESIN_AVAILABLE = VesinNeighborList is not None VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None From cd5b0be7cb9d37198d6c72548cf39cd0657a777b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 15:31:04 -0400 Subject: [PATCH 3/6] fea: add alchemiops d3 --- tests/models/test_d3.py | 50 +++++++++ tests/models/test_sum_model.py | 117 ++------------------- torch_sim/models/dispersion.py | 181 +++++++++++++++++++++++++++++++++ torch_sim/units.py | 6 ++ 4 files changed, 246 insertions(+), 108 deletions(-) create mode 100644 tests/models/test_d3.py create mode 100644 torch_sim/models/dispersion.py diff --git a/tests/models/test_d3.py b/tests/models/test_d3.py new file mode 100644 index 00000000..a871c510 --- /dev/null +++ b/tests/models/test_d3.py @@ -0,0 +1,50 @@ +"""Tests for the D3DispersionModel wrapper.""" + +import traceback # noqa: I001 + +import pytest +import torch + +from tests.conftest import DEVICE, DTYPE +from tests.models.conftest import make_validate_model_outputs_test + +try: + from nvalchemiops.torch.interactions.dispersion import D3Parameters + from torch_sim.models.d3 import D3DispersionModel +except (ImportError, OSError, RuntimeError): + pytest.skip( + f"nvalchemiops not installed: {traceback.format_exc()}", + allow_module_level=True, + ) + + +def _make_d3_params(device: torch.device = DEVICE) -> D3Parameters: + """Build minimal D3 reference parameters for testing (elements up to Fe=26).""" + max_z = 26 + mesh = 5 + return D3Parameters( + rcov=torch.rand(max_z + 1, device=device), + r4r2=torch.rand(max_z + 1, device=device), + c6ab=torch.rand(max_z + 1, max_z + 1, mesh, mesh, device=device), + cn_ref=torch.rand(max_z + 1, max_z + 1, mesh, mesh, device=device), + ) + + +@pytest.fixture +def d3_model() -> D3DispersionModel: + return D3DispersionModel( + a1=0.4289, + a2=4.4407, + s8=0.7875, + d3_params=_make_d3_params(), + cutoff=12.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +test_d3_model_outputs = make_validate_model_outputs_test( + model_fixture_name="d3_model", device=DEVICE, dtype=DTYPE +) diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py index 63d51201..84d00ccd 100644 --- a/tests/models/test_sum_model.py +++ b/tests/models/test_sum_model.py @@ -5,7 +5,8 @@ import torch_sim as ts from tests.conftest import DEVICE, DTYPE -from torch_sim.models.interface import SumModel, validate_model_outputs +from tests.models.conftest import make_validate_model_outputs_test +from torch_sim.models.interface import SumModel from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.morse import MorseModel @@ -23,19 +24,6 @@ def lj_model_a() -> LennardJonesModel: ) -@pytest.fixture -def lj_model_b() -> LennardJonesModel: - return LennardJonesModel( - sigma=2.0, - epsilon=0.005, - cutoff=5.0, - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - - @pytest.fixture def morse_model() -> MorseModel: return MorseModel( @@ -55,6 +43,11 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode return SumModel(lj_model_a, morse_model) +test_sum_model_outputs = make_validate_model_outputs_test( + model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE +) + + def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: with pytest.raises(ValueError, match="at least two"): SumModel(lj_model_a) @@ -75,27 +68,7 @@ def test_sum_model_dtype_mismatch() -> None: SumModel(m1, m2) -def test_sum_model_properties(sum_model: SumModel) -> None: - assert sum_model.device == DEVICE - assert sum_model.dtype == DTYPE - assert sum_model.compute_stress is True - assert sum_model.compute_forces is True - - -def test_sum_model_energy_summation( - lj_model_a: LennardJonesModel, - morse_model: MorseModel, - sum_model: SumModel, - si_sim_state: ts.SimState, -) -> None: - lj_out = lj_model_a(si_sim_state) - morse_out = morse_model(si_sim_state) - sum_out = sum_model(si_sim_state) - expected_energy = lj_out["energy"] + morse_out["energy"] - torch.testing.assert_close(sum_out["energy"], expected_energy) - - -def test_sum_model_forces_summation( +def test_sum_model_additivity( lj_model_a: LennardJonesModel, morse_model: MorseModel, sum_model: SumModel, @@ -104,94 +77,22 @@ def test_sum_model_forces_summation( lj_out = lj_model_a(si_sim_state) morse_out = morse_model(si_sim_state) sum_out = sum_model(si_sim_state) - expected_forces = lj_out["forces"] + morse_out["forces"] - torch.testing.assert_close(sum_out["forces"], expected_forces) - - -def test_sum_model_stress_summation( - lj_model_a: LennardJonesModel, - morse_model: MorseModel, - sum_model: SumModel, - si_sim_state: ts.SimState, -) -> None: - lj_out = lj_model_a(si_sim_state) - morse_out = morse_model(si_sim_state) - sum_out = sum_model(si_sim_state) - expected_stress = lj_out["stress"] + morse_out["stress"] - torch.testing.assert_close(sum_out["stress"], expected_stress) - - -def test_sum_model_batched( - lj_model_a: LennardJonesModel, - morse_model: MorseModel, - sum_model: SumModel, - si_double_sim_state: ts.SimState, -) -> None: - lj_out = lj_model_a(si_double_sim_state) - morse_out = morse_model(si_double_sim_state) - sum_out = sum_model(si_double_sim_state) torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + morse_out["energy"]) torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + morse_out["forces"]) torch.testing.assert_close(sum_out["stress"], lj_out["stress"] + morse_out["stress"]) -def test_sum_model_three_models( - lj_model_a: LennardJonesModel, - lj_model_b: LennardJonesModel, - morse_model: MorseModel, - si_sim_state: ts.SimState, -) -> None: - triple = SumModel(lj_model_a, lj_model_b, morse_model) - a_out = lj_model_a(si_sim_state) - b_out = lj_model_b(si_sim_state) - c_out = morse_model(si_sim_state) - sum_out = triple(si_sim_state) - torch.testing.assert_close( - sum_out["energy"], a_out["energy"] + b_out["energy"] + c_out["energy"] - ) - - -def test_sum_model_compute_stress_setter( +def test_sum_model_setters( lj_model_a: LennardJonesModel, morse_model: MorseModel ) -> None: sm = SumModel(lj_model_a, morse_model) assert sm.compute_stress is True sm.compute_stress = False assert sm.compute_stress is False - - -def test_sum_model_compute_forces_setter( - lj_model_a: LennardJonesModel, morse_model: MorseModel -) -> None: - sm = SumModel(lj_model_a, morse_model) sm.compute_forces = False assert sm.compute_forces is False -def test_sum_model_memory_scales_with( - lj_model_a: LennardJonesModel, morse_model: MorseModel -) -> None: - sm = SumModel(lj_model_a, morse_model) - assert sm.memory_scales_with == "n_atoms_x_density" - - -def test_sum_model_force_conservation( - sum_model: SumModel, si_double_sim_state: ts.SimState -) -> None: - results = sum_model(si_double_sim_state) - for sys_idx in range(si_double_sim_state.n_systems): - mask = si_double_sim_state.system_idx == sys_idx - assert torch.allclose( - results["forces"][mask].sum(dim=0), - torch.zeros(3, dtype=DTYPE), - atol=1e-10, - ) - - -def test_sum_model_validate_outputs(sum_model: SumModel) -> None: - validate_model_outputs(sum_model, DEVICE, DTYPE, check_detached=True) - - def test_sum_model_retain_graph( lj_model_a: LennardJonesModel, morse_model: MorseModel ) -> None: diff --git a/torch_sim/models/dispersion.py b/torch_sim/models/dispersion.py new file mode 100644 index 00000000..0ff8a419 --- /dev/null +++ b/torch_sim/models/dispersion.py @@ -0,0 +1,181 @@ +"""DFT-D3(BJ) dispersion correction model. + +Wraps the ``nvalchemiops`` Warp-accelerated DFT-D3 implementation as a +:class:`~torch_sim.models.interface.ModelInterface`, with full PBC, stress +(virial), and batched system support. + +References: + - Grimme et al., J. Chem. Phys. 132, 154104 (2010). + https://doi.org/10.1063/1.3382344 + - Grimme et al., J. Comput. Chem. 32, 1456-1465 (2011). + https://doi.org/10.1002/jcc.21759 + - nvalchemi-toolkit-ops: https://github.com/NVIDIA/nvalchemi-toolkit-ops +""" + +from __future__ import annotations + +import traceback +import warnings +from typing import TYPE_CHECKING, Any + +import torch + +from torch_sim._duecredit import dcite +from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import torchsim_nl +from torch_sim.units import UnitConversion + + +try: + from nvalchemiops.torch.interactions.dispersion import D3Parameters + from nvalchemiops.torch.interactions.dispersion import dftd3 as nvalchemiops_dftd3 +except (ImportError, ModuleNotFoundError) as exc: + warnings.warn(f"nvalchemiops import failed: {traceback.format_exc()}", stacklevel=2) + + class D3Parameters: + """Placeholder when nvalchemiops is not installed.""" + + def __init__(self, *_a: Any, _err: Exception = exc, **_kw: Any) -> None: + """Raise the original import error.""" + raise _err + + def nvalchemiops_dftd3(*_a: Any, _err: Exception = exc, **_kw: Any) -> Any: + """Raise the original import error.""" + raise _err + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch_sim.state import SimState + +_BOHR_TO_ANG = UnitConversion.Bohr_to_Ang +_ANG_TO_BOHR = UnitConversion.Ang_to_Bohr +_HARTREE_TO_EV = UnitConversion.Hartree_to_eV +_FORCE_CONV = _HARTREE_TO_EV / _BOHR_TO_ANG # Hartree/Bohr -> eV/Ang + + +class D3DispersionModel(ModelInterface): + """DFT-D3(BJ) dispersion correction as a :class:`ModelInterface`. + + Computes DFT-D3 energies, forces, and (optionally) stresses via the + ``nvalchemiops`` Warp GPU/CPU kernels. All user-facing quantities are in + metal units (Angstrom / eV); unit conversion to and from atomic units + (Bohr / Hartree) is handled internally. + + Args: + a1: BJ damping parameter (dimensionless, functional-dependent). + a2: BJ damping parameter (in Bohr, functional-dependent). + s8: C8 scaling factor (dimensionless, functional-dependent). + s6: C6 scaling factor. Defaults to 1.0. + d3_params: Reference D3 parameters (rcov, r4r2, c6ab, cn_ref). + cutoff: Neighbor-list cutoff in **Angstrom**. + device: Compute device. Defaults to CUDA if available, else CPU. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + compute_forces: Whether to return forces. Defaults to True. + compute_stress: Whether to return stress. Defaults to True. + neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. + + Example:: + + model = D3DispersionModel( + a1=0.4289, + a2=4.4407, + s8=0.7875, + d3_params=params, + cutoff=50.0, + ) + results = model(sim_state) + """ + + @dcite("10.1063/1.3382344") + @dcite("10.1002/jcc.21759") + def __init__( + self, + a1: float, + a2: float, + s8: float, + *, + s6: float = 1.0, + d3_params: D3Parameters | None = None, + cutoff: float = 95.0 * _BOHR_TO_ANG, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + neighbor_list_fn: Callable = torchsim_nl, + ) -> None: + """Initialize the D3 dispersion model.""" + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + self.neighbor_list_fn = neighbor_list_fn + self.cutoff = cutoff + self.a1 = a1 + self.a2 = a2 + self.s8 = s8 + self.s6 = s6 + self.d3_params = d3_params + + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: + """Compute D3 dispersion energy, forces, and stress. + + Args: + state: Simulation state (positions in Angstrom, cell in Angstrom). + **_kwargs: Unused; accepted for interface compatibility. + + Returns: + dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], + and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. + """ + edge_index, _mapping_system, unit_shifts = self.neighbor_list_fn( + state.positions, + state.row_vector_cell, + state.pbc, + self.cutoff, + state.system_idx, + ) + n_atoms = state.positions.shape[0] + neighbor_ptr = torch.zeros( + n_atoms + 1, dtype=torch.int32, device=state.positions.device + ) + neighbor_ptr[1:] = ( + torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) + ) + positions_bohr = state.positions * _ANG_TO_BOHR + cell_bohr = state.row_vector_cell.contiguous() * _ANG_TO_BOHR + numbers = state.atomic_numbers.to(torch.int32) + unit_shifts_int = unit_shifts.to(torch.int32) + edge_index_int = edge_index.to(torch.int32) + d3_out = nvalchemiops_dftd3( + positions=positions_bohr, + numbers=numbers, + a1=self.a1, + a2=self.a2, + s8=self.s8, + s6=self.s6, + d3_params=self.d3_params, # type: ignore[arg-type] + neighbor_list=edge_index_int, + neighbor_ptr=neighbor_ptr, + cell=cell_bohr, + unit_shifts=unit_shifts_int, + batch_idx=state.system_idx.to(torch.int32), + compute_virial=self._compute_stress, + num_systems=state.n_systems, + ) + + results: dict[str, torch.Tensor] = { + "energy": (d3_out[0] * _HARTREE_TO_EV).to(self._dtype).detach(), + "forces": (d3_out[1] * _FORCE_CONV).to(self._dtype).detach(), + } + if self._compute_stress: + # d3_out[3] is only defined if compute_virial is True + volumes = state.volume.unsqueeze(-1).unsqueeze(-1) + stress = (d3_out[3] * _HARTREE_TO_EV) / volumes # type: ignore[index] + results["stress"] = stress.to(self._dtype).detach() + return results diff --git a/torch_sim/units.py b/torch_sim/units.py index 74cffba4..e59dfff6 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -77,6 +77,12 @@ class UnitConversion: kcal_to_cal = 1e3 eV_to_J = bc.e + # Atomic-unit conversions (Bohr / Hartree <-> Angstrom / eV) + Bohr_to_Ang = 0.529177210903 + Ang_to_Bohr = 1.0 / Bohr_to_Ang + Hartree_to_eV = 27.211386245988 + eV_to_Hartree = 1.0 / Hartree_to_eV + uc = UnitConversion From 7cae3979b46e952fbe1429de41c0cea54fb5c432 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 16:36:26 -0400 Subject: [PATCH 4/6] use less constants [skip-ci] --- torch_sim/models/dispersion.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torch_sim/models/dispersion.py b/torch_sim/models/dispersion.py index 0ff8a419..cf56d38c 100644 --- a/torch_sim/models/dispersion.py +++ b/torch_sim/models/dispersion.py @@ -49,10 +49,7 @@ def nvalchemiops_dftd3(*_a: Any, _err: Exception = exc, **_kw: Any) -> Any: from torch_sim.state import SimState -_BOHR_TO_ANG = UnitConversion.Bohr_to_Ang -_ANG_TO_BOHR = UnitConversion.Ang_to_Bohr -_HARTREE_TO_EV = UnitConversion.Hartree_to_eV -_FORCE_CONV = _HARTREE_TO_EV / _BOHR_TO_ANG # Hartree/Bohr -> eV/Ang +_FORCE_CONV = UnitConversion.Hartree_to_eV / UnitConversion.Bohr_to_Ang class D3DispersionModel(ModelInterface): @@ -98,7 +95,7 @@ def __init__( *, s6: float = 1.0, d3_params: D3Parameters | None = None, - cutoff: float = 95.0 * _BOHR_TO_ANG, + cutoff: float = 95.0 * UnitConversion.Bohr_to_Ang, device: torch.device | None = None, dtype: torch.dtype = torch.float64, compute_forces: bool = True, @@ -147,8 +144,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] neighbor_ptr[1:] = ( torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) ) - positions_bohr = state.positions * _ANG_TO_BOHR - cell_bohr = state.row_vector_cell.contiguous() * _ANG_TO_BOHR + positions_bohr = state.positions * UnitConversion.Ang_to_Bohr + cell_bohr = state.row_vector_cell.contiguous() * UnitConversion.Ang_to_Bohr numbers = state.atomic_numbers.to(torch.int32) unit_shifts_int = unit_shifts.to(torch.int32) edge_index_int = edge_index.to(torch.int32) @@ -170,12 +167,12 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] ) results: dict[str, torch.Tensor] = { - "energy": (d3_out[0] * _HARTREE_TO_EV).to(self._dtype).detach(), + "energy": (d3_out[0] * UnitConversion.Hartree_to_eV).to(self._dtype).detach(), "forces": (d3_out[1] * _FORCE_CONV).to(self._dtype).detach(), } if self._compute_stress: # d3_out[3] is only defined if compute_virial is True volumes = state.volume.unsqueeze(-1).unsqueeze(-1) - stress = (d3_out[3] * _HARTREE_TO_EV) / volumes # type: ignore[index] + stress = (d3_out[3] * UnitConversion.Hartree_to_eV) / volumes # type: ignore[index] results["stress"] = stress.to(self._dtype).detach() return results From ce38dc64ff062c027364ce070d057a2cb59fa1ff Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 17:15:01 -0400 Subject: [PATCH 5/6] fix dispersion naming --- tests/models/{test_d3.py => test_dispersion.py} | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) rename tests/models/{test_d3.py => test_dispersion.py} (95%) diff --git a/tests/models/test_d3.py b/tests/models/test_dispersion.py similarity index 95% rename from tests/models/test_d3.py rename to tests/models/test_dispersion.py index a871c510..b12f2e4e 100644 --- a/tests/models/test_d3.py +++ b/tests/models/test_dispersion.py @@ -10,7 +10,8 @@ try: from nvalchemiops.torch.interactions.dispersion import D3Parameters - from torch_sim.models.d3 import D3DispersionModel + + from torch_sim.models.dispersion import D3DispersionModel except (ImportError, OSError, RuntimeError): pytest.skip( f"nvalchemiops not installed: {traceback.format_exc()}", From d528ef2b6896e039a38cf1166c43fa7052f1f56b Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 21:16:27 -0400 Subject: [PATCH 6/6] fea: tests with r2scan defaults --- tests/models/test_dispersion.py | 33 +++++++++++++++++++++++++++------ torch_sim/models/dispersion.py | 9 +++++++-- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/tests/models/test_dispersion.py b/tests/models/test_dispersion.py index b12f2e4e..4dc5ddc7 100644 --- a/tests/models/test_dispersion.py +++ b/tests/models/test_dispersion.py @@ -31,12 +31,16 @@ def _make_d3_params(device: torch.device = DEVICE) -> D3Parameters: ) +# BJ damping parameters from +# https://github.com/dftd3/simple-dftd3/blob/main/assets/parameters.toml +PBE_BJ = {"a1": 0.4289, "s8": 0.7875, "a2": 4.4407, "s6": 1.0} +R2SCAN_BJ = {"a1": 0.49484001, "s8": 0.78981345, "a2": 5.73083694, "s6": 1.0} + + @pytest.fixture -def d3_model() -> D3DispersionModel: +def d3_model_pbe() -> D3DispersionModel: return D3DispersionModel( - a1=0.4289, - a2=4.4407, - s8=0.7875, + **PBE_BJ, d3_params=_make_d3_params(), cutoff=12.0, device=DEVICE, @@ -46,6 +50,23 @@ def d3_model() -> D3DispersionModel: ) -test_d3_model_outputs = make_validate_model_outputs_test( - model_fixture_name="d3_model", device=DEVICE, dtype=DTYPE +@pytest.fixture +def d3_model_r2scan() -> D3DispersionModel: + return D3DispersionModel( + **R2SCAN_BJ, + d3_params=_make_d3_params(), + cutoff=12.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +test_d3_pbe_outputs = make_validate_model_outputs_test( + model_fixture_name="d3_model_pbe", device=DEVICE, dtype=DTYPE +) + +test_d3_r2scan_outputs = make_validate_model_outputs_test( + model_fixture_name="d3_model_r2scan", device=DEVICE, dtype=DTYPE ) diff --git a/torch_sim/models/dispersion.py b/torch_sim/models/dispersion.py index cf56d38c..97230556 100644 --- a/torch_sim/models/dispersion.py +++ b/torch_sim/models/dispersion.py @@ -60,6 +60,10 @@ class D3DispersionModel(ModelInterface): metal units (Angstrom / eV); unit conversion to and from atomic units (Bohr / Hartree) is handled internally. + Functional-dependent BJ damping parameters (``a1``, ``a2``, ``s8``, ``s6``) + can be looked up from the canonical parameter table: + https://github.com/dftd3/simple-dftd3/blob/main/assets/parameters.toml + Args: a1: BJ damping parameter (dimensionless, functional-dependent). a2: BJ damping parameter (in Bohr, functional-dependent). @@ -156,7 +160,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] a2=self.a2, s8=self.s8, s6=self.s6, - d3_params=self.d3_params, # type: ignore[arg-type] + d3_params=self.d3_params, neighbor_list=edge_index_int, neighbor_ptr=neighbor_ptr, cell=cell_bohr, @@ -172,7 +176,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] } if self._compute_stress: # d3_out[3] is only defined if compute_virial is True + # we use [-1] to index it to avoid typing errors. volumes = state.volume.unsqueeze(-1).unsqueeze(-1) - stress = (d3_out[3] * UnitConversion.Hartree_to_eV) / volumes # type: ignore[index] + stress = (d3_out[-1] * UnitConversion.Hartree_to_eV) / volumes results["stress"] = stress.to(self._dtype).detach() return results