Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions tests/models/test_polarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""Tests for the polarization electric-field correction model."""

import pytest
import torch

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from torch_sim.models.interface import ModelInterface, SerialSumModel
from torch_sim.models.polarization import UniformPolarizationModel


class DummyPolarResponseModel(ModelInterface):
def __init__(
self,
*,
include_born_effective_charges: bool = True,
include_polarizability: bool = True,
include_total_polarization: bool = True,
device: torch.device = DEVICE,
dtype: torch.dtype = DTYPE,
) -> None:
super().__init__()
self.include_born_effective_charges = include_born_effective_charges
self.include_polarizability = include_polarizability
self.include_total_polarization = include_total_polarization
self._device = device
self._dtype = dtype
self._compute_forces = True
self._compute_stress = True

def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
del kwargs
energy = torch.arange(
1, state.n_systems + 1, device=state.device, dtype=state.dtype
)
forces = (
torch.arange(state.n_atoms * 3, device=state.device, dtype=state.dtype)
.reshape(state.n_atoms, 3)
.div(10.0)
)
stress = (
torch.arange(state.n_systems * 9, device=state.device, dtype=state.dtype)
.reshape(state.n_systems, 3, 3)
.div(100.0)
)
polarization = (
torch.arange(state.n_systems * 3, device=state.device, dtype=state.dtype)
.reshape(state.n_systems, 3)
.add(0.5)
)
output: dict[str, torch.Tensor] = {
"energy": energy,
"forces": forces,
"stress": stress,
}
if self.include_total_polarization:
output["total_polarization"] = polarization
if self.include_polarizability:
diag = torch.tensor([1.0, 2.0, 3.0], device=state.device, dtype=state.dtype)
output["polarizability"] = torch.diag_embed(diag.repeat(state.n_systems, 1))
if self.include_born_effective_charges:
born_effective_charges = torch.zeros(
state.n_atoms, 3, 3, device=state.device, dtype=state.dtype
)
born_effective_charges[:, 0, 0] = 1.0
born_effective_charges[:, 1, 1] = 2.0
born_effective_charges[:, 2, 2] = 3.0
output["born_effective_charges"] = born_effective_charges
return output


def test_polarization_model_requires_external_e_field(
si_double_sim_state: ts.SimState,
) -> None:
base_model = DummyPolarResponseModel()
combined_model = SerialSumModel(
base_model,
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
)

with pytest.raises(ValueError, match="external_E_field"):
combined_model(si_double_sim_state)


def test_polarization_model_applies_linear_response_corrections(
si_double_sim_state: ts.SimState,
) -> None:
base_model = DummyPolarResponseModel()
combined_model = SerialSumModel(
base_model,
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
)
field = torch.tensor(
[[0.2, -0.1, 0.05], [-0.3, 0.4, 0.1]],
device=DEVICE,
dtype=DTYPE,
)
state = ts.SimState.from_state(si_double_sim_state, external_E_field=field)

base_output = base_model(state)
combined_output = combined_model(state)
expected_polarization = base_output["total_polarization"] + torch.einsum(
"sij,sj->si", base_output["polarizability"], field
)
expected_energy = base_output["energy"] - torch.einsum(
"si,si->s", field, base_output["total_polarization"]
)
expected_energy = expected_energy - 0.5 * torch.einsum(
"si,sij,sj->s", field, base_output["polarizability"], field
)
expected_forces = base_output["forces"] + torch.einsum(
"imn,im->in",
base_output["born_effective_charges"],
field[state.system_idx],
)

torch.testing.assert_close(combined_output["energy"], expected_energy)
torch.testing.assert_close(combined_output["forces"], expected_forces)
torch.testing.assert_close(
combined_output["total_polarization"], expected_polarization
)
torch.testing.assert_close(combined_output["stress"], base_output["stress"])


def test_polarization_model_returns_additive_total_polarization_delta(
si_double_sim_state: ts.SimState,
) -> None:
base_model = DummyPolarResponseModel()
combined_model = SerialSumModel(
base_model,
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
)
field = torch.tensor([[0.1, 0.0, 0.0], [0.0, -0.2, 0.3]], device=DEVICE, dtype=DTYPE)
state = ts.SimState.from_state(si_double_sim_state, external_E_field=field)

base_output = base_model(state)
combined_output = combined_model(state)
expected_total_polarization = base_output["total_polarization"] + torch.einsum(
"sij,sj->si", base_output["polarizability"], field
)

torch.testing.assert_close(
combined_output["total_polarization"], expected_total_polarization
)
serialized_state = state.clone()
serialized_state.store_model_extras(base_output)
correction_output = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)(
serialized_state
)
torch.testing.assert_close(
correction_output["total_polarization"],
expected_total_polarization,
)


def test_polarization_model_requires_born_effective_charges_for_force_correction(
si_double_sim_state: ts.SimState,
) -> None:
base_model = DummyPolarResponseModel(include_born_effective_charges=False)
combined_model = SerialSumModel(
base_model,
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
)
state = ts.SimState.from_state(
si_double_sim_state,
external_E_field=torch.ones(
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
),
)

with pytest.raises(ValueError, match="born_effective_charges"):
combined_model(state)


def test_polarization_model_requires_total_polarization(
si_double_sim_state: ts.SimState,
) -> None:
base_model = DummyPolarResponseModel(include_total_polarization=False)
combined_model = SerialSumModel(
base_model,
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
)
state = ts.SimState.from_state(
si_double_sim_state,
external_E_field=torch.ones(
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
),
)

with pytest.raises(ValueError, match="total_polarization"):
combined_model(state)


def test_polarization_model_rejects_non_uniform_field_shape(
si_double_sim_state: ts.SimState,
) -> None:
state = ts.SimState.from_state(
si_double_sim_state,
external_E_field=torch.zeros(
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
),
)
state._system_extras["external_E_field"] = torch.zeros( # noqa: SLF001
state.n_atoms, 3, device=DEVICE, dtype=DTYPE
)
model = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)

with pytest.raises(ValueError, match="shape \\(n_systems, 3\\)"):
model(state)
143 changes: 142 additions & 1 deletion tests/models/test_sum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,79 @@
import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test
from torch_sim.models.interface import SumModel
from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel
from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.models.morse import MorseModel


class ExtraProducerModel(ModelInterface):
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
super().__init__()
self._device = device
self._dtype = dtype
self._compute_stress = False
self._compute_forces = False

def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
del kwargs
latent = state.positions[:, 0] + 2.0
return {
"energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype),
"latent": latent,
}


class ExtraConsumerModel(ModelInterface):
seen_latent: torch.Tensor | None

def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
super().__init__()
self._device = device
self._dtype = dtype
self._compute_stress = False
self._compute_forces = False
self.seen_latent = None

def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
del kwargs
self.seen_latent = state.latent.clone()
energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype)
energy.scatter_add_(0, state.system_idx, state.latent)
return {"energy": energy}


class OverwriteExtrasModel(ModelInterface):
def __init__(
self,
value: float,
device: torch.device = DEVICE,
dtype: torch.dtype = DTYPE,
) -> None:
super().__init__()
self.value = value
self._device = device
self._dtype = dtype
self._compute_stress = False
self._compute_forces = False

def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
del kwargs
return {
"energy": torch.full(
(state.n_systems,),
self.value,
device=state.device,
dtype=state.dtype,
),
"label": torch.full(
(state.n_systems, 3),
self.value,
device=state.device,
dtype=state.dtype,
),
}


@pytest.fixture
def lj_model_a() -> LennardJonesModel:
return LennardJonesModel(
Expand Down Expand Up @@ -43,9 +111,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode
return SumModel(lj_model_a, morse_model)


@pytest.fixture
def serial_sum_model(
lj_model_a: LennardJonesModel, morse_model: MorseModel
) -> SerialSumModel:
return SerialSumModel(lj_model_a, morse_model)


test_sum_model_outputs = make_validate_model_outputs_test(
model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE
)
test_serial_sum_model_outputs = make_validate_model_outputs_test(
model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE
)


def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None:
Expand Down Expand Up @@ -102,3 +180,66 @@ def test_sum_model_retain_graph(
assert lj_model_a.retain_graph is True
assert morse_model.retain_graph is True
assert sm.retain_graph is True


def test_serial_sum_model_matches_parallel_sum_for_independent_models(
lj_model_a: LennardJonesModel,
morse_model: MorseModel,
si_sim_state: ts.SimState,
) -> None:
sum_out = SumModel(lj_model_a, morse_model)(si_sim_state)
serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state)
torch.testing.assert_close(serial_out["energy"], sum_out["energy"])
torch.testing.assert_close(serial_out["forces"], sum_out["forces"])
torch.testing.assert_close(serial_out["stress"], sum_out["stress"])


def test_serial_sum_model_exposes_extras_to_later_models(
si_double_sim_state: ts.SimState,
) -> None:
producer = ExtraProducerModel()
consumer = ExtraConsumerModel()
serial_model = SerialSumModel(producer, consumer)
state = si_double_sim_state.clone()
expected_latent = state.positions[:, 0] + 2.0
expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype)
expected_energy = expected_energy.scatter_add(
0,
state.system_idx,
expected_latent,
)

output = serial_model(state)

assert consumer.seen_latent is not None
torch.testing.assert_close(consumer.seen_latent, expected_latent)
torch.testing.assert_close(output["latent"], expected_latent)
torch.testing.assert_close(output["energy"], expected_energy)
assert not state.has_extras("latent")


def test_serial_sum_model_overwrites_noncanonical_outputs(
si_double_sim_state: ts.SimState,
) -> None:
model = SerialSumModel(OverwriteExtrasModel(1.0), OverwriteExtrasModel(2.0))

output = model(si_double_sim_state)

torch.testing.assert_close(
output["energy"],
torch.full(
(si_double_sim_state.n_systems,),
3.0,
device=si_double_sim_state.device,
dtype=si_double_sim_state.dtype,
),
)
torch.testing.assert_close(
output["label"],
torch.full(
(si_double_sim_state.n_systems, 3),
2.0,
device=si_double_sim_state.device,
dtype=si_double_sim_state.dtype,
),
)
Loading
Loading