|
| 1 | +# %% |
| 2 | +# /// script |
| 3 | +# dependencies = [ |
| 4 | +# "torch_sim_atomistic[io]" |
| 5 | +# ] |
| 6 | +# /// |
| 7 | + |
| 8 | + |
| 9 | +# %% [markdown] |
| 10 | +""" |
| 11 | +# Reproducible Restarts from Stopped Simulations |
| 12 | +
|
| 13 | +This tutorial demonstrates how to save and restore simulation state to enable |
| 14 | +reproducible restarts. We run 50 steps of MD, save the state (including RNG state), |
| 15 | +resume for another 50 steps, and verify the result is identical to 100 uninterrupted |
| 16 | +steps. |
| 17 | +
|
| 18 | +For stochastic integrators like Langevin dynamics, you must save the random number |
| 19 | +generator (RNG) state alongside positions, momenta, and other state variables. |
| 20 | +Without it, the stochastic noise will differ on restart and the trajectory will diverge. |
| 21 | +""" |
| 22 | + |
| 23 | +# %% [markdown] |
| 24 | +""" |
| 25 | +## Setup |
| 26 | +""" |
| 27 | + |
| 28 | +# %% |
| 29 | +from dataclasses import asdict |
| 30 | +from pathlib import Path |
| 31 | + |
| 32 | +import torch |
| 33 | +import torch_sim as ts |
| 34 | +from ase.build import bulk |
| 35 | +from torch_sim.integrators import MDState |
| 36 | +from torch_sim.models.lennard_jones import LennardJonesModel |
| 37 | + |
| 38 | +# All generated files go in this directory |
| 39 | +restart_dir = Path("restart_files") |
| 40 | +restart_dir.mkdir(exist_ok=True) |
| 41 | + |
| 42 | +seed = 42 |
| 43 | +torch.manual_seed(seed) |
| 44 | + |
| 45 | +lj_model = LennardJonesModel( |
| 46 | + sigma=2.0, |
| 47 | + epsilon=0.1, |
| 48 | + device=torch.device("cpu"), |
| 49 | + dtype=torch.float64, |
| 50 | +) |
| 51 | + |
| 52 | +si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) |
| 53 | + |
| 54 | +initial_state = ts.initialize_state( |
| 55 | + si_atoms, device=torch.device("cpu"), dtype=torch.float64 |
| 56 | +) |
| 57 | +initial_state.rng = seed # seed the SimState RNG for reproducibility |
| 58 | + |
| 59 | +print(f"Initial state: {initial_state.n_atoms} atoms") |
| 60 | + |
| 61 | +# %% [markdown] |
| 62 | +""" |
| 63 | +## Part 1: Run 50 Steps, Save State, Resume for 50 More |
| 64 | +
|
| 65 | +We save the complete state with `asdict()` + `torch.save()`. Since `torch.save()` |
| 66 | +uses pickle, the `torch.Generator` (RNG) is included automatically — no need to |
| 67 | +save it separately. |
| 68 | +
|
| 69 | +**PyTorch 2.6+**: You must pass `weights_only=False` to `torch.load()` when loading |
| 70 | +checkpoints that contain `torch.Generator` objects. |
| 71 | +""" |
| 72 | + |
| 73 | +# %% |
| 74 | +# Run first 50 steps |
| 75 | +trajectory_file_restart = str(restart_dir / "restart_trajectory.h5") |
| 76 | +reporter_restart = ts.TrajectoryReporter( |
| 77 | + filenames=trajectory_file_restart, |
| 78 | + state_frequency=10, |
| 79 | + state_kwargs={"save_velocities": True}, |
| 80 | +) |
| 81 | + |
| 82 | +state_after_50 = ts.integrate( |
| 83 | + system=initial_state.clone(), |
| 84 | + model=lj_model, |
| 85 | + integrator=ts.Integrator.nvt_langevin, |
| 86 | + n_steps=50, |
| 87 | + temperature=300, |
| 88 | + timestep=0.001, |
| 89 | + trajectory_reporter=reporter_restart, |
| 90 | +) |
| 91 | +reporter_restart.close() |
| 92 | + |
| 93 | +# Save the complete state (including RNG) in one file |
| 94 | +checkpoint_file = str(restart_dir / "checkpoint.pt") |
| 95 | +torch.save(asdict(state_after_50), checkpoint_file) |
| 96 | +print(f"Saved checkpoint after 50 steps to {checkpoint_file}") |
| 97 | + |
| 98 | +# %% [markdown] |
| 99 | +""" |
| 100 | +Now restore the state and continue for another 50 steps: |
| 101 | +""" |
| 102 | + |
| 103 | +# %% |
| 104 | +# Load checkpoint (weights_only=False needed for torch.Generator in PyTorch 2.6+) |
| 105 | +loaded = torch.load(checkpoint_file, weights_only=False) |
| 106 | +restored_state = MDState(**loaded) |
| 107 | + |
| 108 | +# Verify RNG was restored |
| 109 | +assert torch.equal(restored_state.rng.get_state(), state_after_50.rng.get_state()) |
| 110 | +print(f"Restored state: {restored_state.n_atoms} atoms, RNG matches ✓") |
| 111 | + |
| 112 | +# Continue for another 50 steps (append to existing trajectory) |
| 113 | +reporter_restart_continued = ts.TrajectoryReporter( |
| 114 | + filenames=trajectory_file_restart, |
| 115 | + state_frequency=10, |
| 116 | + state_kwargs={"save_velocities": True}, |
| 117 | + trajectory_kwargs={"mode": "a"}, |
| 118 | +) |
| 119 | + |
| 120 | +state_after_100_restart = ts.integrate( |
| 121 | + system=restored_state, |
| 122 | + model=lj_model, |
| 123 | + integrator=ts.Integrator.nvt_langevin, |
| 124 | + n_steps=50, |
| 125 | + temperature=300, |
| 126 | + timestep=0.001, |
| 127 | + trajectory_reporter=reporter_restart_continued, |
| 128 | +) |
| 129 | +reporter_restart_continued.close() |
| 130 | +print(f"Completed restart simulation: 50 + 50 = 100 steps") |
| 131 | + |
| 132 | +# %% [markdown] |
| 133 | +""" |
| 134 | +## Part 2: Run 100 Steps Continuously for Comparison |
| 135 | +""" |
| 136 | + |
| 137 | +# %% |
| 138 | +trajectory_file_continuous = str(restart_dir / "continuous_trajectory.h5") |
| 139 | +reporter_continuous = ts.TrajectoryReporter( |
| 140 | + filenames=trajectory_file_continuous, |
| 141 | + state_frequency=10, |
| 142 | + state_kwargs={"save_velocities": True}, |
| 143 | +) |
| 144 | + |
| 145 | +initial_state_continuous = ts.initialize_state( |
| 146 | + si_atoms, device=torch.device("cpu"), dtype=torch.float64 |
| 147 | +) |
| 148 | +initial_state_continuous.rng = seed |
| 149 | + |
| 150 | +state_after_100_continuous = ts.integrate( |
| 151 | + system=initial_state_continuous, |
| 152 | + model=lj_model, |
| 153 | + integrator=ts.Integrator.nvt_langevin, |
| 154 | + n_steps=100, |
| 155 | + temperature=300, |
| 156 | + timestep=0.001, |
| 157 | + trajectory_reporter=reporter_continuous, |
| 158 | +) |
| 159 | +reporter_continuous.close() |
| 160 | +print(f"Completed continuous simulation: 100 steps") |
| 161 | + |
| 162 | +# %% [markdown] |
| 163 | +""" |
| 164 | +## Part 3: Compare Trajectories |
| 165 | +
|
| 166 | +Both runs started from the same initial state and seed. The restarted run saved and |
| 167 | +restored the RNG state at step 50. If everything is correct, the trajectories should |
| 168 | +match exactly: |
| 169 | +""" |
| 170 | + |
| 171 | +# %% |
| 172 | +# Compare final RNG states |
| 173 | +rng_match = torch.equal( |
| 174 | + state_after_100_restart.rng.get_state(), |
| 175 | + state_after_100_continuous.rng.get_state(), |
| 176 | +) |
| 177 | +print(f"Final RNG states match: {rng_match}") |
| 178 | + |
| 179 | +# Compare trajectories frame by frame |
| 180 | +with ts.TorchSimTrajectory(trajectory_file_restart, mode="r") as traj_restart: |
| 181 | + positions_restart = traj_restart.get_array("positions") |
| 182 | + steps_restart = traj_restart.get_steps("positions") |
| 183 | + velocities_restart = traj_restart.get_array("velocities") |
| 184 | + |
| 185 | +with ts.TorchSimTrajectory(trajectory_file_continuous, mode="r") as traj_continuous: |
| 186 | + positions_continuous = traj_continuous.get_array("positions") |
| 187 | + steps_continuous = traj_continuous.get_steps("positions") |
| 188 | + velocities_continuous = traj_continuous.get_array("velocities") |
| 189 | + |
| 190 | +matching_steps = sorted(set(steps_restart) & set(steps_continuous)) |
| 191 | +print(f"Comparing {len(matching_steps)} frames at steps: {matching_steps}") |
| 192 | + |
| 193 | +max_pos_diff = 0.0 |
| 194 | +max_vel_diff = 0.0 |
| 195 | +all_match = True |
| 196 | + |
| 197 | +for step in matching_steps: |
| 198 | + idx_restart = steps_restart.tolist().index(step) |
| 199 | + idx_continuous = steps_continuous.tolist().index(step) |
| 200 | + |
| 201 | + pos_restart = torch.tensor(positions_restart[idx_restart]) |
| 202 | + pos_continuous = torch.tensor(positions_continuous[idx_continuous]) |
| 203 | + vel_restart = torch.tensor(velocities_restart[idx_restart]) |
| 204 | + vel_continuous = torch.tensor(velocities_continuous[idx_continuous]) |
| 205 | + |
| 206 | + pos_diff = torch.max(torch.abs(pos_restart - pos_continuous)).item() |
| 207 | + vel_diff = torch.max(torch.abs(vel_restart - vel_continuous)).item() |
| 208 | + max_pos_diff = max(max_pos_diff, pos_diff) |
| 209 | + max_vel_diff = max(max_vel_diff, vel_diff) |
| 210 | + |
| 211 | + if not torch.allclose(pos_restart, pos_continuous, atol=1e-10, rtol=1e-10): |
| 212 | + print(f" Step {step}: Position mismatch! Max diff: {pos_diff:.2e}") |
| 213 | + all_match = False |
| 214 | + if not torch.allclose(vel_restart, vel_continuous, atol=1e-10, rtol=1e-10): |
| 215 | + print(f" Step {step}: Velocity mismatch! Max diff: {vel_diff:.2e}") |
| 216 | + all_match = False |
| 217 | + |
| 218 | +assert all_match, ( |
| 219 | + f"Restarted and continuous trajectories differ! " |
| 220 | + f"Max position difference: {max_pos_diff:.2e}, max velocity difference: {max_vel_diff:.2e}" |
| 221 | +) |
| 222 | +print("\n✓ Restarted and continuous trajectories match exactly.") |
| 223 | + |
| 224 | +# %% [markdown] |
| 225 | +""" |
| 226 | +## Key Takeaways |
| 227 | +
|
| 228 | +1. **Save with `asdict()` + `torch.save()`**: This captures everything — positions, |
| 229 | + momenta, forces, energy, cell, and the `torch.Generator` RNG state — in a single |
| 230 | + checkpoint file. |
| 231 | +
|
| 232 | +2. **Restore with `MDState(**torch.load(...))`**: The `torch.Generator` is unpickled |
| 233 | + automatically, so the RNG state is restored without any extra steps. |
| 234 | +
|
| 235 | +3. **Use append mode** (`trajectory_kwargs={"mode": "a"}`) in `TrajectoryReporter` |
| 236 | + to continue an existing trajectory file. |
| 237 | +
|
| 238 | +4. **Pickle caveats**: The `torch.Generator` object in the checkpoint requires pickle |
| 239 | + (`weights_only=False`) and may not load across PyTorch versions. For portable |
| 240 | + checkpoints, save tensors normally and use `state.rng.get_state()` to extract the |
| 241 | + RNG state as a plain `uint8` tensor that works with `weights_only=True`. |
| 242 | +
|
| 243 | +5. **Verify**: Always compare restarted trajectories to continuous runs. |
| 244 | +""" |
| 245 | + |
| 246 | +# %% |
| 247 | +# Cleanup |
| 248 | +import shutil |
| 249 | + |
| 250 | +shutil.rmtree(restart_dir) |
| 251 | +print(f"Cleaned up {restart_dir}/") |
0 commit comments