Skip to content

Commit 82c6af5

Browse files
committed
fea: bring in serial sum model
1 parent cea8137 commit 82c6af5

2 files changed

Lines changed: 124 additions & 6 deletions

File tree

tests/models/test_sum_model.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,47 @@
66
import torch_sim as ts
77
from tests.conftest import DEVICE, DTYPE
88
from tests.models.conftest import make_validate_model_outputs_test
9-
from torch_sim.models.interface import SumModel
9+
from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel
1010
from torch_sim.models.lennard_jones import LennardJonesModel
1111
from torch_sim.models.morse import MorseModel
1212

1313

14+
class ExtraProducerModel(ModelInterface):
15+
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
16+
super().__init__()
17+
self._device = device
18+
self._dtype = dtype
19+
self._compute_stress = False
20+
self._compute_forces = False
21+
22+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
23+
del kwargs
24+
latent = state.positions[:, 0] + 2.0
25+
return {
26+
"energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype),
27+
"latent": latent,
28+
}
29+
30+
31+
class ExtraConsumerModel(ModelInterface):
32+
seen_latent: torch.Tensor | None
33+
34+
def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None:
35+
super().__init__()
36+
self._device = device
37+
self._dtype = dtype
38+
self._compute_stress = False
39+
self._compute_forces = False
40+
self.seen_latent = None
41+
42+
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
43+
del kwargs
44+
self.seen_latent = state.latent.clone()
45+
energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype)
46+
energy.scatter_add_(0, state.system_idx, state.latent)
47+
return {"energy": energy}
48+
49+
1450
@pytest.fixture
1551
def lj_model_a() -> LennardJonesModel:
1652
return LennardJonesModel(
@@ -43,9 +79,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode
4379
return SumModel(lj_model_a, morse_model)
4480

4581

82+
@pytest.fixture
83+
def serial_sum_model(
84+
lj_model_a: LennardJonesModel, morse_model: MorseModel
85+
) -> SerialSumModel:
86+
return SerialSumModel(lj_model_a, morse_model)
87+
88+
4689
test_sum_model_outputs = make_validate_model_outputs_test(
4790
model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE
4891
)
92+
test_serial_sum_model_outputs = make_validate_model_outputs_test(
93+
model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE
94+
)
4995

5096

5197
def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None:
@@ -102,3 +148,39 @@ def test_sum_model_retain_graph(
102148
assert lj_model_a.retain_graph is True
103149
assert morse_model.retain_graph is True
104150
assert sm.retain_graph is True
151+
152+
153+
def test_serial_sum_model_matches_parallel_sum_for_independent_models(
154+
lj_model_a: LennardJonesModel,
155+
morse_model: MorseModel,
156+
si_sim_state: ts.SimState,
157+
) -> None:
158+
sum_out = SumModel(lj_model_a, morse_model)(si_sim_state)
159+
serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state)
160+
torch.testing.assert_close(serial_out["energy"], sum_out["energy"])
161+
torch.testing.assert_close(serial_out["forces"], sum_out["forces"])
162+
torch.testing.assert_close(serial_out["stress"], sum_out["stress"])
163+
164+
165+
def test_serial_sum_model_exposes_extras_to_later_models(
166+
si_double_sim_state: ts.SimState,
167+
) -> None:
168+
producer = ExtraProducerModel()
169+
consumer = ExtraConsumerModel()
170+
serial_model = SerialSumModel(producer, consumer)
171+
state = si_double_sim_state.clone()
172+
expected_latent = state.positions[:, 0] + 2.0
173+
expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype)
174+
expected_energy = expected_energy.scatter_add(
175+
0,
176+
state.system_idx,
177+
expected_latent,
178+
)
179+
180+
output = serial_model(state)
181+
182+
assert consumer.seen_latent is not None
183+
torch.testing.assert_close(consumer.seen_latent, expected_latent)
184+
torch.testing.assert_close(output["latent"], expected_latent)
185+
torch.testing.assert_close(output["energy"], expected_energy)
186+
assert not state.has_extras("latent")

torch_sim/models/interface.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs):
5252
}
5353

5454

55+
def _accumulate_model_output(
56+
combined: dict[str, torch.Tensor], output: dict[str, torch.Tensor]
57+
) -> None:
58+
"""Accumulate one model output into a combined output dict."""
59+
for key, tensor in output.items():
60+
if key in combined:
61+
combined[key] = combined[key] + tensor
62+
else:
63+
combined[key] = tensor
64+
65+
5566
class ModelInterface(torch.nn.Module, ABC):
5667
"""Abstract base class for all simulation models in TorchSim.
5768
@@ -300,11 +311,36 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
300311
combined: dict[str, torch.Tensor] = {}
301312
for model in self._children():
302313
output = model(state, **kwargs)
303-
for key, tensor in output.items():
304-
if key in combined:
305-
combined[key] = combined[key] + tensor
306-
else:
307-
combined[key] = tensor
314+
_accumulate_model_output(combined, output)
315+
return combined
316+
317+
318+
class SerialSumModel(SumModel):
319+
"""Serial additive composition of multiple :class:`ModelInterface` models.
320+
321+
Unlike :class:`SumModel`, child models do not all see the same input state.
322+
Instead, each child runs after the previous child's non-canonical outputs have
323+
been stored into a cloned :class:`~torch_sim.state.SimState` via
324+
:meth:`torch_sim.state.SimState.store_model_extras`. This lets earlier models
325+
expose per-atom or per-system features that later models can consume, while
326+
energies, forces, stresses, and any repeated auxiliary outputs are still summed
327+
key-by-key.
328+
329+
Examples:
330+
```py
331+
serial_model = SerialSumModel(polarization_model, dispersion_model)
332+
output = serial_model(sim_state)
333+
```
334+
"""
335+
336+
def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
337+
"""Run child models serially, exposing extras from earlier models."""
338+
combined: dict[str, torch.Tensor] = {}
339+
serial_state = state.clone()
340+
for model in self._children():
341+
output = model(serial_state, **kwargs)
342+
_accumulate_model_output(combined, output)
343+
serial_state.store_model_extras(output)
308344
return combined
309345

310346

0 commit comments

Comments
 (0)