Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
[project.optional-dependencies]
test = [
"torch-sim-atomistic[io,symmetry,vesin]",
"physical-validation>=1.0.5",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This package seems unsupported

"platformdirs>=4.0.0",
"psutil>=7.0.0",
"pymatgen>=2025.6.14",
Expand Down Expand Up @@ -139,8 +140,11 @@ check-filenames = true
ignore-words-list = ["convertor"] # codespell:ignore convertor

[tool.pytest.ini_options]
addopts = ["-p no:warnings"]
addopts = ["-p no:warnings", "-m not physical_validation"]
testpaths = ["tests"]
markers = [
"physical_validation: long-running physical validation tests (run with: pytest -m physical_validation)",
]

[tool.uv]
# make these dependencies mutually exclusive since they use incompatible e3nn versions
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@

torch.set_num_threads(4)


def pytest_addoption(parser):
parser.addoption(
"--validation-plots",
action="store_true",
default=False,
help="Save physical validation plots to tests/physical_validation_data/plots/",
)
parser.addoption(
"--clean-validation-data",
action="store_true",
default=False,
help="Delete saved physical validation data before running tests",
)

DEVICE = torch.device("cpu")
DTYPE = torch.float64

Expand Down
78 changes: 69 additions & 9 deletions tests/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,66 @@ def test_npt_langevin(
assert pos_diff > 0.0001 # Systems should remain separated


def test_npt_langevin_strain(
ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel
) -> None:
n_steps = 200
dt = torch.tensor(0.001, dtype=DTYPE) * MetalUnits.time
kT = torch.tensor(300.0, dtype=DTYPE) * MetalUnits.temperature
external_pressure = torch.tensor(10.0, dtype=DTYPE) * MetalUnits.pressure
alpha = 1 * dt
cell_alpha = 10 * dt
b_tau = 30 * dt

ar_double_sim_state.rng = 42
state = ts.npt_langevin_strain_init(
state=ar_double_sim_state,
model=lj_model,
dt=dt,
kT=kT,
alpha=alpha,
cell_alpha=cell_alpha,
b_tau=b_tau,
)

# Check strain state shape
assert state.cell_positions.shape == (2,) # scalar strain per system

energies = []
temperatures = []
for _step in range(n_steps):
state = ts.npt_langevin_strain_step(
state=state,
model=lj_model,
dt=dt,
kT=kT,
external_pressure=external_pressure,
)

temp = ts.calc_kT(
masses=state.masses, momenta=state.momenta, system_idx=state.system_idx
)
energies.append(state.energy)
temperatures.append(temp / MetalUnits.temperature)

temperatures_tensor = torch.stack(temperatures)
energies_tensor = torch.stack(energies)
energies_list = [t.tolist() for t in energies_tensor.T]

assert len(energies_list[0]) == n_steps

mean_temps = torch.mean(temperatures_tensor, dim=0)
for mean_temp in mean_temps:
assert abs(mean_temp - kT.item() / MetalUnits.temperature) < 150.0

for traj in energies_list:
energy_std = torch.tensor(traj).std()
assert energy_std < 1.0

# Cell reconstruction is consistent
assert torch.allclose(state.cell, state.current_cell)


def test_npt_langevin_multi_kt(
ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel
):
Expand Down Expand Up @@ -339,7 +399,7 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
assert torch.allclose(
temperatures_tensor[-1],
torch.tensor([290.3553, 289.9699], dtype=dtype),
torch.tensor([305.6400, 305.4556], dtype=dtype),
)

energies_tensor = torch.stack(energies)
Expand Down Expand Up @@ -728,7 +788,7 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
assert torch.allclose(
temperatures_tensor[-1],
torch.tensor([287.5729, 287.1330], dtype=dtype),
torch.tensor([283.1162, 313.1624], dtype=dtype),
)

energies_tensor = torch.stack(energies)
Expand Down Expand Up @@ -1023,19 +1083,19 @@ def test_compute_cell_force_atoms_per_system():
atomic_numbers=torch.ones(72, dtype=torch.long),
stress=torch.zeros((2, 3, 3)),
reference_cell=torch.eye(3).repeat(2, 1, 1),
cell_positions=torch.ones((2, 3, 3)),
cell_velocities=torch.zeros((2, 3, 3)),
cell_positions=torch.zeros(2, 3),
cell_velocities=torch.zeros(2, 3),
cell_masses=torch.ones(2),
alpha=torch.ones(2),
cell_alpha=torch.ones(2),
b_tau=torch.ones(2),
)

# Get forces and compare ratio
cell_force = _compute_cell_force(state, torch.tensor(0.0), torch.tensor([1.0, 1.0]))
force_ratio = (
torch.diagonal(cell_force[1]).mean() / torch.diagonal(cell_force[0]).mean()
)
# Get forces and compare ratio (per-dimension force)
P_ext = torch.zeros(2, 3)
cell_force = _compute_cell_force(state, P_ext, torch.tensor([1.0, 1.0]))
# Check the first dimension's force ratio
force_ratio = cell_force[1, 0] / cell_force[0, 0]

# Force ratio should match atom ratio (8:1) with the fix
assert abs(force_ratio - 8.0) / 8.0 < 0.1
Expand Down
14 changes: 7 additions & 7 deletions tests/test_nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None:

result = build_triplets(edge_index, n_atoms)

assert result["trip_in"].device == dev
assert result["trip_out"].device == dev
assert result["center_atom"].device == dev
assert result["trip_in"].device.type == dev.type
assert result["trip_out"].device.type == dev.type
assert result["center_atom"].device.type == dev.type


@pytest.mark.parametrize(
Expand All @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None:
internal_cell_offsets,
)

assert result["quad_c_to_a_edge"].device == dev
assert result["quad_d_to_b_trip_idx"].device == dev
assert result["d_to_b_edge"].device == dev
assert result["c_to_a_edge"].device == dev
assert result["quad_c_to_a_edge"].device.type == dev.type
assert result["quad_d_to_b_trip_idx"].device.type == dev.type
assert result["d_to_b_edge"].device.type == dev.type
assert result["c_to_a_edge"].device.type == dev.type


def test_build_triplets_jit_script() -> None:
Expand Down
Loading
Loading