From bcc58c25259c564e2188b125d97e95bbc0589f07 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 8 Apr 2026 16:30:35 -0400 Subject: [PATCH 1/3] fea: add polarization model --- tests/models/test_polarization.py | 194 ++++++++++++++++++++++++++++ tests/models/test_sum_model.py | 84 +++++++++++- torch_sim/models/interface.py | 46 ++++++- torch_sim/models/polarization.py | 204 ++++++++++++++++++++++++++++++ 4 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 tests/models/test_polarization.py create mode 100644 torch_sim/models/polarization.py diff --git a/tests/models/test_polarization.py b/tests/models/test_polarization.py new file mode 100644 index 00000000..9f1333e4 --- /dev/null +++ b/tests/models/test_polarization.py @@ -0,0 +1,194 @@ +"""Tests for the polarization electric-field correction model.""" + +import pytest +import torch + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.models.interface import ModelInterface, SerialSumModel +from torch_sim.models.polarization import UniformPolarizationModel + + +class DummyPolarResponseModel(ModelInterface): + def __init__( + self, + *, + polarization_key: str = "polarization", + include_born_effective_charges: bool = True, + include_polarizability: bool = True, + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, + ) -> None: + super().__init__() + self.polarization_key = polarization_key + self.include_born_effective_charges = include_born_effective_charges + self.include_polarizability = include_polarizability + self._device = device + self._dtype = dtype + self._compute_forces = True + self._compute_stress = True + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + energy = torch.arange( + 1, state.n_systems + 1, device=state.device, dtype=state.dtype + ) + forces = ( + torch.arange(state.n_atoms * 3, device=state.device, dtype=state.dtype) + .reshape(state.n_atoms, 3) + .div(10.0) + ) + stress = ( + torch.arange(state.n_systems * 9, device=state.device, dtype=state.dtype) + .reshape(state.n_systems, 3, 3) + .div(100.0) + ) + polarization = ( + torch.arange(state.n_systems * 3, device=state.device, dtype=state.dtype) + .reshape(state.n_systems, 3) + .add(0.5) + ) + output: dict[str, torch.Tensor] = { + "energy": energy, + "forces": forces, + "stress": stress, + self.polarization_key: polarization, + } + if self.include_polarizability: + diag = torch.tensor([1.0, 2.0, 3.0], device=state.device, dtype=state.dtype) + output["polarizability"] = torch.diag_embed(diag.repeat(state.n_systems, 1)) + if self.include_born_effective_charges: + born_effective_charges = torch.zeros( + state.n_atoms, 3, 3, device=state.device, dtype=state.dtype + ) + born_effective_charges[:, 0, 0] = 1.0 + born_effective_charges[:, 1, 1] = 2.0 + born_effective_charges[:, 2, 2] = 3.0 + output["born_effective_charges"] = born_effective_charges + return output + + +def test_polarization_model_normalizes_raw_key_without_field( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel() + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + + base_output = base_model(si_double_sim_state) + combined_output = combined_model(si_double_sim_state) + + torch.testing.assert_close(combined_output["energy"], base_output["energy"]) + torch.testing.assert_close(combined_output["forces"], base_output["forces"]) + torch.testing.assert_close(combined_output["stress"], base_output["stress"]) + torch.testing.assert_close( + combined_output["total_polarization"], base_output["polarization"] + ) + torch.testing.assert_close( + combined_output["polarization"], base_output["polarization"] + ) + + +def test_polarization_model_applies_linear_response_corrections( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel() + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + field = torch.tensor( + [[0.2, -0.1, 0.05], [-0.3, 0.4, 0.1]], + device=DEVICE, + dtype=DTYPE, + ) + state = ts.SimState.from_state(si_double_sim_state, external_E_field=field) + + base_output = base_model(state) + combined_output = combined_model(state) + expected_polarization = base_output["polarization"] + torch.einsum( + "sij,sj->si", base_output["polarizability"], field + ) + expected_energy = base_output["energy"] - torch.einsum( + "si,si->s", field, base_output["polarization"] + ) + expected_energy = expected_energy - 0.5 * torch.einsum( + "si,sij,sj->s", field, base_output["polarizability"], field + ) + expected_forces = base_output["forces"] + torch.einsum( + "imn,im->in", + base_output["born_effective_charges"], + field[state.system_idx], + ) + + torch.testing.assert_close(combined_output["energy"], expected_energy) + torch.testing.assert_close(combined_output["forces"], expected_forces) + torch.testing.assert_close( + combined_output["total_polarization"], expected_polarization + ) + torch.testing.assert_close( + combined_output["polarization"], base_output["polarization"] + ) + torch.testing.assert_close(combined_output["stress"], base_output["stress"]) + + +def test_polarization_model_adds_only_delta_for_blessed_name( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel(polarization_key="total_polarization") + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + field = torch.tensor([[0.1, 0.0, 0.0], [0.0, -0.2, 0.3]], device=DEVICE, dtype=DTYPE) + state = ts.SimState.from_state(si_double_sim_state, external_E_field=field) + + base_output = base_model(state) + combined_output = combined_model(state) + expected_total_polarization = base_output["total_polarization"] + torch.einsum( + "sij,sj->si", base_output["polarizability"], field + ) + + torch.testing.assert_close( + combined_output["total_polarization"], expected_total_polarization + ) + assert "polarization" not in combined_output + + +def test_polarization_model_requires_born_effective_charges_for_force_correction( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel(include_born_effective_charges=False) + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + state = ts.SimState.from_state( + si_double_sim_state, + external_E_field=torch.ones( + si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE + ), + ) + + with pytest.raises(ValueError, match="born_effective_charges"): + combined_model(state) + + +def test_polarization_model_rejects_non_uniform_field_shape( + si_double_sim_state: ts.SimState, +) -> None: + state = ts.SimState.from_state( + si_double_sim_state, + external_E_field=torch.zeros( + si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE + ), + ) + state._system_extras["external_E_field"] = torch.zeros( # noqa: SLF001 + state.n_atoms, 3, device=DEVICE, dtype=DTYPE + ) + model = UniformPolarizationModel(device=DEVICE, dtype=DTYPE) + + with pytest.raises(ValueError, match="shape \\(n_systems, 3\\)"): + model(state) diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py index 84d00ccd..ae575ad4 100644 --- a/tests/models/test_sum_model.py +++ b/tests/models/test_sum_model.py @@ -6,11 +6,47 @@ import torch_sim as ts from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test -from torch_sim.models.interface import SumModel +from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.models.morse import MorseModel +class ExtraProducerModel(ModelInterface): + def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_stress = False + self._compute_forces = False + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + latent = state.positions[:, 0] + 2.0 + return { + "energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype), + "latent": latent, + } + + +class ExtraConsumerModel(ModelInterface): + seen_latent: torch.Tensor | None + + def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_stress = False + self._compute_forces = False + self.seen_latent = None + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + self.seen_latent = state.latent.clone() + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, state.latent) + return {"energy": energy} + + @pytest.fixture def lj_model_a() -> LennardJonesModel: return LennardJonesModel( @@ -43,9 +79,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode return SumModel(lj_model_a, morse_model) +@pytest.fixture +def serial_sum_model( + lj_model_a: LennardJonesModel, morse_model: MorseModel +) -> SerialSumModel: + return SerialSumModel(lj_model_a, morse_model) + + test_sum_model_outputs = make_validate_model_outputs_test( model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE ) +test_serial_sum_model_outputs = make_validate_model_outputs_test( + model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE +) def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: @@ -102,3 +148,39 @@ def test_sum_model_retain_graph( assert lj_model_a.retain_graph is True assert morse_model.retain_graph is True assert sm.retain_graph is True + + +def test_serial_sum_model_matches_parallel_sum_for_independent_models( + lj_model_a: LennardJonesModel, + morse_model: MorseModel, + si_sim_state: ts.SimState, +) -> None: + sum_out = SumModel(lj_model_a, morse_model)(si_sim_state) + serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state) + torch.testing.assert_close(serial_out["energy"], sum_out["energy"]) + torch.testing.assert_close(serial_out["forces"], sum_out["forces"]) + torch.testing.assert_close(serial_out["stress"], sum_out["stress"]) + + +def test_serial_sum_model_exposes_extras_to_later_models( + si_double_sim_state: ts.SimState, +) -> None: + producer = ExtraProducerModel() + consumer = ExtraConsumerModel() + serial_model = SerialSumModel(producer, consumer) + state = si_double_sim_state.clone() + expected_latent = state.positions[:, 0] + 2.0 + expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype) + expected_energy = expected_energy.scatter_add( + 0, + state.system_idx, + expected_latent, + ) + + output = serial_model(state) + + assert consumer.seen_latent is not None + torch.testing.assert_close(consumer.seen_latent, expected_latent) + torch.testing.assert_close(output["latent"], expected_latent) + torch.testing.assert_close(output["energy"], expected_energy) + assert not state.has_extras("latent") diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 98d28433..999152c8 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -52,6 +52,17 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): } +def _accumulate_model_output( + combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor] +) -> None: + """Accumulate one model output into a combined output dict.""" + for key, tensor in output.items(): + if key in combined: + combined[key] = combined[key] + tensor + else: + combined[key] = tensor + + class ModelInterface(torch.nn.Module, ABC): """Abstract base class for all simulation models in TorchSim. @@ -300,11 +311,36 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: combined: dict[str, torch.Tensor] = {} for model in self._children(): output = model(state, **kwargs) - for key, tensor in output.items(): - if key in combined: - combined[key] = combined[key] + tensor - else: - combined[key] = tensor + _accumulate_model_output(combined, output) + return combined + + +class SerialSumModel(SumModel): + """Serial additive composition of multiple :class:`ModelInterface` models. + + Unlike :class:`SumModel`, child models do not all see the same input state. + Instead, each child runs after the previous child's non-canonical outputs have + been stored into a cloned :class:`~torch_sim.state.SimState` via + :meth:`torch_sim.state.SimState.store_model_extras`. This lets earlier models + expose per-atom or per-system features that later models can consume, while + energies, forces, stresses, and any repeated auxiliary outputs are still summed + key-by-key. + + Examples: + ```py + serial_model = SerialSumModel(polarization_model, dispersion_model) + output = serial_model(sim_state) + ``` + """ + + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: + """Run child models serially, exposing extras from earlier models.""" + combined: dict[str, torch.Tensor] = {} + serial_state = state.clone() + for model in self._children(): + output = model(serial_state, **kwargs) + _accumulate_model_output(combined, output) + serial_state.store_model_extras(output) return combined diff --git a/torch_sim/models/polarization.py b/torch_sim/models/polarization.py new file mode 100644 index 00000000..caa22392 --- /dev/null +++ b/torch_sim/models/polarization.py @@ -0,0 +1,204 @@ +"""Electric-field corrections for polarization-aware models.""" + +from typing import Protocol + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.typing import AtomExtras, SystemExtras + + +RAW_POLARIZATION_KEY = "polarization" + + +class _PolarizationState(Protocol): + n_systems: int + device: torch.device + dtype: torch.dtype + positions: torch.Tensor + system_idx: torch.Tensor + + +class UniformPolarizationModel(ModelInterface): + """Apply constant electric-field corrections from polarization response tensors. + + This model is designed to run after an upstream model inside a + :class:`~torch_sim.models.interface.SerialSumModel`. The upstream model is + expected to populate zero-field response tensors on the state via + ``store_model_extras()``, such as ``polarization``/``total_polarization``, + ``polarizability``, and ``born_effective_charges``. + + The returned canonical outputs are additive corrections: + + * ``energy``: ``-E·P0 - 0.5 E·alpha·E`` + * ``forces``: ``e Z*·E`` + * ``stress``: zero tensor, so upstream stress is preserved unchanged + + The model also emits ``total_polarization`` using TorchSim's blessed naming. + When the upstream model already emitted ``total_polarization``, this model + returns only the field-induced increment so that additive composition yields + the final corrected polarization. When the upstream model only emitted the + legacy ``polarization`` key, this model returns the full corrected + ``total_polarization`` tensor. + """ + + def __init__( + self, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + *, + compute_forces: bool = True, + compute_stress: bool = True, + retain_graph: bool = False, + ) -> None: + """Initialize a uniform-field polarization correction model.""" + super().__init__() + self._device = device or torch.device("cpu") + self._dtype = dtype + self._compute_forces = compute_forces + self._compute_stress = compute_stress + self._retain_graph = retain_graph + self._memory_scales_with = "n_atoms" + + @ModelInterface.compute_stress.setter + def compute_stress(self, value: bool) -> None: # noqa: FBT001 + """Set whether the model returns an additive stress tensor.""" + self._compute_stress = value + + @ModelInterface.compute_forces.setter + def compute_forces(self, value: bool) -> None: # noqa: FBT001 + """Set whether the model returns additive force corrections.""" + self._compute_forces = value + + @property + def retain_graph(self) -> bool: + """Whether outputs should remain attached to the autograd graph.""" + return self._retain_graph + + @retain_graph.setter + def retain_graph(self, value: bool) -> None: + """Set whether outputs should remain attached to the autograd graph.""" + self._retain_graph = value + + @staticmethod + def _polarization_from_state( + state: _PolarizationState, + ) -> tuple[torch.Tensor | None, bool]: + """Read polarization from blessed or legacy state extras.""" + total_polarization = getattr(state, SystemExtras.TOTAL_POLARIZATION.value, None) + if total_polarization is not None: + return total_polarization, True + raw_polarization = getattr(state, RAW_POLARIZATION_KEY, None) + if raw_polarization is not None: + return raw_polarization, False + return None, False + + def _build_output(self, state: _PolarizationState) -> dict[str, torch.Tensor]: + """Allocate additive outputs for energy and optional corrections.""" + output: dict[str, torch.Tensor] = { + "energy": torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + } + if self.compute_forces: + output["forces"] = torch.zeros_like(state.positions) + if self.compute_stress: + # V1 intentionally applies no field-induced stress correction. + output["stress"] = torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ) + return output + + def _apply_field_corrections( + self, + state: _PolarizationState, + output: dict[str, torch.Tensor], + field: torch.Tensor, + base_polarization: torch.Tensor | None, + *, + using_blessed_name: bool, + ) -> None: + """Apply nonzero uniform-field energy, force, and polarization updates.""" + if base_polarization is None: + required_key = ( + SystemExtras.TOTAL_POLARIZATION.value + if using_blessed_name + else RAW_POLARIZATION_KEY + ) + raise ValueError( + f"UniformPolarizationModel requires '{required_key}' on the state " + "when external_E_field is non-zero" + ) + polarizability = getattr(state, SystemExtras.POLARIZABILITY.value, None) + if polarizability is None: + raise ValueError( + "UniformPolarizationModel requires 'polarizability' on the state " + "when external_E_field is non-zero" + ) + output["energy"] = output["energy"] - torch.einsum( + "si,si->s", field, base_polarization + ) + output["energy"] = output["energy"] - 0.5 * torch.einsum( + "si,sij,sj->s", field, polarizability, field + ) + polarization_delta = torch.einsum("sij,sj->si", polarizability, field) + if using_blessed_name: + output[SystemExtras.TOTAL_POLARIZATION.value] = polarization_delta + else: + output[SystemExtras.TOTAL_POLARIZATION.value] = ( + base_polarization + polarization_delta + ) + if not self.compute_forces: + return + born_effective_charges = getattr( + state, AtomExtras.BORN_EFFECTIVE_CHARGES.value, None + ) + if born_effective_charges is None: + raise ValueError( + "UniformPolarizationModel requires 'born_effective_charges' on " + "the state when external_E_field is non-zero" + ) + output["forces"] = output["forces"] + torch.einsum( + "imn,im->in", born_effective_charges, field[state.system_idx] + ) + + def forward( + self, + state: _PolarizationState, + **kwargs, + ) -> dict[str, torch.Tensor]: + """Return additive uniform-field corrections for a polarization model.""" + del kwargs + output = self._build_output(state) + + field = getattr(state, SystemExtras.EXTERNAL_E_FIELD.value, None) + base_polarization, using_blessed_name = self._polarization_from_state(state) + if field is None: + if base_polarization is not None and not using_blessed_name: + output[SystemExtras.TOTAL_POLARIZATION.value] = base_polarization + else: + if field.shape != (state.n_systems, 3): + raise ValueError( + "UniformPolarizationModel requires external_E_field to have shape " + "(n_systems, 3)" + ) + if base_polarization is not None and not using_blessed_name: + output[SystemExtras.TOTAL_POLARIZATION.value] = base_polarization + if torch.any(field != 0): + self._apply_field_corrections( + state, + output, + field, + base_polarization, + using_blessed_name=using_blessed_name, + ) + + if not self.retain_graph: + output = { + key: val.detach() if isinstance(val, torch.Tensor) else val + for key, val in output.items() + } + return output + + +# TODO: Add a spatially-varying polarization model. + +__all__ = ["UniformPolarizationModel"] From d1dbf2eed6af297e382060283fd32618ef64651a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 8 Apr 2026 17:11:49 -0400 Subject: [PATCH 2/3] wip: the interface doesn't super clean potentially needs more work. --- tests/models/test_polarization.py | 59 ++++---- tests/models/test_sum_model.py | 59 ++++++++ torch_sim/models/interface.py | 32 +++-- torch_sim/models/polarization.py | 220 ++++++++++++------------------ 4 files changed, 199 insertions(+), 171 deletions(-) diff --git a/tests/models/test_polarization.py b/tests/models/test_polarization.py index 9f1333e4..ee1e9991 100644 --- a/tests/models/test_polarization.py +++ b/tests/models/test_polarization.py @@ -13,16 +13,16 @@ class DummyPolarResponseModel(ModelInterface): def __init__( self, *, - polarization_key: str = "polarization", include_born_effective_charges: bool = True, include_polarizability: bool = True, + include_total_polarization: bool = True, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE, ) -> None: super().__init__() - self.polarization_key = polarization_key self.include_born_effective_charges = include_born_effective_charges self.include_polarizability = include_polarizability + self.include_total_polarization = include_total_polarization self._device = device self._dtype = dtype self._compute_forces = True @@ -52,8 +52,9 @@ def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tenso "energy": energy, "forces": forces, "stress": stress, - self.polarization_key: polarization, } + if self.include_total_polarization: + output["total_polarization"] = polarization if self.include_polarizability: diag = torch.tensor([1.0, 2.0, 3.0], device=state.device, dtype=state.dtype) output["polarizability"] = torch.diag_embed(diag.repeat(state.n_systems, 1)) @@ -68,7 +69,7 @@ def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tenso return output -def test_polarization_model_normalizes_raw_key_without_field( +def test_polarization_model_requires_external_e_field( si_double_sim_state: ts.SimState, ) -> None: base_model = DummyPolarResponseModel() @@ -77,18 +78,8 @@ def test_polarization_model_normalizes_raw_key_without_field( UniformPolarizationModel(device=DEVICE, dtype=DTYPE), ) - base_output = base_model(si_double_sim_state) - combined_output = combined_model(si_double_sim_state) - - torch.testing.assert_close(combined_output["energy"], base_output["energy"]) - torch.testing.assert_close(combined_output["forces"], base_output["forces"]) - torch.testing.assert_close(combined_output["stress"], base_output["stress"]) - torch.testing.assert_close( - combined_output["total_polarization"], base_output["polarization"] - ) - torch.testing.assert_close( - combined_output["polarization"], base_output["polarization"] - ) + with pytest.raises(ValueError, match="external_E_field"): + combined_model(si_double_sim_state) def test_polarization_model_applies_linear_response_corrections( @@ -108,11 +99,11 @@ def test_polarization_model_applies_linear_response_corrections( base_output = base_model(state) combined_output = combined_model(state) - expected_polarization = base_output["polarization"] + torch.einsum( + expected_polarization = base_output["total_polarization"] + torch.einsum( "sij,sj->si", base_output["polarizability"], field ) expected_energy = base_output["energy"] - torch.einsum( - "si,si->s", field, base_output["polarization"] + "si,si->s", field, base_output["total_polarization"] ) expected_energy = expected_energy - 0.5 * torch.einsum( "si,sij,sj->s", field, base_output["polarizability"], field @@ -128,16 +119,13 @@ def test_polarization_model_applies_linear_response_corrections( torch.testing.assert_close( combined_output["total_polarization"], expected_polarization ) - torch.testing.assert_close( - combined_output["polarization"], base_output["polarization"] - ) torch.testing.assert_close(combined_output["stress"], base_output["stress"]) -def test_polarization_model_adds_only_delta_for_blessed_name( +def test_polarization_model_returns_additive_total_polarization_delta( si_double_sim_state: ts.SimState, ) -> None: - base_model = DummyPolarResponseModel(polarization_key="total_polarization") + base_model = DummyPolarResponseModel() combined_model = SerialSumModel( base_model, UniformPolarizationModel(device=DEVICE, dtype=DTYPE), @@ -154,7 +142,11 @@ def test_polarization_model_adds_only_delta_for_blessed_name( torch.testing.assert_close( combined_output["total_polarization"], expected_total_polarization ) - assert "polarization" not in combined_output + correction_output = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)(state) + torch.testing.assert_close( + correction_output["total_polarization"], + expected_total_polarization, + ) def test_polarization_model_requires_born_effective_charges_for_force_correction( @@ -176,6 +168,25 @@ def test_polarization_model_requires_born_effective_charges_for_force_correction combined_model(state) +def test_polarization_model_requires_total_polarization( + si_double_sim_state: ts.SimState, +) -> None: + base_model = DummyPolarResponseModel(include_total_polarization=False) + combined_model = SerialSumModel( + base_model, + UniformPolarizationModel(device=DEVICE, dtype=DTYPE), + ) + state = ts.SimState.from_state( + si_double_sim_state, + external_E_field=torch.ones( + si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE + ), + ) + + with pytest.raises(ValueError, match="total_polarization"): + combined_model(state) + + def test_polarization_model_rejects_non_uniform_field_shape( si_double_sim_state: ts.SimState, ) -> None: diff --git a/tests/models/test_sum_model.py b/tests/models/test_sum_model.py index ae575ad4..4fb8d2ca 100644 --- a/tests/models/test_sum_model.py +++ b/tests/models/test_sum_model.py @@ -47,6 +47,38 @@ def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tenso return {"energy": energy} +class OverwriteExtrasModel(ModelInterface): + def __init__( + self, + value: float, + device: torch.device = DEVICE, + dtype: torch.dtype = DTYPE, + ) -> None: + super().__init__() + self.value = value + self._device = device + self._dtype = dtype + self._compute_stress = False + self._compute_forces = False + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + return { + "energy": torch.full( + (state.n_systems,), + self.value, + device=state.device, + dtype=state.dtype, + ), + "label": torch.full( + (state.n_systems, 3), + self.value, + device=state.device, + dtype=state.dtype, + ), + } + + @pytest.fixture def lj_model_a() -> LennardJonesModel: return LennardJonesModel( @@ -184,3 +216,30 @@ def test_serial_sum_model_exposes_extras_to_later_models( torch.testing.assert_close(output["latent"], expected_latent) torch.testing.assert_close(output["energy"], expected_energy) assert not state.has_extras("latent") + + +def test_serial_sum_model_overwrites_noncanonical_outputs( + si_double_sim_state: ts.SimState, +) -> None: + model = SerialSumModel(OverwriteExtrasModel(1.0), OverwriteExtrasModel(2.0)) + + output = model(si_double_sim_state) + + torch.testing.assert_close( + output["energy"], + torch.full( + (si_double_sim_state.n_systems,), + 3.0, + device=si_double_sim_state.device, + dtype=si_double_sim_state.dtype, + ), + ) + torch.testing.assert_close( + output["label"], + torch.full( + (si_double_sim_state.n_systems, 3), + 2.0, + device=si_double_sim_state.device, + dtype=si_double_sim_state.dtype, + ), + ) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 999152c8..7004216a 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -51,13 +51,19 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): "n_edges": 2, } +_ADDITIVE_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) + def _accumulate_model_output( combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor] ) -> None: - """Accumulate one model output into a combined output dict.""" + """Accumulate one model output into a combined output dict. + + Canonical mechanical outputs are additive. Other outputs are treated as + full updated values, so later models replace earlier ones. + """ for key, tensor in output.items(): - if key in combined: + if key in combined and key in _ADDITIVE_MODEL_KEYS: combined[key] = combined[key] + tensor else: combined[key] = tensor @@ -199,11 +205,12 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: class SumModel(ModelInterface): """Additive composition of multiple :class:`ModelInterface` models. - Calls each child model's :meth:`forward` and sums the output tensors - key-by-key, so energies, forces, and stresses are combined additively. - This is the standard way to layer a dispersion correction (e.g. DFT-D3), - an Ewald electrostatic term, or a local pair potential on top of a primary - machine-learning potential. + Calls each child model's :meth:`forward`. Canonical mechanical outputs + (energy, forces, stress) are combined additively, while non-canonical + outputs are treated as full updated values and later models replace + earlier ones. This is the standard way to layer a dispersion correction + (e.g. DFT-D3), an Ewald electrostatic term, or a local pair potential on + top of a primary machine-learning potential. Args: models: Two or more :class:`ModelInterface` instances that share the @@ -298,8 +305,9 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Sum the outputs of all child models. Each child model is called with the same ``state`` and ``**kwargs``. - Output tensors that appear in multiple children are summed element-wise; - keys unique to a single child are passed through unchanged. + Canonical mechanical outputs that appear in multiple children are + summed element-wise. Non-canonical outputs are replaced by later + models so they behave like full state updates rather than deltas. Args: state: Simulation state (see :class:`ModelInterface`). @@ -322,9 +330,9 @@ class SerialSumModel(SumModel): Instead, each child runs after the previous child's non-canonical outputs have been stored into a cloned :class:`~torch_sim.state.SimState` via :meth:`torch_sim.state.SimState.store_model_extras`. This lets earlier models - expose per-atom or per-system features that later models can consume, while - energies, forces, stresses, and any repeated auxiliary outputs are still summed - key-by-key. + expose per-atom or per-system features that later models can consume. + Energies, forces, and stresses remain additive, while repeated auxiliary + outputs are treated as full updated values from the latest stage. Examples: ```py diff --git a/torch_sim/models/polarization.py b/torch_sim/models/polarization.py index caa22392..8b86086b 100644 --- a/torch_sim/models/polarization.py +++ b/torch_sim/models/polarization.py @@ -1,45 +1,25 @@ """Electric-field corrections for polarization-aware models.""" -from typing import Protocol - import torch from torch_sim.models.interface import ModelInterface +from torch_sim.state import SimState from torch_sim.typing import AtomExtras, SystemExtras -RAW_POLARIZATION_KEY = "polarization" - +class UniformPolarizationModel(ModelInterface): + """Calculates the energy and force contributions from the application + of a constant electric field to a polarizable system. -class _PolarizationState(Protocol): - n_systems: int - device: torch.device - dtype: torch.dtype - positions: torch.Tensor - system_idx: torch.Tensor + This model is intended to run after an upstream model inside + :class:`~torch_sim.models.interface.SerialSumModel`. + Required state extras: -class UniformPolarizationModel(ModelInterface): - """Apply constant electric-field corrections from polarization response tensors. - - This model is designed to run after an upstream model inside a - :class:`~torch_sim.models.interface.SerialSumModel`. The upstream model is - expected to populate zero-field response tensors on the state via - ``store_model_extras()``, such as ``polarization``/``total_polarization``, - ``polarizability``, and ``born_effective_charges``. - - The returned canonical outputs are additive corrections: - - * ``energy``: ``-E·P0 - 0.5 E·alpha·E`` - * ``forces``: ``e Z*·E`` - * ``stress``: zero tensor, so upstream stress is preserved unchanged - - The model also emits ``total_polarization`` using TorchSim's blessed naming. - When the upstream model already emitted ``total_polarization``, this model - returns only the field-induced increment so that additive composition yields - the final corrected polarization. When the upstream model only emitted the - legacy ``polarization`` key, this model returns the full corrected - ``total_polarization`` tensor. + * ``external_E_field`` + * ``total_polarization`` + * ``polarizability`` + * ``born_effective_charges`` when ``compute_forces`` is enabled """ def __init__( @@ -80,125 +60,95 @@ def retain_graph(self, value: bool) -> None: """Set whether outputs should remain attached to the autograd graph.""" self._retain_graph = value - @staticmethod - def _polarization_from_state( - state: _PolarizationState, - ) -> tuple[torch.Tensor | None, bool]: - """Read polarization from blessed or legacy state extras.""" - total_polarization = getattr(state, SystemExtras.TOTAL_POLARIZATION.value, None) - if total_polarization is not None: - return total_polarization, True - raw_polarization = getattr(state, RAW_POLARIZATION_KEY, None) - if raw_polarization is not None: - return raw_polarization, False - return None, False - - def _build_output(self, state: _PolarizationState) -> dict[str, torch.Tensor]: - """Allocate additive outputs for energy and optional corrections.""" - output: dict[str, torch.Tensor] = { - "energy": torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + def _finalize_output( + self, output: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Detach outputs unless graph retention is enabled.""" + if self.retain_graph: + return output + return { + key: val.detach() if isinstance(val, torch.Tensor) else val + for key, val in output.items() } - if self.compute_forces: - output["forces"] = torch.zeros_like(state.positions) - if self.compute_stress: - # V1 intentionally applies no field-induced stress correction. - output["stress"] = torch.zeros( - state.n_systems, 3, 3, device=state.device, dtype=state.dtype - ) - return output - def _apply_field_corrections( + def _apply_nonzero_field( self, - state: _PolarizationState, + state: SimState, output: dict[str, torch.Tensor], field: torch.Tensor, - base_polarization: torch.Tensor | None, - *, - using_blessed_name: bool, ) -> None: - """Apply nonzero uniform-field energy, force, and polarization updates.""" - if base_polarization is None: - required_key = ( - SystemExtras.TOTAL_POLARIZATION.value - if using_blessed_name - else RAW_POLARIZATION_KEY - ) - raise ValueError( - f"UniformPolarizationModel requires '{required_key}' on the state " - "when external_E_field is non-zero" - ) - polarizability = getattr(state, SystemExtras.POLARIZABILITY.value, None) - if polarizability is None: + """Apply constant-field linear-response corrections. + + Computes the additive updates + - delta_energy = -E·P0 - 1/2 E·alpha·E + - total_polarization = P0 + alpha·E + - delta_forces = Z*·E + """ + required_keys = [ + SystemExtras.TOTAL_POLARIZATION.value, + SystemExtras.POLARIZABILITY.value, + ] + if self.compute_forces: + required_keys.append(AtomExtras.BORN_EFFECTIVE_CHARGES.value) + + missing_keys = [key for key in required_keys if not state.has_extras(key)] + if missing_keys: + missing = ", ".join(f"'{key}'" for key in missing_keys) raise ValueError( - "UniformPolarizationModel requires 'polarizability' on the state " + f"UniformPolarizationModel requires {missing} on the state " "when external_E_field is non-zero" ) - output["energy"] = output["energy"] - torch.einsum( - "si,si->s", field, base_polarization - ) - output["energy"] = output["energy"] - 0.5 * torch.einsum( - "si,sij,sj->s", field, polarizability, field + + dipole_coupling = torch.einsum("si,si->s", field, state.total_polarization) + polarization_response = torch.einsum( + "si,sij,sj->s", field, state.polarizability, field ) - polarization_delta = torch.einsum("sij,sj->si", polarizability, field) - if using_blessed_name: - output[SystemExtras.TOTAL_POLARIZATION.value] = polarization_delta - else: - output[SystemExtras.TOTAL_POLARIZATION.value] = ( - base_polarization + polarization_delta + output["energy"] = -dipole_coupling - 0.5 * polarization_response + output[SystemExtras.TOTAL_POLARIZATION.value] = ( + torch.einsum( + "sij,sj->si", + state.polarizability, + field, ) - if not self.compute_forces: - return - born_effective_charges = getattr( - state, AtomExtras.BORN_EFFECTIVE_CHARGES.value, None + + state.total_polarization ) - if born_effective_charges is None: - raise ValueError( - "UniformPolarizationModel requires 'born_effective_charges' on " - "the state when external_E_field is non-zero" + if self.compute_forces: + output["forces"] = torch.einsum( + "imn,im->in", + state.born_effective_charges, + field[state.system_idx], ) - output["forces"] = output["forces"] + torch.einsum( - "imn,im->in", born_effective_charges, field[state.system_idx] - ) - def forward( - self, - state: _PolarizationState, - **kwargs, - ) -> dict[str, torch.Tensor]: + def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """Return additive uniform-field corrections for a polarization model.""" del kwargs - output = self._build_output(state) - - field = getattr(state, SystemExtras.EXTERNAL_E_FIELD.value, None) - base_polarization, using_blessed_name = self._polarization_from_state(state) - if field is None: - if base_polarization is not None and not using_blessed_name: - output[SystemExtras.TOTAL_POLARIZATION.value] = base_polarization - else: - if field.shape != (state.n_systems, 3): - raise ValueError( - "UniformPolarizationModel requires external_E_field to have shape " - "(n_systems, 3)" - ) - if base_polarization is not None and not using_blessed_name: - output[SystemExtras.TOTAL_POLARIZATION.value] = base_polarization - if torch.any(field != 0): - self._apply_field_corrections( - state, - output, - field, - base_polarization, - using_blessed_name=using_blessed_name, - ) - - if not self.retain_graph: - output = { - key: val.detach() if isinstance(val, torch.Tensor) else val - for key, val in output.items() - } - return output - - -# TODO: Add a spatially-varying polarization model. + output: dict[str, torch.Tensor] = { + "energy": torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + } + if self.compute_forces: + output["forces"] = torch.zeros_like(state.positions) + if self.compute_stress: + # V1 intentionally applies no field-induced stress correction. + output["stress"] = torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ) + + if not state.has_extras(SystemExtras.EXTERNAL_E_FIELD.value): + raise ValueError( + "UniformPolarizationModel requires 'external_E_field' on the state" + ) + + field = getattr(state, SystemExtras.EXTERNAL_E_FIELD.value) + if field.shape != (state.n_systems, 3): + raise ValueError( + "UniformPolarizationModel requires external_E_field to have shape " + "(n_systems, 3)" + ) + if not torch.any(field != 0): + return self._finalize_output(output) + + self._apply_nonzero_field(state, output, field) + return self._finalize_output(output) + __all__ = ["UniformPolarizationModel"] From 8840c5ec3b3bee0e239b4fe39d25af5f0c9b3658 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 8 Apr 2026 21:11:22 -0400 Subject: [PATCH 3/3] fix --- tests/models/test_polarization.py | 6 +++++- torch_sim/models/interface.py | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/models/test_polarization.py b/tests/models/test_polarization.py index ee1e9991..2fc71f96 100644 --- a/tests/models/test_polarization.py +++ b/tests/models/test_polarization.py @@ -142,7 +142,11 @@ def test_polarization_model_returns_additive_total_polarization_delta( torch.testing.assert_close( combined_output["total_polarization"], expected_total_polarization ) - correction_output = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)(state) + serialized_state = state.clone() + serialized_state.store_model_extras(base_output) + correction_output = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)( + serialized_state + ) torch.testing.assert_close( correction_output["total_polarization"], expected_total_polarization, diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 7004216a..6741de34 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -34,6 +34,7 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): import torch import torch_sim as ts +from torch_sim.state import _CANONICAL_MODEL_KEYS if TYPE_CHECKING: @@ -51,8 +52,6 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): "n_edges": 2, } -_ADDITIVE_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) - def _accumulate_model_output( combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor] @@ -63,7 +62,7 @@ def _accumulate_model_output( full updated values, so later models replace earlier ones. """ for key, tensor in output.items(): - if key in combined and key in _ADDITIVE_MODEL_KEYS: + if key in combined and key in _CANONICAL_MODEL_KEYS: combined[key] = combined[key] + tensor else: combined[key] = tensor