diff --git a/tests/models/test_dispersion.py b/tests/models/test_dispersion.py new file mode 100644 index 00000000..4dc5ddc7 --- /dev/null +++ b/tests/models/test_dispersion.py @@ -0,0 +1,72 @@ +"""Tests for the D3DispersionModel wrapper.""" + +import traceback # noqa: I001 + +import pytest +import torch + +from tests.conftest import DEVICE, DTYPE +from tests.models.conftest import make_validate_model_outputs_test + +try: + from nvalchemiops.torch.interactions.dispersion import D3Parameters + + from torch_sim.models.dispersion import D3DispersionModel +except (ImportError, OSError, RuntimeError): + pytest.skip( + f"nvalchemiops not installed: {traceback.format_exc()}", + allow_module_level=True, + ) + + +def _make_d3_params(device: torch.device = DEVICE) -> D3Parameters: + """Build minimal D3 reference parameters for testing (elements up to Fe=26).""" + max_z = 26 + mesh = 5 + return D3Parameters( + rcov=torch.rand(max_z + 1, device=device), + r4r2=torch.rand(max_z + 1, device=device), + c6ab=torch.rand(max_z + 1, max_z + 1, mesh, mesh, device=device), + cn_ref=torch.rand(max_z + 1, max_z + 1, mesh, mesh, device=device), + ) + + +# BJ damping parameters from +# https://github.com/dftd3/simple-dftd3/blob/main/assets/parameters.toml +PBE_BJ = {"a1": 0.4289, "s8": 0.7875, "a2": 4.4407, "s6": 1.0} +R2SCAN_BJ = {"a1": 0.49484001, "s8": 0.78981345, "a2": 5.73083694, "s6": 1.0} + + +@pytest.fixture +def d3_model_pbe() -> D3DispersionModel: + return D3DispersionModel( + **PBE_BJ, + d3_params=_make_d3_params(), + cutoff=12.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +@pytest.fixture +def d3_model_r2scan() -> D3DispersionModel: + return D3DispersionModel( + **R2SCAN_BJ, + d3_params=_make_d3_params(), + cutoff=12.0, + device=DEVICE, + dtype=DTYPE, + compute_forces=True, + compute_stress=True, + ) + + +test_d3_pbe_outputs = make_validate_model_outputs_test( + model_fixture_name="d3_model_pbe", device=DEVICE, dtype=DTYPE +) + +test_d3_r2scan_outputs = make_validate_model_outputs_test( + model_fixture_name="d3_model_r2scan", device=DEVICE, dtype=DTYPE +) diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py index 63d51201..84d00ccd 100644 --- a/tests/models/test_sum_model.py +++ b/tests/models/test_sum_model.py @@ -5,7 +5,8 @@ import torch_sim as ts from tests.conftest import DEVICE, DTYPE -from torch_sim.models.interface import SumModel, validate_model_outputs +from tests.models.conftest import make_validate_model_outputs_test +from torch_sim.models.interface import SumModel from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.morse import MorseModel @@ -23,19 +24,6 @@ def lj_model_a() -> LennardJonesModel: ) -@pytest.fixture -def lj_model_b() -> LennardJonesModel: - return LennardJonesModel( - sigma=2.0, - epsilon=0.005, - cutoff=5.0, - device=DEVICE, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - - @pytest.fixture def morse_model() -> MorseModel: return MorseModel( @@ -55,6 +43,11 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode return SumModel(lj_model_a, morse_model) +test_sum_model_outputs = make_validate_model_outputs_test( + model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE +) + + def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: with pytest.raises(ValueError, match="at least two"): SumModel(lj_model_a) @@ -75,27 +68,7 @@ def test_sum_model_dtype_mismatch() -> None: SumModel(m1, m2) -def test_sum_model_properties(sum_model: SumModel) -> None: - assert sum_model.device == DEVICE - assert sum_model.dtype == DTYPE - assert sum_model.compute_stress is True - assert sum_model.compute_forces is True - - -def test_sum_model_energy_summation( - lj_model_a: LennardJonesModel, - morse_model: MorseModel, - sum_model: SumModel, - si_sim_state: ts.SimState, -) -> None: - lj_out = lj_model_a(si_sim_state) - morse_out = morse_model(si_sim_state) - sum_out = sum_model(si_sim_state) - expected_energy = lj_out["energy"] + morse_out["energy"] - torch.testing.assert_close(sum_out["energy"], expected_energy) - - -def test_sum_model_forces_summation( +def test_sum_model_additivity( lj_model_a: LennardJonesModel, morse_model: MorseModel, sum_model: SumModel, @@ -104,94 +77,22 @@ def test_sum_model_forces_summation( lj_out = lj_model_a(si_sim_state) morse_out = morse_model(si_sim_state) sum_out = sum_model(si_sim_state) - expected_forces = lj_out["forces"] + morse_out["forces"] - torch.testing.assert_close(sum_out["forces"], expected_forces) - - -def test_sum_model_stress_summation( - lj_model_a: LennardJonesModel, - morse_model: MorseModel, - sum_model: SumModel, - si_sim_state: ts.SimState, -) -> None: - lj_out = lj_model_a(si_sim_state) - morse_out = morse_model(si_sim_state) - sum_out = sum_model(si_sim_state) - expected_stress = lj_out["stress"] + morse_out["stress"] - torch.testing.assert_close(sum_out["stress"], expected_stress) - - -def test_sum_model_batched( - lj_model_a: LennardJonesModel, - morse_model: MorseModel, - sum_model: SumModel, - si_double_sim_state: ts.SimState, -) -> None: - lj_out = lj_model_a(si_double_sim_state) - morse_out = morse_model(si_double_sim_state) - sum_out = sum_model(si_double_sim_state) torch.testing.assert_close(sum_out["energy"], lj_out["energy"] + morse_out["energy"]) torch.testing.assert_close(sum_out["forces"], lj_out["forces"] + morse_out["forces"]) torch.testing.assert_close(sum_out["stress"], lj_out["stress"] + morse_out["stress"]) -def test_sum_model_three_models( - lj_model_a: LennardJonesModel, - lj_model_b: LennardJonesModel, - morse_model: MorseModel, - si_sim_state: ts.SimState, -) -> None: - triple = SumModel(lj_model_a, lj_model_b, morse_model) - a_out = lj_model_a(si_sim_state) - b_out = lj_model_b(si_sim_state) - c_out = morse_model(si_sim_state) - sum_out = triple(si_sim_state) - torch.testing.assert_close( - sum_out["energy"], a_out["energy"] + b_out["energy"] + c_out["energy"] - ) - - -def test_sum_model_compute_stress_setter( +def test_sum_model_setters( lj_model_a: LennardJonesModel, morse_model: MorseModel ) -> None: sm = SumModel(lj_model_a, morse_model) assert sm.compute_stress is True sm.compute_stress = False assert sm.compute_stress is False - - -def test_sum_model_compute_forces_setter( - lj_model_a: LennardJonesModel, morse_model: MorseModel -) -> None: - sm = SumModel(lj_model_a, morse_model) sm.compute_forces = False assert sm.compute_forces is False -def test_sum_model_memory_scales_with( - lj_model_a: LennardJonesModel, morse_model: MorseModel -) -> None: - sm = SumModel(lj_model_a, morse_model) - assert sm.memory_scales_with == "n_atoms_x_density" - - -def test_sum_model_force_conservation( - sum_model: SumModel, si_double_sim_state: ts.SimState -) -> None: - results = sum_model(si_double_sim_state) - for sys_idx in range(si_double_sim_state.n_systems): - mask = si_double_sim_state.system_idx == sys_idx - assert torch.allclose( - results["forces"][mask].sum(dim=0), - torch.zeros(3, dtype=DTYPE), - atol=1e-10, - ) - - -def test_sum_model_validate_outputs(sum_model: SumModel) -> None: - validate_model_outputs(sum_model, DEVICE, DTYPE, check_detached=True) - - def test_sum_model_retain_graph( lj_model_a: LennardJonesModel, morse_model: MorseModel ) -> None: diff --git a/torch_sim/models/dispersion.py b/torch_sim/models/dispersion.py new file mode 100644 index 00000000..97230556 --- /dev/null +++ b/torch_sim/models/dispersion.py @@ -0,0 +1,183 @@ +"""DFT-D3(BJ) dispersion correction model. + +Wraps the ``nvalchemiops`` Warp-accelerated DFT-D3 implementation as a +:class:`~torch_sim.models.interface.ModelInterface`, with full PBC, stress +(virial), and batched system support. + +References: + - Grimme et al., J. Chem. Phys. 132, 154104 (2010). + https://doi.org/10.1063/1.3382344 + - Grimme et al., J. Comput. Chem. 32, 1456-1465 (2011). + https://doi.org/10.1002/jcc.21759 + - nvalchemi-toolkit-ops: https://github.com/NVIDIA/nvalchemi-toolkit-ops +""" + +from __future__ import annotations + +import traceback +import warnings +from typing import TYPE_CHECKING, Any + +import torch + +from torch_sim._duecredit import dcite +from torch_sim.models.interface import ModelInterface +from torch_sim.neighbors import torchsim_nl +from torch_sim.units import UnitConversion + + +try: + from nvalchemiops.torch.interactions.dispersion import D3Parameters + from nvalchemiops.torch.interactions.dispersion import dftd3 as nvalchemiops_dftd3 +except (ImportError, ModuleNotFoundError) as exc: + warnings.warn(f"nvalchemiops import failed: {traceback.format_exc()}", stacklevel=2) + + class D3Parameters: + """Placeholder when nvalchemiops is not installed.""" + + def __init__(self, *_a: Any, _err: Exception = exc, **_kw: Any) -> None: + """Raise the original import error.""" + raise _err + + def nvalchemiops_dftd3(*_a: Any, _err: Exception = exc, **_kw: Any) -> Any: + """Raise the original import error.""" + raise _err + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch_sim.state import SimState + +_FORCE_CONV = UnitConversion.Hartree_to_eV / UnitConversion.Bohr_to_Ang + + +class D3DispersionModel(ModelInterface): + """DFT-D3(BJ) dispersion correction as a :class:`ModelInterface`. + + Computes DFT-D3 energies, forces, and (optionally) stresses via the + ``nvalchemiops`` Warp GPU/CPU kernels. All user-facing quantities are in + metal units (Angstrom / eV); unit conversion to and from atomic units + (Bohr / Hartree) is handled internally. + + Functional-dependent BJ damping parameters (``a1``, ``a2``, ``s8``, ``s6``) + can be looked up from the canonical parameter table: + https://github.com/dftd3/simple-dftd3/blob/main/assets/parameters.toml + + Args: + a1: BJ damping parameter (dimensionless, functional-dependent). + a2: BJ damping parameter (in Bohr, functional-dependent). + s8: C8 scaling factor (dimensionless, functional-dependent). + s6: C6 scaling factor. Defaults to 1.0. + d3_params: Reference D3 parameters (rcov, r4r2, c6ab, cn_ref). + cutoff: Neighbor-list cutoff in **Angstrom**. + device: Compute device. Defaults to CUDA if available, else CPU. + dtype: Floating-point dtype. Defaults to ``torch.float64``. + compute_forces: Whether to return forces. Defaults to True. + compute_stress: Whether to return stress. Defaults to True. + neighbor_list_fn: Neighbor-list constructor. Defaults to ``torchsim_nl``. + + Example:: + + model = D3DispersionModel( + a1=0.4289, + a2=4.4407, + s8=0.7875, + d3_params=params, + cutoff=50.0, + ) + results = model(sim_state) + """ + + @dcite("10.1063/1.3382344") + @dcite("10.1002/jcc.21759") + def __init__( + self, + a1: float, + a2: float, + s8: float, + *, + s6: float = 1.0, + d3_params: D3Parameters | None = None, + cutoff: float = 95.0 * UnitConversion.Bohr_to_Ang, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + compute_forces: bool = True, + compute_stress: bool = True, + neighbor_list_fn: Callable = torchsim_nl, + ) -> None: + """Initialize the D3 dispersion model.""" + super().__init__() + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._memory_scales_with = "n_atoms_x_density" + self.neighbor_list_fn = neighbor_list_fn + self.cutoff = cutoff + self.a1 = a1 + self.a2 = a2 + self.s8 = s8 + self.s6 = s6 + self.d3_params = d3_params + + def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: + """Compute D3 dispersion energy, forces, and stress. + + Args: + state: Simulation state (positions in Angstrom, cell in Angstrom). + **_kwargs: Unused; accepted for interface compatibility. + + Returns: + dict with ``"energy"`` [n_systems], ``"forces"`` [n_atoms, 3], + and (if ``compute_stress``) ``"stress"`` [n_systems, 3, 3]. + """ + edge_index, _mapping_system, unit_shifts = self.neighbor_list_fn( + state.positions, + state.row_vector_cell, + state.pbc, + self.cutoff, + state.system_idx, + ) + n_atoms = state.positions.shape[0] + neighbor_ptr = torch.zeros( + n_atoms + 1, dtype=torch.int32, device=state.positions.device + ) + neighbor_ptr[1:] = ( + torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) + ) + positions_bohr = state.positions * UnitConversion.Ang_to_Bohr + cell_bohr = state.row_vector_cell.contiguous() * UnitConversion.Ang_to_Bohr + numbers = state.atomic_numbers.to(torch.int32) + unit_shifts_int = unit_shifts.to(torch.int32) + edge_index_int = edge_index.to(torch.int32) + d3_out = nvalchemiops_dftd3( + positions=positions_bohr, + numbers=numbers, + a1=self.a1, + a2=self.a2, + s8=self.s8, + s6=self.s6, + d3_params=self.d3_params, + neighbor_list=edge_index_int, + neighbor_ptr=neighbor_ptr, + cell=cell_bohr, + unit_shifts=unit_shifts_int, + batch_idx=state.system_idx.to(torch.int32), + compute_virial=self._compute_stress, + num_systems=state.n_systems, + ) + + results: dict[str, torch.Tensor] = { + "energy": (d3_out[0] * UnitConversion.Hartree_to_eV).to(self._dtype).detach(), + "forces": (d3_out[1] * _FORCE_CONV).to(self._dtype).detach(), + } + if self._compute_stress: + # d3_out[3] is only defined if compute_virial is True + # we use [-1] to index it to avoid typing errors. + volumes = state.volume.unsqueeze(-1).unsqueeze(-1) + stress = (d3_out[-1] * UnitConversion.Hartree_to_eV) / volumes + results["stress"] = stress.to(self._dtype).detach() + return results diff --git a/torch_sim/units.py b/torch_sim/units.py index 74cffba4..e59dfff6 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -77,6 +77,12 @@ class UnitConversion: kcal_to_cal = 1e3 eV_to_J = bc.e + # Atomic-unit conversions (Bohr / Hartree <-> Angstrom / eV) + Bohr_to_Ang = 0.529177210903 + Ang_to_Bohr = 1.0 / Bohr_to_Ang + Hartree_to_eV = 27.211386245988 + eV_to_Hartree = 1.0 / Hartree_to_eV + uc = UnitConversion