Skip to content

Commit bcc58c2

Browse files
committed
fea: add polarization model
1 parent 560b27f commit bcc58c2

4 files changed

Lines changed: 522 additions & 6 deletions

File tree

tests/models/test_polarization.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""Tests for the polarization electric-field correction model."""
2+
3+
import pytest
4+
import torch
5+
6+
import torch_sim as ts
7+
from tests.conftest import DEVICE, DTYPE
8+
from torch_sim.models.interface import ModelInterface, SerialSumModel
9+
from torch_sim.models.polarization import UniformPolarizationModel
10+
11+
12+
class DummyPolarResponseModel(ModelInterface):
13+
def __init__(
14+
self,
15+
*,
16+
polarization_key: str = "polarization",
17+
include_born_effective_charges: bool = True,
18+
include_polarizability: bool = True,
19+
device: torch.device = DEVICE,
20+
dtype: torch.dtype = DTYPE,
21+
) -> None:
22+
super().__init__()
23+
self.polarization_key = polarization_key
24+
self.include_born_effective_charges = include_born_effective_charges
25+
self.include_polarizability = include_polarizability
26+
self._device = device
27+
self._dtype = dtype
28+
self._compute_forces = True
29+
self._compute_stress = True
30+
31+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
32+
del kwargs
33+
energy = torch.arange(
34+
1, state.n_systems + 1, device=state.device, dtype=state.dtype
35+
)
36+
forces = (
37+
torch.arange(state.n_atoms * 3, device=state.device, dtype=state.dtype)
38+
.reshape(state.n_atoms, 3)
39+
.div(10.0)
40+
)
41+
stress = (
42+
torch.arange(state.n_systems * 9, device=state.device, dtype=state.dtype)
43+
.reshape(state.n_systems, 3, 3)
44+
.div(100.0)
45+
)
46+
polarization = (
47+
torch.arange(state.n_systems * 3, device=state.device, dtype=state.dtype)
48+
.reshape(state.n_systems, 3)
49+
.add(0.5)
50+
)
51+
output: dict[str, torch.Tensor] = {
52+
"energy": energy,
53+
"forces": forces,
54+
"stress": stress,
55+
self.polarization_key: polarization,
56+
}
57+
if self.include_polarizability:
58+
diag = torch.tensor([1.0, 2.0, 3.0], device=state.device, dtype=state.dtype)
59+
output["polarizability"] = torch.diag_embed(diag.repeat(state.n_systems, 1))
60+
if self.include_born_effective_charges:
61+
born_effective_charges = torch.zeros(
62+
state.n_atoms, 3, 3, device=state.device, dtype=state.dtype
63+
)
64+
born_effective_charges[:, 0, 0] = 1.0
65+
born_effective_charges[:, 1, 1] = 2.0
66+
born_effective_charges[:, 2, 2] = 3.0
67+
output["born_effective_charges"] = born_effective_charges
68+
return output
69+
70+
71+
def test_polarization_model_normalizes_raw_key_without_field(
72+
si_double_sim_state: ts.SimState,
73+
) -> None:
74+
base_model = DummyPolarResponseModel()
75+
combined_model = SerialSumModel(
76+
base_model,
77+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
78+
)
79+
80+
base_output = base_model(si_double_sim_state)
81+
combined_output = combined_model(si_double_sim_state)
82+
83+
torch.testing.assert_close(combined_output["energy"], base_output["energy"])
84+
torch.testing.assert_close(combined_output["forces"], base_output["forces"])
85+
torch.testing.assert_close(combined_output["stress"], base_output["stress"])
86+
torch.testing.assert_close(
87+
combined_output["total_polarization"], base_output["polarization"]
88+
)
89+
torch.testing.assert_close(
90+
combined_output["polarization"], base_output["polarization"]
91+
)
92+
93+
94+
def test_polarization_model_applies_linear_response_corrections(
95+
si_double_sim_state: ts.SimState,
96+
) -> None:
97+
base_model = DummyPolarResponseModel()
98+
combined_model = SerialSumModel(
99+
base_model,
100+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
101+
)
102+
field = torch.tensor(
103+
[[0.2, -0.1, 0.05], [-0.3, 0.4, 0.1]],
104+
device=DEVICE,
105+
dtype=DTYPE,
106+
)
107+
state = ts.SimState.from_state(si_double_sim_state, external_E_field=field)
108+
109+
base_output = base_model(state)
110+
combined_output = combined_model(state)
111+
expected_polarization = base_output["polarization"] + torch.einsum(
112+
"sij,sj->si", base_output["polarizability"], field
113+
)
114+
expected_energy = base_output["energy"] - torch.einsum(
115+
"si,si->s", field, base_output["polarization"]
116+
)
117+
expected_energy = expected_energy - 0.5 * torch.einsum(
118+
"si,sij,sj->s", field, base_output["polarizability"], field
119+
)
120+
expected_forces = base_output["forces"] + torch.einsum(
121+
"imn,im->in",
122+
base_output["born_effective_charges"],
123+
field[state.system_idx],
124+
)
125+
126+
torch.testing.assert_close(combined_output["energy"], expected_energy)
127+
torch.testing.assert_close(combined_output["forces"], expected_forces)
128+
torch.testing.assert_close(
129+
combined_output["total_polarization"], expected_polarization
130+
)
131+
torch.testing.assert_close(
132+
combined_output["polarization"], base_output["polarization"]
133+
)
134+
torch.testing.assert_close(combined_output["stress"], base_output["stress"])
135+
136+
137+
def test_polarization_model_adds_only_delta_for_blessed_name(
138+
si_double_sim_state: ts.SimState,
139+
) -> None:
140+
base_model = DummyPolarResponseModel(polarization_key="total_polarization")
141+
combined_model = SerialSumModel(
142+
base_model,
143+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
144+
)
145+
field = torch.tensor([[0.1, 0.0, 0.0], [0.0, -0.2, 0.3]], device=DEVICE, dtype=DTYPE)
146+
state = ts.SimState.from_state(si_double_sim_state, external_E_field=field)
147+
148+
base_output = base_model(state)
149+
combined_output = combined_model(state)
150+
expected_total_polarization = base_output["total_polarization"] + torch.einsum(
151+
"sij,sj->si", base_output["polarizability"], field
152+
)
153+
154+
torch.testing.assert_close(
155+
combined_output["total_polarization"], expected_total_polarization
156+
)
157+
assert "polarization" not in combined_output
158+
159+
160+
def test_polarization_model_requires_born_effective_charges_for_force_correction(
161+
si_double_sim_state: ts.SimState,
162+
) -> None:
163+
base_model = DummyPolarResponseModel(include_born_effective_charges=False)
164+
combined_model = SerialSumModel(
165+
base_model,
166+
UniformPolarizationModel(device=DEVICE, dtype=DTYPE),
167+
)
168+
state = ts.SimState.from_state(
169+
si_double_sim_state,
170+
external_E_field=torch.ones(
171+
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
172+
),
173+
)
174+
175+
with pytest.raises(ValueError, match="born_effective_charges"):
176+
combined_model(state)
177+
178+
179+
def test_polarization_model_rejects_non_uniform_field_shape(
180+
si_double_sim_state: ts.SimState,
181+
) -> None:
182+
state = ts.SimState.from_state(
183+
si_double_sim_state,
184+
external_E_field=torch.zeros(
185+
si_double_sim_state.n_systems, 3, device=DEVICE, dtype=DTYPE
186+
),
187+
)
188+
state._system_extras["external_E_field"] = torch.zeros( # noqa: SLF001
189+
state.n_atoms, 3, device=DEVICE, dtype=DTYPE
190+
)
191+
model = UniformPolarizationModel(device=DEVICE, dtype=DTYPE)
192+
193+
with pytest.raises(ValueError, match="shape \\(n_systems, 3\\)"):
194+
model(state)

tests/models/test_sum_model.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,47 @@
66
import torch_sim as ts
77
from tests.conftest import DEVICE, DTYPE
88
from tests.models.conftest import make_validate_model_outputs_test
9-
from torch_sim.models.interface import SumModel
9+
from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel
1010
from torch_sim.models.lennard_jones import LennardJonesModel
1111
from torch_sim.models.morse import MorseModel
1212

1313

14+
class ExtraProducerModel(ModelInterface):
15+
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
16+
super().__init__()
17+
self._device = device
18+
self._dtype = dtype
19+
self._compute_stress = False
20+
self._compute_forces = False
21+
22+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
23+
del kwargs
24+
latent = state.positions[:, 0] + 2.0
25+
return {
26+
"energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype),
27+
"latent": latent,
28+
}
29+
30+
31+
class ExtraConsumerModel(ModelInterface):
32+
seen_latent: torch.Tensor | None
33+
34+
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
35+
super().__init__()
36+
self._device = device
37+
self._dtype = dtype
38+
self._compute_stress = False
39+
self._compute_forces = False
40+
self.seen_latent = None
41+
42+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
43+
del kwargs
44+
self.seen_latent = state.latent.clone()
45+
energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype)
46+
energy.scatter_add_(0, state.system_idx, state.latent)
47+
return {"energy": energy}
48+
49+
1450
@pytest.fixture
1551
def lj_model_a() -> LennardJonesModel:
1652
return LennardJonesModel(
@@ -43,9 +79,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode
4379
return SumModel(lj_model_a, morse_model)
4480

4581

82+
@pytest.fixture
83+
def serial_sum_model(
84+
lj_model_a: LennardJonesModel, morse_model: MorseModel
85+
) -> SerialSumModel:
86+
return SerialSumModel(lj_model_a, morse_model)
87+
88+
4689
test_sum_model_outputs = make_validate_model_outputs_test(
4790
model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE
4891
)
92+
test_serial_sum_model_outputs = make_validate_model_outputs_test(
93+
model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE
94+
)
4995

5096

5197
def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None:
@@ -102,3 +148,39 @@ def test_sum_model_retain_graph(
102148
assert lj_model_a.retain_graph is True
103149
assert morse_model.retain_graph is True
104150
assert sm.retain_graph is True
151+
152+
153+
def test_serial_sum_model_matches_parallel_sum_for_independent_models(
154+
lj_model_a: LennardJonesModel,
155+
morse_model: MorseModel,
156+
si_sim_state: ts.SimState,
157+
) -> None:
158+
sum_out = SumModel(lj_model_a, morse_model)(si_sim_state)
159+
serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state)
160+
torch.testing.assert_close(serial_out["energy"], sum_out["energy"])
161+
torch.testing.assert_close(serial_out["forces"], sum_out["forces"])
162+
torch.testing.assert_close(serial_out["stress"], sum_out["stress"])
163+
164+
165+
def test_serial_sum_model_exposes_extras_to_later_models(
166+
si_double_sim_state: ts.SimState,
167+
) -> None:
168+
producer = ExtraProducerModel()
169+
consumer = ExtraConsumerModel()
170+
serial_model = SerialSumModel(producer, consumer)
171+
state = si_double_sim_state.clone()
172+
expected_latent = state.positions[:, 0] + 2.0
173+
expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype)
174+
expected_energy = expected_energy.scatter_add(
175+
0,
176+
state.system_idx,
177+
expected_latent,
178+
)
179+
180+
output = serial_model(state)
181+
182+
assert consumer.seen_latent is not None
183+
torch.testing.assert_close(consumer.seen_latent, expected_latent)
184+
torch.testing.assert_close(output["latent"], expected_latent)
185+
torch.testing.assert_close(output["energy"], expected_energy)
186+
assert not state.has_extras("latent")

torch_sim/models/interface.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs):
5252
}
5353

5454

55+
def _accumulate_model_output(
56+
combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor]
57+
) -> None:
58+
"""Accumulate one model output into a combined output dict."""
59+
for key, tensor in output.items():
60+
if key in combined:
61+
combined[key] = combined[key] + tensor
62+
else:
63+
combined[key] = tensor
64+
65+
5566
class ModelInterface(torch.nn.Module, ABC):
5667
"""Abstract base class for all simulation models in TorchSim.
5768
@@ -300,11 +311,36 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
300311
combined: dict[str, torch.Tensor] = {}
301312
for model in self._children():
302313
output = model(state, **kwargs)
303-
for key, tensor in output.items():
304-
if key in combined:
305-
combined[key] = combined[key] + tensor
306-
else:
307-
combined[key] = tensor
314+
_accumulate_model_output(combined, output)
315+
return combined
316+
317+
318+
class SerialSumModel(SumModel):
319+
"""Serial additive composition of multiple :class:`ModelInterface` models.
320+
321+
Unlike :class:`SumModel`, child models do not all see the same input state.
322+
Instead, each child runs after the previous child's non-canonical outputs have
323+
been stored into a cloned :class:`~torch_sim.state.SimState` via
324+
:meth:`torch_sim.state.SimState.store_model_extras`. This lets earlier models
325+
expose per-atom or per-system features that later models can consume, while
326+
energies, forces, stresses, and any repeated auxiliary outputs are still summed
327+
key-by-key.
328+
329+
Examples:
330+
```py
331+
serial_model = SerialSumModel(polarization_model, dispersion_model)
332+
output = serial_model(sim_state)
333+
```
334+
"""
335+
336+
def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
337+
"""Run child models serially, exposing extras from earlier models."""
338+
combined: dict[str, torch.Tensor] = {}
339+
serial_state = state.clone()
340+
for model in self._children():
341+
output = model(serial_state, **kwargs)
342+
_accumulate_model_output(combined, output)
343+
serial_state.store_model_extras(output)
308344
return combined
309345

310346

0 commit comments

Comments
 (0)