|
6 | 6 | import torch_sim as ts |
7 | 7 | from tests.conftest import DEVICE, DTYPE |
8 | 8 | from tests.models.conftest import make_validate_model_outputs_test |
9 | | -from torch_sim.models.interface import SumModel |
| 9 | +from torch_sim.models.interface import ModelInterface, SerialSumModel, SumModel |
10 | 10 | from torch_sim.models.lennard_jones import LennardJonesModel |
11 | 11 | from torch_sim.models.morse import MorseModel |
12 | 12 |
|
13 | 13 |
|
| 14 | +class ExtraProducerModel(ModelInterface): |
| 15 | + def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None: |
| 16 | + super().__init__() |
| 17 | + self._device = device |
| 18 | + self._dtype = dtype |
| 19 | + self._compute_stress = False |
| 20 | + self._compute_forces = False |
| 21 | + |
| 22 | + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: |
| 23 | + del kwargs |
| 24 | + latent = state.positions[:, 0] + 2.0 |
| 25 | + return { |
| 26 | + "energy": torch.ones(state.n_systems, device=state.device, dtype=state.dtype), |
| 27 | + "latent": latent, |
| 28 | + } |
| 29 | + |
| 30 | + |
| 31 | +class ExtraConsumerModel(ModelInterface): |
| 32 | + seen_latent: torch.Tensor | None |
| 33 | + |
| 34 | + def __init__(self, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE) -> None: |
| 35 | + super().__init__() |
| 36 | + self._device = device |
| 37 | + self._dtype = dtype |
| 38 | + self._compute_stress = False |
| 39 | + self._compute_forces = False |
| 40 | + self.seen_latent = None |
| 41 | + |
| 42 | + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: |
| 43 | + del kwargs |
| 44 | + self.seen_latent = state.latent.clone() |
| 45 | + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) |
| 46 | + energy.scatter_add_(0, state.system_idx, state.latent) |
| 47 | + return {"energy": energy} |
| 48 | + |
| 49 | + |
14 | 50 | @pytest.fixture |
15 | 51 | def lj_model_a() -> LennardJonesModel: |
16 | 52 | return LennardJonesModel( |
@@ -43,9 +79,19 @@ def sum_model(lj_model_a: LennardJonesModel, morse_model: MorseModel) -> SumMode |
43 | 79 | return SumModel(lj_model_a, morse_model) |
44 | 80 |
|
45 | 81 |
|
| 82 | +@pytest.fixture |
| 83 | +def serial_sum_model( |
| 84 | + lj_model_a: LennardJonesModel, morse_model: MorseModel |
| 85 | +) -> SerialSumModel: |
| 86 | + return SerialSumModel(lj_model_a, morse_model) |
| 87 | + |
| 88 | + |
46 | 89 | test_sum_model_outputs = make_validate_model_outputs_test( |
47 | 90 | model_fixture_name="sum_model", device=DEVICE, dtype=DTYPE |
48 | 91 | ) |
| 92 | +test_serial_sum_model_outputs = make_validate_model_outputs_test( |
| 93 | + model_fixture_name="serial_sum_model", device=DEVICE, dtype=DTYPE |
| 94 | +) |
49 | 95 |
|
50 | 96 |
|
51 | 97 | def test_sum_model_requires_two_models(lj_model_a: LennardJonesModel) -> None: |
@@ -102,3 +148,39 @@ def test_sum_model_retain_graph( |
102 | 148 | assert lj_model_a.retain_graph is True |
103 | 149 | assert morse_model.retain_graph is True |
104 | 150 | assert sm.retain_graph is True |
| 151 | + |
| 152 | + |
| 153 | +def test_serial_sum_model_matches_parallel_sum_for_independent_models( |
| 154 | + lj_model_a: LennardJonesModel, |
| 155 | + morse_model: MorseModel, |
| 156 | + si_sim_state: ts.SimState, |
| 157 | +) -> None: |
| 158 | + sum_out = SumModel(lj_model_a, morse_model)(si_sim_state) |
| 159 | + serial_out = SerialSumModel(lj_model_a, morse_model)(si_sim_state) |
| 160 | + torch.testing.assert_close(serial_out["energy"], sum_out["energy"]) |
| 161 | + torch.testing.assert_close(serial_out["forces"], sum_out["forces"]) |
| 162 | + torch.testing.assert_close(serial_out["stress"], sum_out["stress"]) |
| 163 | + |
| 164 | + |
| 165 | +def test_serial_sum_model_exposes_extras_to_later_models( |
| 166 | + si_double_sim_state: ts.SimState, |
| 167 | +) -> None: |
| 168 | + producer = ExtraProducerModel() |
| 169 | + consumer = ExtraConsumerModel() |
| 170 | + serial_model = SerialSumModel(producer, consumer) |
| 171 | + state = si_double_sim_state.clone() |
| 172 | + expected_latent = state.positions[:, 0] + 2.0 |
| 173 | + expected_energy = torch.ones(state.n_systems, device=state.device, dtype=state.dtype) |
| 174 | + expected_energy = expected_energy.scatter_add( |
| 175 | + 0, |
| 176 | + state.system_idx, |
| 177 | + expected_latent, |
| 178 | + ) |
| 179 | + |
| 180 | + output = serial_model(state) |
| 181 | + |
| 182 | + assert consumer.seen_latent is not None |
| 183 | + torch.testing.assert_close(consumer.seen_latent, expected_latent) |
| 184 | + torch.testing.assert_close(output["latent"], expected_latent) |
| 185 | + torch.testing.assert_close(output["energy"], expected_energy) |
| 186 | + assert not state.has_extras("latent") |
0 commit comments