Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "torch-sim-atomistic"
version = "0.5.2"
version = "0.6.0"
description = "A pytorch toolkit for calculating material properties using MLIPs"
authors = [
{ name = "Abhijeet Gangan", email = "abhijeetgangan@g.ucla.edu" },
Expand Down
24 changes: 23 additions & 1 deletion tests/models/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Pytest fixtures and test factories for model testing."""

from __future__ import annotations

import typing

import pytest
Expand All @@ -10,7 +12,10 @@


if typing.TYPE_CHECKING:
from collections.abc import Callable, Sequence

from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState


def make_model_calculator_consistency_test(
Expand Down Expand Up @@ -81,22 +86,39 @@ def make_validate_model_outputs_test(
dtype: torch.dtype = DTYPE,
*,
check_detached: bool = True,
state_modifiers: Sequence[Callable[[SimState], SimState]] = (),
):
"""Factory function to create model output validation tests.

Runs ``validate_model_outputs`` once with no modifier (baseline), then
once more for each entry in *state_modifiers* so that every modifier
gets a full, independent validation pass.

Args:
model_fixture_name: Name of the model fixture to validate
device: Device to run validation on
dtype: Data type to use for validation
check_detached: Whether to assert output tensors are detached from the
autograd graph (skipped for models with ``retain_graph=True``).
state_modifiers: Each callable receives a ``SimState`` and returns a
(possibly new) ``SimState``. The full validation suite is run
once per modifier so that different input edge-cases are
exercised independently.
"""
from torch_sim.models.interface import validate_model_outputs

def test_model_output_validation(request: pytest.FixtureRequest) -> None:
"""Test that a model implementation follows the ModelInterface contract."""
model: ModelInterface = request.getfixturevalue(model_fixture_name)
validate_model_outputs(model, device, dtype, check_detached=check_detached)
modifiers = state_modifiers or [None]
for modifier in modifiers:
validate_model_outputs(
model,
device,
dtype,
check_detached=check_detached,
state_modifier=modifier,
)

test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
10 changes: 6 additions & 4 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None:
mol.info["charge"] = charge
mol.info["spin"] = spin

# Convert to SimState (should extract charge/spin)
state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE)

# Verify charge/spin were extracted correctly
state = ts.io.atoms_to_state(
[mol],
device=DEVICE,
dtype=DTYPE,
system_extras={"charge": "charge", "spin": "spin"},
)
assert state.charge is not None
assert state.spin is not None
assert state.charge[0].item() == charge
Expand Down
16 changes: 8 additions & 8 deletions tests/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_get_elementary_deformations_strain_consistency(
n_deform=n_deform,
max_strain_normal=max_strain_normal,
max_strain_shear=max_strain_shear,
bravais_type=BravaisType.triclinic, # Test all axes
bravais_type=BravaisType.TRICLINIC, # Test all axes
)

# Should generate deformations for all 6 axes (triclinic)
Expand Down Expand Up @@ -271,12 +271,12 @@ def mace_model() -> MaceModel:
@pytest.mark.parametrize(
("sim_state_name", "expected_bravais_type", "atol"),
[
("cu_sim_state", BravaisType.cubic, 2e-1),
("mg_sim_state", BravaisType.hexagonal, 5e-1),
("sb_sim_state", BravaisType.trigonal, 5e-1),
("tio2_sim_state", BravaisType.tetragonal, 5e-1),
("ga_sim_state", BravaisType.orthorhombic, 5e-1),
("niti_sim_state", BravaisType.monoclinic, 5e-1),
("cu_sim_state", BravaisType.CUBIC, 2e-1),
("mg_sim_state", BravaisType.HEXAGONAL, 5e-1),
("sb_sim_state", BravaisType.TRIGONAL, 5e-1),
("tio2_sim_state", BravaisType.TETRAGONAL, 5e-1),
("ga_sim_state", BravaisType.ORTHORHOMBIC, 5e-1),
("niti_sim_state", BravaisType.MONOCLINIC, 5e-1),
],
)
def test_elastic_tensor_symmetries(
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_elastic_tensor_symmetries(
)
C_triclinic = (
calculate_elastic_tensor(
state=state, model=model, bravais_type=BravaisType.triclinic
state=state, model=model, bravais_type=BravaisType.TRICLINIC
)
* UnitConversion.eV_per_Ang3_to_GPa
)
Expand Down
110 changes: 110 additions & 0 deletions tests/test_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import torch

import torch_sim as ts


class TestExtras:
def test_system_extras_construction(self):
"""Extras can be passed at construction time."""
field = torch.randn(1, 3)
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
external_E_field=field,
)
assert torch.equal(state.external_E_field, field)

def test_atom_extras_construction(self):
"""Per-atom extras work at construction time."""
tags = torch.tensor([1.0, 2.0])
state = ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_atom_extras={"tags": tags},
)
assert torch.equal(state.tags, tags)

def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState):
with pytest.raises(AttributeError, match="nonexistent_key"):
_ = cu_sim_state.nonexistent_key

def test_post_init_validation_rejects_bad_shape(self):
with pytest.raises(ValueError, match="leading dim must be n_systems"):
ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"bad": torch.randn(5, 3)},
)

def test_construction_extras_cannot_shadow(self):
with pytest.raises(ValueError, match="shadows an existing attribute"):
ts.SimState(
positions=torch.zeros(2, 3),
masses=torch.ones(2),
cell=torch.eye(3).unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([1, 1], dtype=torch.int),
_system_extras={"cell": torch.zeros(1, 3)},
)

def test_store_model_extras_canonical_keys_not_stored(
self, si_double_sim_state: ts.SimState
):
"""Canonical keys (energy, forces, stress) must not land in extras."""
state = si_double_sim_state.clone()
state.store_model_extras(
{
"energy": torch.randn(state.n_systems),
"forces": torch.randn(state.n_atoms, 3),
"stress": torch.randn(state.n_systems, 3, 3),
}
)
for key in ("energy", "forces", "stress"):
assert key not in state._system_extras # noqa: SLF001
assert key not in state._atom_extras # noqa: SLF001

def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_systems go into system_extras."""
state = si_double_sim_state.clone()
dipole = torch.randn(state.n_systems, 3)
state.store_model_extras(
{"energy": torch.randn(state.n_systems), "dipole": dipole}
)
assert torch.equal(state.dipole, dipole)

def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_atoms go into atom_extras."""
state = si_double_sim_state.clone()
charges = torch.randn(state.n_atoms)
density = torch.randn(state.n_atoms, 8)
state.store_model_extras(
{
"energy": torch.randn(state.n_systems),
"charges": charges,
"density_coefficients": density,
}
)
assert torch.equal(state.charges, charges)
assert state.density_coefficients.shape == (state.n_atoms, 8)

def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState):
"""0-d tensors and non-Tensor values are silently ignored."""
state = si_double_sim_state.clone()
state.store_model_extras(
{
"scalar": torch.tensor(3.14),
"string": "not a tensor",
}
)
assert not state.has_extras("scalar")
assert not state.has_extras("string")
Loading
Loading