diff --git a/tests/models/test_polarization.py b/tests/models/test_polarization.py new file mode 100644 index 00000000..2fc71f96 --- /dev/null +++ b/tests/models/test_polarization.py @@ -0,0 +1,209 @@ +"""Tests for the polarization electric-field correction model.""" + +import pytest +import torch + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.models.interface import ModelInterface, SerialSumModel +from torch_sim.models.polarization import UniformPolarizationModel + + +class DummyPolarResponseModel(ModelInterface): + def __init__( + self, + *, + include_born_effective_charges: bool = True, + include_polarizability: bool = True, + include_total_polarization: bool = True, + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, + ) -> None: + super().__init__() + self.include_born_effective_charges = include_born_effective_charges + self.include_polarizability = include_polarizability + self.include_total_polarization = include_total_polarization + self._device = device + self._dtype = dtype + self._compute_forces = True + self._compute_stress = True + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + energy = torch.arange( + 1, state.n_systems + 1, device=state.device, dtype=state.dtype + ) + forces = ( + torch.arange(state.n_atoms * 3, device=state.device, dtype=state.dtype) + .reshape(state.n_atoms, 3) + .div(10.0) + ) + stress = ( + torch.arange(state.n_systems * 9, device=state.device, dtype=state.dtype) + .reshape(state.n_systems, 3, 3) + .div(100.0) + ) + polarization = ( + torch.arange(state.n_systems * 3, device=state.device, dtype=state.dtype) + .reshape(state.n_systems, 3) + .add(0.5) + ) + output: dict[str, torch.Tensor] = { + "energy": energy, + "forces": forces, + "stress": stress, + } + if self.include_total_polarization: + output["total_polarization"] = polarization + if self.include_polarizability: + diag = torch.tensor([1.0, 2.0, 3.0], device=state.device, dtype=state.dtype) + output["polarizability"] = torch.diag_embed(diag.repeat(state.n_systems, 1)) + if self.include_born_effective_charges: + born_effective_charges = torch.zeros( + state.n_atoms, 3, 3, device=state.device, dtype=state.dtype + ) + born_effective_charges[:, 0, 0] = 1.0 + born_effective_charges[:, 1, 1] = 2.0 + born_effective_charges[:, 2, 2] = 3.0 + output["born_effective_charges"] = born_effective_charges + return output + + +def test_polarization_model_requires_external_e_field( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel() + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + + with pytest.raises(ValueError, match="external_E_field"): + combined_model(si_double_sim_state) + + +def test_polarization_model_applies_linear_response_corrections( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel() + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + field = torch.tensor( + [[0.2, -0.1, 0.05], [-0.3, 0.4, 0.1]], + device=DEVICE, + dtype=DTYPE, + ) + state = ts.SimState.from_state(si_double_sim_state, external_E_field=field) + + base_output = base_model(state) + combined_output = combined_model(state) + expected_polarization = base_output["total_polarization"] + torch.einsum( + "sij,sj->si", base_output["polarizability"], field + ) + expected_energy = base_output["energy"] - torch.einsum( + "si,si->s", field, base_output["total_polarization"] + ) + expected_energy = expected_energy - 0.5 * torch.einsum( + "si,sij,sj->s", field, base_output["polarizability"], field + ) + expected_forces = base_output["forces"] + torch.einsum( + "imn,im->in", + base_output["born_effective_charges"], + field[state.system_idx], + ) + + torch.testing.assert_close(combined_output["energy"], expected_energy) + torch.testing.assert_close(combined_output["forces"], expected_forces) + torch.testing.assert_close( + combined_output["total_polarization"], expected_polarization + ) + torch.testing.assert_close(combined_output["stress"], base_output["stress"]) + + +def test_polarization_model_returns_additive_total_polarization_delta( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel() + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + field = torch.tensor([[0.1, 0.0, 0.0], [0.0, -0.2, 0.3]], device=DEVICE, dtype=DTYPE) + state = ts.SimState.from_state(si_double_sim_state, external_E_field=field) + + base_output = base_model(state) + combined_output = combined_model(state) + expected_total_polarization = base_output["total_polarization"] + torch.einsum( + "sij,sj->si", base_output["polarizability"], field + ) + + torch.testing.assert_close( + combined_output["total_polarization"], expected_total_polarization + ) + serialized_state = state.clone() + serialized_state.store_model_extras(base_output) + correction_output = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)( + serialized_state + ) + torch.testing.assert_close( + correction_output["total_polarization"], + expected_total_polarization, + ) + + +def test_polarization_model_requires_born_effective_charges_for_force_correction( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel(include_born_effective_charges=False) + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + state = ts.SimState.from_state( + si_double_sim_state, + external_E_field=torch.ones( + si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE + ), + ) + + with pytest.raises(ValueError, match="born_effective_charges"): + combined_model(state) + + +def test_polarization_model_requires_total_polarization( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel(include_total_polarization=False) + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + state = ts.SimState.from_state( + si_double_sim_state, + external_E_field=torch.ones( + si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE + ), + ) + + with pytest.raises(ValueError, match="total_polarization"): + combined_model(state) + + +def test_polarization_model_rejects_non_uniform_field_shape( + si_double_sim_state: ts.SimState, +) -> None: + state = ts.SimState.from_state( + si_double_sim_state, + external_E_field=torch.zeros( + si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE + ), + ) + state._system_extras["external_E_field"] = torch.zeros( # noqa: SLF001 + state.n_atoms, 3, device=DEVICE, dtype=DTYPE + ) + model = UniformPolarizationModel(device=DEVICE, dtype=DTYPE) + + with pytest.raises(ValueError, match="shape \\(n_systems, 3\\)"): + model(state) diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py index 84d00ccd..4fb8d2ca 100644 --- a/tests/models/test_sum_model.py +++ b/tests/models/test_sum_model.py @@ -6,11 +6,79 @@ import torch_sim as ts from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test -from torch_sim.models.interface import SumModel +from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.morse import MorseModel +class ExtraProducerModel(ModelInterface): + def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_stress = False + self._compute_forces = False + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + latent = state.positions[:, 0] + 2.0 + return { + "energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype), + "latent": latent, + } + + +class ExtraConsumerModel(ModelInterface): + seen_latent: torch.Tensor | None + + def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_stress = False + self._compute_forces = False + self.seen_latent = None + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + self.seen_latent = state.latent.clone() + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, state.latent) + return {"energy": energy} + + +class OverwriteExtrasModel(ModelInterface): + def __init__( + self, + value: float, + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, + ) -> None: + super().__init__() + self.value = value + self._device = device + self._dtype = dtype + self._compute_stress = False + self._compute_forces = False + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + return { + "energy": torch.full( + (state.n_systems,), + self.value, + device=state.device, + dtype=state.dtype, + ), + "label": torch.full( + (state.n_systems, 3), + self.value, + device=state.device, + dtype=state.dtype, + ), + } + + @pytest.fixture def lj_model_a() -> LennardJonesModel: return LennardJonesModel( @@ -43,9 +111,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode return SumModel(lj_model_a, morse_model) +@pytest.fixture +def serial_sum_model( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> SerialSumModel: + return SerialSumModel(lj_model_a, morse_model) + + test_sum_model_outputs = make_validate_model_outputs_test( model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE ) +test_serial_sum_model_outputs = make_validate_model_outputs_test( + model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE +) def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: @@ -102,3 +180,66 @@ def test_sum_model_retain_graph( assert lj_model_a.retain_graph is True assert morse_model.retain_graph is True assert sm.retain_graph is True + + +def test_serial_sum_model_matches_parallel_sum_for_independent_models( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + si_sim_state: ts.SimState, +) -> None: + sum_out = SumModel(lj_model_a, morse_model)(si_sim_state) + serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state) + torch.testing.assert_close(serial_out["energy"], sum_out["energy"]) + torch.testing.assert_close(serial_out["forces"], sum_out["forces"]) + torch.testing.assert_close(serial_out["stress"], sum_out["stress"]) + + +def test_serial_sum_model_exposes_extras_to_later_models( + si_double_sim_state: ts.SimState, +) -> None: + producer = ExtraProducerModel() + consumer = ExtraConsumerModel() + serial_model = SerialSumModel(producer, consumer) + state = si_double_sim_state.clone() + expected_latent = state.positions[:, 0] + 2.0 + expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype) + expected_energy = expected_energy.scatter_add( + 0, + state.system_idx, + expected_latent, + ) + + output = serial_model(state) + + assert consumer.seen_latent is not None + torch.testing.assert_close(consumer.seen_latent, expected_latent) + torch.testing.assert_close(output["latent"], expected_latent) + torch.testing.assert_close(output["energy"], expected_energy) + assert not state.has_extras("latent") + + +def test_serial_sum_model_overwrites_noncanonical_outputs( + si_double_sim_state: ts.SimState, +) -> None: + model = SerialSumModel(OverwriteExtrasModel(1.0), OverwriteExtrasModel(2.0)) + + output = model(si_double_sim_state) + + torch.testing.assert_close( + output["energy"], + torch.full( + (si_double_sim_state.n_systems,), + 3.0, + device=si_double_sim_state.device, + dtype=si_double_sim_state.dtype, + ), + ) + torch.testing.assert_close( + output["label"], + torch.full( + (si_double_sim_state.n_systems, 3), + 2.0, + device=si_double_sim_state.device, + dtype=si_double_sim_state.dtype, + ), + ) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 98d28433..6741de34 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -34,6 +34,7 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): import torch import torch_sim as ts +from torch_sim.state import _CANONICAL_MODEL_KEYS if TYPE_CHECKING: @@ -52,6 +53,21 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): } +def _accumulate_model_output( + combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor] +) -> None: + """Accumulate one model output into a combined output dict. + + Canonical mechanical outputs are additive. Other outputs are treated as + full updated values, so later models replace earlier ones. + """ + for key, tensor in output.items(): + if key in combined and key in _CANONICAL_MODEL_KEYS: + combined[key] = combined[key] + tensor + else: + combined[key] = tensor + + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. @@ -188,11 +204,12 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: class SumModel(ModelInterface): """Additive composition of multiple :class:`ModelInterface` models. - Calls each child model's :meth:`forward` and sums the output tensors - key-by-key, so energies, forces, and stresses are combined additively. - This is the standard way to layer a dispersion correction (e.g. DFT-D3), - an Ewald electrostatic term, or a local pair potential on top of a primary - machine-learning potential. + Calls each child model's :meth:`forward`. Canonical mechanical outputs + (energy, forces, stress) are combined additively, while non-canonical + outputs are treated as full updated values and later models replace + earlier ones. This is the standard way to layer a dispersion correction + (e.g. DFT-D3), an Ewald electrostatic term, or a local pair potential on + top of a primary machine-learning potential. Args: models: Two or more :class:`ModelInterface` instances that share the @@ -287,8 +304,9 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Sum the outputs of all child models. Each child model is called with the same ``state`` and ``**kwargs``. - Output tensors that appear in multiple children are summed element-wise; - keys unique to a single child are passed through unchanged. + Canonical mechanical outputs that appear in multiple children are + summed element-wise. Non-canonical outputs are replaced by later + models so they behave like full state updates rather than deltas. Args: state: Simulation state (see :class:`ModelInterface`). @@ -300,11 +318,36 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: combined: dict[str, torch.Tensor] = {} for model in self._children(): output = model(state, **kwargs) - for key, tensor in output.items(): - if key in combined: - combined[key] = combined[key] + tensor - else: - combined[key] = tensor + _accumulate_model_output(combined, output) + return combined + + +class SerialSumModel(SumModel): + """Serial additive composition of multiple :class:`ModelInterface` models. + + Unlike :class:`SumModel`, child models do not all see the same input state. + Instead, each child runs after the previous child's non-canonical outputs have + been stored into a cloned :class:`~torch_sim.state.SimState` via + :meth:`torch_sim.state.SimState.store_model_extras`. This lets earlier models + expose per-atom or per-system features that later models can consume. + Energies, forces, and stresses remain additive, while repeated auxiliary + outputs are treated as full updated values from the latest stage. + + Examples: + ```py + serial_model = SerialSumModel(polarization_model, dispersion_model) + output = serial_model(sim_state) + ``` + """ + + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: + """Run child models serially, exposing extras from earlier models.""" + combined: dict[str, torch.Tensor] = {} + serial_state = state.clone() + for model in self._children(): + output = model(serial_state, **kwargs) + _accumulate_model_output(combined, output) + serial_state.store_model_extras(output) return combined diff --git a/torch_sim/models/polarization.py b/torch_sim/models/polarization.py new file mode 100644 index 00000000..8b86086b --- /dev/null +++ b/torch_sim/models/polarization.py @@ -0,0 +1,154 @@ +"""Electric-field corrections for polarization-aware models.""" + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.state import SimState +from torch_sim.typing import AtomExtras, SystemExtras + + +class UniformPolarizationModel(ModelInterface): + """Calculates the energy and force contributions from the application + of a constant electric field to a polarizable system. + + This model is intended to run after an upstream model inside + :class:`~torch_sim.models.interface.SerialSumModel`. + + Required state extras: + + * ``external_E_field`` + * ``total_polarization`` + * ``polarizability`` + * ``born_effective_charges`` when ``compute_forces`` is enabled + """ + + def __init__( + self, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + *, + compute_forces: bool = True, + compute_stress: bool = True, + retain_graph: bool = False, + ) -> None: + """Initialize a uniform-field polarization correction model.""" + super().__init__() + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._retain_graph = retain_graph + self._memory_scales_with = "n_atoms" + + @ModelInterface.compute_stress.setter + def compute_stress(self, value: bool) -> None: # noqa: FBT001 + """Set whether the model returns an additive stress tensor.""" + self._compute_stress = value + + @ModelInterface.compute_forces.setter + def compute_forces(self, value: bool) -> None: # noqa: FBT001 + """Set whether the model returns additive force corrections.""" + self._compute_forces = value + + @property + def retain_graph(self) -> bool: + """Whether outputs should remain attached to the autograd graph.""" + return self._retain_graph + + @retain_graph.setter + def retain_graph(self, value: bool) -> None: + """Set whether outputs should remain attached to the autograd graph.""" + self._retain_graph = value + + def _finalize_output( + self, output: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Detach outputs unless graph retention is enabled.""" + if self.retain_graph: + return output + return { + key: val.detach() if isinstance(val, torch.Tensor) else val + for key, val in output.items() + } + + def _apply_nonzero_field( + self, + state: SimState, + output: dict[str, torch.Tensor], + field: torch.Tensor, + ) -> None: + """Apply constant-field linear-response corrections. + + Computes the additive updates + - delta_energy = -E·P0 - 1/2 E·alpha·E + - total_polarization = P0 + alpha·E + - delta_forces = Z*·E + """ + required_keys = [ + SystemExtras.TOTAL_POLARIZATION.value, + SystemExtras.POLARIZABILITY.value, + ] + if self.compute_forces: + required_keys.append(AtomExtras.BORN_EFFECTIVE_CHARGES.value) + + missing_keys = [key for key in required_keys if not state.has_extras(key)] + if missing_keys: + missing = ", ".join(f"'{key}'" for key in missing_keys) + raise ValueError( + f"UniformPolarizationModel requires {missing} on the state " + "when external_E_field is non-zero" + ) + + dipole_coupling = torch.einsum("si,si->s", field, state.total_polarization) + polarization_response = torch.einsum( + "si,sij,sj->s", field, state.polarizability, field + ) + output["energy"] = -dipole_coupling - 0.5 * polarization_response + output[SystemExtras.TOTAL_POLARIZATION.value] = ( + torch.einsum( + "sij,sj->si", + state.polarizability, + field, + ) + + state.total_polarization + ) + if self.compute_forces: + output["forces"] = torch.einsum( + "imn,im->in", + state.born_effective_charges, + field[state.system_idx], + ) + + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: + """Return additive uniform-field corrections for a polarization model.""" + del kwargs + output: dict[str, torch.Tensor] = { + "energy": torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + } + if self.compute_forces: + output["forces"] = torch.zeros_like(state.positions) + if self.compute_stress: + # V1 intentionally applies no field-induced stress correction. + output["stress"] = torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ) + + if not state.has_extras(SystemExtras.EXTERNAL_E_FIELD.value): + raise ValueError( + "UniformPolarizationModel requires 'external_E_field' on the state" + ) + + field = getattr(state, SystemExtras.EXTERNAL_E_FIELD.value) + if field.shape != (state.n_systems, 3): + raise ValueError( + "UniformPolarizationModel requires external_E_field to have shape " + "(n_systems, 3)" + ) + if not torch.any(field != 0): + return self._finalize_output(output) + + self._apply_nonzero_field(state, output, field) + return self._finalize_output(output) + + +__all__ = ["UniformPolarizationModel"]