Skip to content

Commit 35b0ef4

Browse files
authored
Nits rollup (#501)
Signed-off-by: Rhys Goodall <rhys.goodall@outlook.com>
1 parent 84986c1 commit 35b0ef4

12 files changed

Lines changed: 419 additions & 98 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ docs/reference/torch_sim.*
2828
*.hdf5
2929
*.traj
3030

31+
# ignore torch.save outputs
32+
*.pt
33+
3134
# coverage
3235
coverage.xml
3336
.coverage*

docs/user/reproducibility.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,34 @@ Because TorchSim runs batched simulations, all systems in a batch share a single
7979

8080
If strict reproducibility is required, keep your batching setup fixed.
8181

82-
### Serialising the RNG state
82+
### Serialising state for reproducible restarts
8383

84-
If you wish to be able to resume a session and ensure determinism you need to persist and reload the `torch.Generator` state. This can be done using `torch.save()` and `torch.Generator().set_state()`:
84+
To resume a simulation and ensure determinism you need to persist and reload the complete state, including the `torch.Generator` RNG state. The simplest approach is to save the full state dict with `torch.save()`:
8585

8686
```python
87+
from dataclasses import asdict
88+
from torch_sim.integrators import MDState
89+
8790
# save
91+
torch.save(asdict(state), "checkpoint.pt")
92+
93+
# restore (weights_only=False needed for torch.Generator in PyTorch 2.6+)
94+
restored = MDState(**torch.load("checkpoint.pt", weights_only=False))
95+
```
96+
97+
This captures positions, momenta, forces, energy, cell, and the `torch.Generator` in a single file. Since `torch.save()` uses pickle, the generator is serialised automatically.
98+
99+
> **Pickle caveat**: The `torch.Generator` object in the dict requires `weights_only=False` and may not unpickle across PyTorch versions. For portable checkpoints, save the tensors normally and extract the RNG state as a plain `uint8` tensor via `get_state()` — this loads with `weights_only=True` and is version-safe:
100+
101+
```python
102+
# save RNG state as a plain uint8 tensor (no pickle needed)
88103
rng_state = state.rng.get_state()
89104
torch.save(rng_state, "rng_state.pt")
90105

91106
# restore
92107
gen = torch.Generator(device=state.device)
93-
gen.set_state(torch.load("rng_state.pt"))
108+
gen.set_state(torch.load("rng_state.pt", weights_only=True))
94109
state.rng = gen
95110
```
111+
112+
See the [reproducible restart tutorial](../../examples/tutorials/reproducible_restart_tutorial.py) for a complete worked example.

examples/tutorials/hybrid_swap_tutorial.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ class HybridSwapMCState(SwapMCState, MDState):
107107
ts.SwapMCState._system_attributes | MDState._system_attributes # noqa: SLF001
108108
)
109109

110+
def __post_init__(self) -> None:
111+
"""Initialize HybridSwapMCState and ensure last_permutation is set."""
112+
super().__post_init__()
113+
110114

111115
# %% [markdown]
112116
"""
@@ -127,16 +131,8 @@ class HybridSwapMCState(SwapMCState, MDState):
127131
state.rng = 42
128132
md_state = ts.nvt_langevin_init(state=state, model=mace_model, kT=kT)
129133

130-
# Initialize swap Monte Carlo state
131-
swap_state = ts.swap_mc_init(state=md_state, model=mace_model)
132-
133-
# Create hybrid state combining both
134-
hybrid_state = HybridSwapMCState(
135-
**md_state.attributes,
136-
last_permutation=torch.arange(
137-
md_state.n_atoms, device=md_state.device, dtype=torch.long
138-
),
139-
)
134+
# Create hybrid state from MD state
135+
hybrid_state = HybridSwapMCState(**md_state.attributes)
140136

141137

142138
# %% [markdown]
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# %%
2+
# /// script
3+
# dependencies = [
4+
# "torch_sim_atomistic[io]"
5+
# ]
6+
# ///
7+
8+
9+
# %% [markdown]
10+
"""
11+
# Reproducible Restarts from Stopped Simulations
12+
13+
This tutorial demonstrates how to save and restore simulation state to enable
14+
reproducible restarts. We run 50 steps of MD, save the state (including RNG state),
15+
resume for another 50 steps, and verify the result is identical to 100 uninterrupted
16+
steps.
17+
18+
For stochastic integrators like Langevin dynamics, you must save the random number
19+
generator (RNG) state alongside positions, momenta, and other state variables.
20+
Without it, the stochastic noise will differ on restart and the trajectory will diverge.
21+
"""
22+
23+
# %% [markdown]
24+
"""
25+
## Setup
26+
"""
27+
28+
# %%
29+
from dataclasses import asdict
30+
from pathlib import Path
31+
32+
import torch
33+
import torch_sim as ts
34+
from ase.build import bulk
35+
from torch_sim.integrators import MDState
36+
from torch_sim.models.lennard_jones import LennardJonesModel
37+
38+
# All generated files go in this directory
39+
restart_dir = Path("restart_files")
40+
restart_dir.mkdir(exist_ok=True)
41+
42+
seed = 42
43+
torch.manual_seed(seed)
44+
45+
lj_model = LennardJonesModel(
46+
sigma=2.0,
47+
epsilon=0.1,
48+
device=torch.device("cpu"),
49+
dtype=torch.float64,
50+
)
51+
52+
si_atoms = bulk("Si", "diamond", a=5.43, cubic=True)
53+
54+
initial_state = ts.initialize_state(
55+
si_atoms, device=torch.device("cpu"), dtype=torch.float64
56+
)
57+
initial_state.rng = seed # seed the SimState RNG for reproducibility
58+
59+
print(f"Initial state: {initial_state.n_atoms} atoms")
60+
61+
# %% [markdown]
62+
"""
63+
## Part 1: Run 50 Steps, Save State, Resume for 50 More
64+
65+
We save the complete state with `asdict()` + `torch.save()`. Since `torch.save()`
66+
uses pickle, the `torch.Generator` (RNG) is included automatically — no need to
67+
save it separately.
68+
69+
**PyTorch 2.6+**: You must pass `weights_only=False` to `torch.load()` when loading
70+
checkpoints that contain `torch.Generator` objects.
71+
"""
72+
73+
# %%
74+
# Run first 50 steps
75+
trajectory_file_restart = str(restart_dir / "restart_trajectory.h5")
76+
reporter_restart = ts.TrajectoryReporter(
77+
filenames=trajectory_file_restart,
78+
state_frequency=10,
79+
state_kwargs={"save_velocities": True},
80+
)
81+
82+
state_after_50 = ts.integrate(
83+
system=initial_state.clone(),
84+
model=lj_model,
85+
integrator=ts.Integrator.nvt_langevin,
86+
n_steps=50,
87+
temperature=300,
88+
timestep=0.001,
89+
trajectory_reporter=reporter_restart,
90+
)
91+
reporter_restart.close()
92+
93+
# Save the complete state (including RNG) in one file
94+
checkpoint_file = str(restart_dir / "checkpoint.pt")
95+
torch.save(asdict(state_after_50), checkpoint_file)
96+
print(f"Saved checkpoint after 50 steps to {checkpoint_file}")
97+
98+
# %% [markdown]
99+
"""
100+
Now restore the state and continue for another 50 steps:
101+
"""
102+
103+
# %%
104+
# Load checkpoint (weights_only=False needed for torch.Generator in PyTorch 2.6+)
105+
loaded = torch.load(checkpoint_file, weights_only=False)
106+
restored_state = MDState(**loaded)
107+
108+
# Verify RNG was restored
109+
assert torch.equal(restored_state.rng.get_state(), state_after_50.rng.get_state())
110+
print(f"Restored state: {restored_state.n_atoms} atoms, RNG matches ✓")
111+
112+
# Continue for another 50 steps (append to existing trajectory)
113+
reporter_restart_continued = ts.TrajectoryReporter(
114+
filenames=trajectory_file_restart,
115+
state_frequency=10,
116+
state_kwargs={"save_velocities": True},
117+
trajectory_kwargs={"mode": "a"},
118+
)
119+
120+
state_after_100_restart = ts.integrate(
121+
system=restored_state,
122+
model=lj_model,
123+
integrator=ts.Integrator.nvt_langevin,
124+
n_steps=50,
125+
temperature=300,
126+
timestep=0.001,
127+
trajectory_reporter=reporter_restart_continued,
128+
)
129+
reporter_restart_continued.close()
130+
print(f"Completed restart simulation: 50 + 50 = 100 steps")
131+
132+
# %% [markdown]
133+
"""
134+
## Part 2: Run 100 Steps Continuously for Comparison
135+
"""
136+
137+
# %%
138+
trajectory_file_continuous = str(restart_dir / "continuous_trajectory.h5")
139+
reporter_continuous = ts.TrajectoryReporter(
140+
filenames=trajectory_file_continuous,
141+
state_frequency=10,
142+
state_kwargs={"save_velocities": True},
143+
)
144+
145+
initial_state_continuous = ts.initialize_state(
146+
si_atoms, device=torch.device("cpu"), dtype=torch.float64
147+
)
148+
initial_state_continuous.rng = seed
149+
150+
state_after_100_continuous = ts.integrate(
151+
system=initial_state_continuous,
152+
model=lj_model,
153+
integrator=ts.Integrator.nvt_langevin,
154+
n_steps=100,
155+
temperature=300,
156+
timestep=0.001,
157+
trajectory_reporter=reporter_continuous,
158+
)
159+
reporter_continuous.close()
160+
print(f"Completed continuous simulation: 100 steps")
161+
162+
# %% [markdown]
163+
"""
164+
## Part 3: Compare Trajectories
165+
166+
Both runs started from the same initial state and seed. The restarted run saved and
167+
restored the RNG state at step 50. If everything is correct, the trajectories should
168+
match exactly:
169+
"""
170+
171+
# %%
172+
# Compare final RNG states
173+
rng_match = torch.equal(
174+
state_after_100_restart.rng.get_state(),
175+
state_after_100_continuous.rng.get_state(),
176+
)
177+
print(f"Final RNG states match: {rng_match}")
178+
179+
# Compare trajectories frame by frame
180+
with ts.TorchSimTrajectory(trajectory_file_restart, mode="r") as traj_restart:
181+
positions_restart = traj_restart.get_array("positions")
182+
steps_restart = traj_restart.get_steps("positions")
183+
velocities_restart = traj_restart.get_array("velocities")
184+
185+
with ts.TorchSimTrajectory(trajectory_file_continuous, mode="r") as traj_continuous:
186+
positions_continuous = traj_continuous.get_array("positions")
187+
steps_continuous = traj_continuous.get_steps("positions")
188+
velocities_continuous = traj_continuous.get_array("velocities")
189+
190+
matching_steps = sorted(set(steps_restart) & set(steps_continuous))
191+
print(f"Comparing {len(matching_steps)} frames at steps: {matching_steps}")
192+
193+
max_pos_diff = 0.0
194+
max_vel_diff = 0.0
195+
all_match = True
196+
197+
for step in matching_steps:
198+
idx_restart = steps_restart.tolist().index(step)
199+
idx_continuous = steps_continuous.tolist().index(step)
200+
201+
pos_restart = torch.tensor(positions_restart[idx_restart])
202+
pos_continuous = torch.tensor(positions_continuous[idx_continuous])
203+
vel_restart = torch.tensor(velocities_restart[idx_restart])
204+
vel_continuous = torch.tensor(velocities_continuous[idx_continuous])
205+
206+
pos_diff = torch.max(torch.abs(pos_restart - pos_continuous)).item()
207+
vel_diff = torch.max(torch.abs(vel_restart - vel_continuous)).item()
208+
max_pos_diff = max(max_pos_diff, pos_diff)
209+
max_vel_diff = max(max_vel_diff, vel_diff)
210+
211+
if not torch.allclose(pos_restart, pos_continuous, atol=1e-10, rtol=1e-10):
212+
print(f" Step {step}: Position mismatch! Max diff: {pos_diff:.2e}")
213+
all_match = False
214+
if not torch.allclose(vel_restart, vel_continuous, atol=1e-10, rtol=1e-10):
215+
print(f" Step {step}: Velocity mismatch! Max diff: {vel_diff:.2e}")
216+
all_match = False
217+
218+
assert all_match, (
219+
f"Restarted and continuous trajectories differ! "
220+
f"Max position difference: {max_pos_diff:.2e}, max velocity difference: {max_vel_diff:.2e}"
221+
)
222+
print("\n✓ Restarted and continuous trajectories match exactly.")
223+
224+
# %% [markdown]
225+
"""
226+
## Key Takeaways
227+
228+
1. **Save with `asdict()` + `torch.save()`**: This captures everything — positions,
229+
momenta, forces, energy, cell, and the `torch.Generator` RNG state — in a single
230+
checkpoint file.
231+
232+
2. **Restore with `MDState(**torch.load(...))`**: The `torch.Generator` is unpickled
233+
automatically, so the RNG state is restored without any extra steps.
234+
235+
3. **Use append mode** (`trajectory_kwargs={"mode": "a"}`) in `TrajectoryReporter`
236+
to continue an existing trajectory file.
237+
238+
4. **Pickle caveats**: The `torch.Generator` object in the checkpoint requires pickle
239+
(`weights_only=False`) and may not load across PyTorch versions. For portable
240+
checkpoints, save tensors normally and use `state.rng.get_state()` to extract the
241+
RNG state as a plain `uint8` tensor that works with `weights_only=True`.
242+
243+
5. **Verify**: Always compare restarted trajectories to continuous runs.
244+
"""
245+
246+
# %%
247+
# Cleanup
248+
import shutil
249+
250+
shutil.rmtree(restart_dir)
251+
print(f"Cleaned up {restart_dir}/")

tests/test_monte_carlo.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,26 @@ def test_monte_carlo_integration(
191191
assert torch.all(orig_counts == result_counts)
192192

193193

194+
def test_swap_mc_state_default_last_permutation(
195+
batched_diverse_state: ts.SimState,
196+
) -> None:
197+
"""Test that SwapMCState initializes last_permutation to identity if not provided."""
198+
from torch_sim.monte_carlo import SwapMCState
199+
200+
state = SwapMCState(
201+
positions=batched_diverse_state.positions,
202+
masses=batched_diverse_state.masses,
203+
cell=batched_diverse_state.cell,
204+
pbc=batched_diverse_state.pbc,
205+
atomic_numbers=batched_diverse_state.atomic_numbers,
206+
system_idx=batched_diverse_state.system_idx,
207+
energy=torch.zeros(batched_diverse_state.n_systems),
208+
)
209+
assert state.last_permutation is not None
210+
expected_identity = torch.arange(batched_diverse_state.n_atoms, device=DEVICE)
211+
assert torch.equal(state.last_permutation, expected_identity)
212+
213+
194214
def test_swap_mc_state_attributes():
195215
"""Test SwapMCState class structure and inheritance."""
196216
from torch_sim.state import SimState

tests/test_state.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,35 @@ def test_initialize_state_from_state(ar_supercell_sim_state: SimState) -> None:
348348
assert state.cell.shape == ar_supercell_sim_state.cell.shape
349349

350350

351+
def test_initialize_state_from_list_of_states_with_multiple_systems(
352+
si_double_sim_state: SimState, fe_supercell_sim_state: SimState
353+
) -> None:
354+
"""Test initialize_state with list of states that have n_systems > 1."""
355+
# This should work now that we've removed the arbitrary n_systems == 1 constraint
356+
concatenated = ts.initialize_state([si_double_sim_state, fe_supercell_sim_state])
357+
358+
# Should have 3 systems total (2 from si_double + 1 from fe)
359+
assert concatenated.n_systems == 3
360+
assert concatenated.cell.shape[0] == 3
361+
362+
# Check system indices are correct
363+
fe_atoms = fe_supercell_sim_state.n_atoms
364+
expected_system_indices = torch.cat(
365+
[
366+
si_double_sim_state.system_idx,
367+
torch.full(
368+
(fe_atoms,), 2, dtype=torch.int64, device=fe_supercell_sim_state.device
369+
),
370+
]
371+
)
372+
assert torch.all(concatenated.system_idx == expected_system_indices)
373+
374+
# Verify we can slice back to original states
375+
assert torch.allclose(concatenated[0].positions, si_double_sim_state[0].positions)
376+
assert torch.allclose(concatenated[1].positions, si_double_sim_state[1].positions)
377+
assert torch.allclose(concatenated[2].positions, fe_supercell_sim_state.positions)
378+
379+
351380
def test_initialize_state_from_atoms(si_atoms: "Atoms") -> None:
352381
"""Test conversion from ASE Atoms to SimState."""
353382
state = ts.initialize_state([si_atoms], DEVICE, torch.float64)

0 commit comments

Comments
 (0)