diff --git a/pyproject.toml b/pyproject.toml index 9ebc3752d..f4b9ab60b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "torch-sim-atomistic" -version = "0.5.2" +version = "0.6.0" description = "A pytorch toolkit for calculating material properties using MLIPs" authors = [ { name = "Abhijeet Gangan", email = "abhijeetgangan@g.ucla.edu" }, diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 4c78118d4..0448b9165 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -1,5 +1,7 @@ """Pytest fixtures and test factories for model testing.""" +from __future__ import annotations + import typing import pytest @@ -10,7 +12,10 @@ if typing.TYPE_CHECKING: + from collections.abc import Callable, Sequence + from torch_sim.models.interface import ModelInterface + from torch_sim.state import SimState def make_model_calculator_consistency_test( @@ -81,22 +86,39 @@ def make_validate_model_outputs_test( dtype: torch.dtype = DTYPE, *, check_detached: bool = True, + state_modifiers: Sequence[Callable[[SimState], SimState]] = (), ): """Factory function to create model output validation tests. + Runs ``validate_model_outputs`` once with no modifier (baseline), then + once more for each entry in *state_modifiers* so that every modifier + gets a full, independent validation pass. + Args: model_fixture_name: Name of the model fixture to validate device: Device to run validation on dtype: Data type to use for validation check_detached: Whether to assert output tensors are detached from the autograd graph (skipped for models with ``retain_graph=True``). + state_modifiers: Each callable receives a ``SimState`` and returns a + (possibly new) ``SimState``. The full validation suite is run + once per modifier so that different input edge-cases are + exercised independently. """ from torch_sim.models.interface import validate_model_outputs def test_model_output_validation(request: pytest.FixtureRequest) -> None: """Test that a model implementation follows the ModelInterface contract.""" model: ModelInterface = request.getfixturevalue(model_fixture_name) - validate_model_outputs(model, device, dtype, check_detached=check_detached) + modifiers = state_modifiers or [None] + for modifier in modifiers: + validate_model_outputs( + model, + device, + dtype, + check_detached=check_detached, + state_modifier=modifier, + ) test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index d259e2ec1..25fa1ba38 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -263,10 +263,12 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None: mol.info["charge"] = charge mol.info["spin"] = spin - # Convert to SimState (should extract charge/spin) - state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE) - - # Verify charge/spin were extracted correctly + state = ts.io.atoms_to_state( + [mol], + device=DEVICE, + dtype=DTYPE, + system_extras={"charge": "charge", "spin": "spin"}, + ) assert state.charge is not None assert state.spin is not None assert state.charge[0].item() == charge diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 6ad6af76f..2ac174410 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -221,7 +221,7 @@ def test_get_elementary_deformations_strain_consistency( n_deform=n_deform, max_strain_normal=max_strain_normal, max_strain_shear=max_strain_shear, - bravais_type=BravaisType.triclinic, # Test all axes + bravais_type=BravaisType.TRICLINIC, # Test all axes ) # Should generate deformations for all 6 axes (triclinic) @@ -271,12 +271,12 @@ def mace_model() -> MaceModel: @pytest.mark.parametrize( ("sim_state_name", "expected_bravais_type", "atol"), [ - ("cu_sim_state", BravaisType.cubic, 2e-1), - ("mg_sim_state", BravaisType.hexagonal, 5e-1), - ("sb_sim_state", BravaisType.trigonal, 5e-1), - ("tio2_sim_state", BravaisType.tetragonal, 5e-1), - ("ga_sim_state", BravaisType.orthorhombic, 5e-1), - ("niti_sim_state", BravaisType.monoclinic, 5e-1), + ("cu_sim_state", BravaisType.CUBIC, 2e-1), + ("mg_sim_state", BravaisType.HEXAGONAL, 5e-1), + ("sb_sim_state", BravaisType.TRIGONAL, 5e-1), + ("tio2_sim_state", BravaisType.TETRAGONAL, 5e-1), + ("ga_sim_state", BravaisType.ORTHORHOMBIC, 5e-1), + ("niti_sim_state", BravaisType.MONOCLINIC, 5e-1), ], ) def test_elastic_tensor_symmetries( @@ -340,7 +340,7 @@ def test_elastic_tensor_symmetries( ) C_triclinic = ( calculate_elastic_tensor( - state=state, model=model, bravais_type=BravaisType.triclinic + state=state, model=model, bravais_type=BravaisType.TRICLINIC ) * UnitConversion.eV_per_Ang3_to_GPa ) diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 000000000..3bc36b3df --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,110 @@ +import pytest +import torch + +import torch_sim as ts + + +class TestExtras: + def test_system_extras_construction(self): + """Extras can be passed at construction time.""" + field = torch.randn(1, 3) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + external_E_field=field, + ) + assert torch.equal(state.external_E_field, field) + + def test_atom_extras_construction(self): + """Per-atom extras work at construction time.""" + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + assert torch.equal(state.tags, tags) + + def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState): + with pytest.raises(AttributeError, match="nonexistent_key"): + _ = cu_sim_state.nonexistent_key + + def test_post_init_validation_rejects_bad_shape(self): + with pytest.raises(ValueError, match="leading dim must be n_systems"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"bad": torch.randn(5, 3)}, + ) + + def test_construction_extras_cannot_shadow(self): + with pytest.raises(ValueError, match="shadows an existing attribute"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"cell": torch.zeros(1, 3)}, + ) + + def test_store_model_extras_canonical_keys_not_stored( + self, si_double_sim_state: ts.SimState + ): + """Canonical keys (energy, forces, stress) must not land in extras.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "forces": torch.randn(state.n_atoms, 3), + "stress": torch.randn(state.n_systems, 3, 3), + } + ) + for key in ("energy", "forces", "stress"): + assert key not in state._system_extras # noqa: SLF001 + assert key not in state._atom_extras # noqa: SLF001 + + def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_systems go into system_extras.""" + state = si_double_sim_state.clone() + dipole = torch.randn(state.n_systems, 3) + state.store_model_extras( + {"energy": torch.randn(state.n_systems), "dipole": dipole} + ) + assert torch.equal(state.dipole, dipole) + + def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_atoms go into atom_extras.""" + state = si_double_sim_state.clone() + charges = torch.randn(state.n_atoms) + density = torch.randn(state.n_atoms, 8) + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "charges": charges, + "density_coefficients": density, + } + ) + assert torch.equal(state.charges, charges) + assert state.density_coefficients.shape == (state.n_atoms, 8) + + def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState): + """0-d tensors and non-Tensor values are silently ignored.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "scalar": torch.tensor(3.14), + "string": "not a tensor", + } + ) + assert not state.has_extras("scalar") + assert not state.has_extras("string") diff --git a/tests/test_io.py b/tests/test_io.py index 2bb4f0175..46e300d72 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,65 +2,20 @@ import sys from typing import Any +import numpy as np import pytest import torch from ase import Atoms -from phonopy.structure.atoms import PhonopyAtoms -from pymatgen.core import Structure +from ase.build import molecule import torch_sim as ts from tests.conftest import DEVICE, DTYPE from torch_sim.state import SimState -def test_single_structure_to_state(si_structure: Structure) -> None: - """Test conversion from pymatgen Structure to state tensors.""" - state = ts.io.structures_to_state(si_structure, DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert all( - t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) - ) - assert all( - t.dtype == torch.float64 for t in (state.positions, state.masses, state.cell) - ) - assert state.atomic_numbers.dtype == torch.int - - # Check shapes and values - assert state.positions.shape == (8, 3) - assert torch.allclose(state.masses, torch.full_like(state.masses, 28.0855)) # Si - assert torch.all(state.atomic_numbers == 14) # Si atomic number - assert torch.allclose( - state.cell, - torch.diag(torch.full((3,), 5.43, device=DEVICE, dtype=torch.float64)), - ) - - -def test_multiple_structures_to_state(si_structure: Structure) -> None: - """Test conversion from list of pymatgen Structure to state tensors.""" - state = ts.io.structures_to_state([si_structure, si_structure], DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert state.positions.shape == (16, 3) - assert state.masses.shape == (16,) - assert state.cell.shape == (2, 3, 3) - assert torch.all(state.pbc) - assert state.atomic_numbers.shape == (16,) - assert state.system_idx is not None - assert state.system_idx.shape == (16,) - assert torch.all( - state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8) - ) - - def test_single_atoms_to_state(si_atoms: Atoms) -> None: - """Test conversion from ASE Atoms to state tensors.""" + """Test basic shape/dtype/device properties of atoms_to_state.""" state = ts.io.atoms_to_state(si_atoms, DEVICE, torch.float64) - - # Check basic properties assert isinstance(state, SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) @@ -68,83 +23,7 @@ def test_single_atoms_to_state(si_atoms: Atoms) -> None: assert torch.all(state.pbc) assert state.atomic_numbers.shape == (8,) assert state.system_idx is not None - assert state.system_idx.shape == (8,) assert torch.all(state.system_idx == 0) - - -def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: - """Test conversion from ASE Atoms to state tensors.""" - state = ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert state.positions.shape == (16, 3) - assert state.masses.shape == (16,) - assert state.cell.shape == (2, 3, 3) - assert torch.all(state.pbc) - assert state.atomic_numbers.shape == (16,) - assert state.system_idx is not None - assert state.system_idx.shape == (16,) - assert torch.all( - state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8), - ) - - -def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: - """Test conversion from state tensors to list of pymatgen Structure.""" - structures = ts.io.state_to_structures(ar_supercell_sim_state) - assert len(structures) == 1 - assert isinstance(structures[0], Structure) - assert len(structures[0]) == 32 - - -def test_state_to_multiple_structures(ar_double_sim_state: SimState) -> None: - """Test conversion from state tensors to list of pymatgen Structure.""" - structures = ts.io.state_to_structures(ar_double_sim_state) - assert len(structures) == 2 - assert isinstance(structures[0], Structure) - assert isinstance(structures[1], Structure) - assert len(structures[0]) == 32 - assert len(structures[1]) == 32 - - -def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: - """Test conversion from state tensors to list of ASE Atoms.""" - atoms = ts.io.state_to_atoms(ar_supercell_sim_state) - assert len(atoms) == 1 - assert isinstance(atoms[0], Atoms) - assert len(atoms[0]) == 32 - - -def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: - """Test conversion from state tensors to list of ASE Atoms.""" - atoms = ts.io.state_to_atoms(ar_double_sim_state) - assert len(atoms) == 2 - assert isinstance(atoms[0], Atoms) - assert isinstance(atoms[1], Atoms) - assert len(atoms[0]) == 32 - assert len(atoms[1]) == 32 - - -def test_to_atoms(ar_supercell_sim_state: SimState) -> None: - """Test conversion from SimState to list of ASE Atoms.""" - atoms = ts.io.state_to_atoms(ar_supercell_sim_state) - assert isinstance(atoms[0], Atoms) - - -def test_to_structures(ar_supercell_sim_state: SimState) -> None: - """Test conversion from SimState to list of Pymatgen Structure.""" - structures = ts.io.state_to_structures(ar_supercell_sim_state) - assert isinstance(structures[0], Structure) - - -def test_single_phonopy_to_state(si_phonopy_atoms: Any) -> None: - """Test conversion from PhonopyAtoms to state tensors.""" - state = ts.io.phonopy_to_state(si_phonopy_atoms, DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) assert all( t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) ) @@ -153,53 +32,92 @@ def test_single_phonopy_to_state(si_phonopy_atoms: Any) -> None: ) assert state.atomic_numbers.dtype == torch.int - # Check shapes and values - assert state.positions.shape == (8, 3) - assert torch.allclose(state.masses, torch.full_like(state.masses, 28.0855)) # Si - assert torch.all(state.atomic_numbers == 14) # Si atomic number - assert torch.allclose( - state.cell, - torch.diag(torch.full((3,), 5.43, device=DEVICE, dtype=torch.float64)), - ) - -def test_multiple_phonopy_to_state(si_phonopy_atoms: Any) -> None: - """Test conversion from multiple PhonopyAtoms to state tensors.""" - state = ts.io.phonopy_to_state( - [si_phonopy_atoms, si_phonopy_atoms], DEVICE, torch.float64 +@pytest.mark.parametrize( + ("system_extras", "atom_extras", "expected_sys", "expected_atom"), + [ + pytest.param(None, None, {}, {}, id="no-extras-by-default"), + pytest.param( + {"charge": "charge", "spin": "spin"}, + None, + {"charge": 3.0, "spin": 2.0}, + {}, + id="system-extras-identity-map", + ), + pytest.param( + {"total_charge": "charge"}, + None, + {"total_charge": 3.0}, + {}, + id="system-extras-rename", + ), + pytest.param( + None, + {"site_tags": "my_tags"}, + {}, + {"site_tags": [1.0, 2.0, 3.0]}, + id="atom-extras-rename", + ), + ], +) +def test_extras_map_import( + system_extras: dict[str, str] | None, + atom_extras: dict[str, str] | None, + expected_sys: dict[str, float], + expected_atom: dict[str, list[float]], +) -> None: + """ExtrasMap controls which keys are read and how they are renamed on import.""" + mol = molecule("H2O") + mol.info["charge"] = 3.0 + mol.info["spin"] = 2.0 + mol.arrays["my_tags"] = np.array([1.0, 2.0, 3.0]) + state = ts.io.atoms_to_state( + [mol], DEVICE, DTYPE, system_extras=system_extras, atom_extras=atom_extras ) - - # Check basic properties - assert isinstance(state, SimState) - assert state.positions.shape == (16, 3) - assert state.masses.shape == (16,) - assert state.cell.shape == (2, 3, 3) - assert torch.all(state.pbc) - assert state.atomic_numbers.shape == (16,) - assert state.system_idx is not None - assert state.system_idx.shape == (16,) - assert torch.all( - state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8), + if not expected_sys and not expected_atom: + assert not state.system_extras + assert not state.atom_extras + for key, val in expected_sys.items(): + assert getattr(state, key)[0].item() == val + for key, vals in expected_atom.items(): + assert getattr(state, key).shape == (len(vals),) + + +def test_extras_map_missing_key_skipped() -> None: + """Missing ASE keys are silently skipped rather than defaulting to zero.""" + mol = molecule("H2O") + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE, system_extras={"charge": "charge"}) + assert not state.system_extras + + +def test_extras_map_multi_system() -> None: + """System extras work across multiple systems with correct per-system values.""" + mol1, mol2 = molecule("H2O"), molecule("CH4") + mol1.info["charge"] = 1.0 + mol2.info["charge"] = -1.0 + state = ts.io.atoms_to_state( + [mol1, mol2], DEVICE, DTYPE, system_extras={"charge": "charge"} ) - - -def test_state_to_phonopy(ar_supercell_sim_state: SimState) -> None: - """Test conversion from state tensors to list of PhonopyAtoms.""" - phonopy_atoms = ts.io.state_to_phonopy(ar_supercell_sim_state) - assert len(phonopy_atoms) == 1 - assert isinstance(phonopy_atoms[0], PhonopyAtoms) - assert len(phonopy_atoms[0]) == 32 - - -def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: - """Test conversion from state tensors to list of PhonopyAtoms.""" - phonopy_atoms = ts.io.state_to_phonopy(ar_double_sim_state) - assert len(phonopy_atoms) == 2 - assert isinstance(phonopy_atoms[0], PhonopyAtoms) - assert isinstance(phonopy_atoms[1], PhonopyAtoms) - assert len(phonopy_atoms[0]) == 32 - assert len(phonopy_atoms[1]) == 32 + assert state.charge.shape == (2,) + assert state.charge[0].item() == 1.0 + assert state.charge[1].item() == -1.0 + + +def test_extras_map_export_roundtrip() -> None: + """System and atom extras round-trip through state_to_atoms with rename.""" + mol = molecule("H2O") + mol.info["charge"] = 5.0 + mol.arrays["my_tags"] = np.array([1.0, 2.0, 3.0]) + sys_map = {"total_charge": "charge"} + atom_map = {"site_tags": "my_tags"} + state = ts.io.atoms_to_state( + [mol], DEVICE, DTYPE, system_extras=sys_map, atom_extras=atom_map + ) + atoms = ts.io.state_to_atoms(state, system_extras=sys_map, atom_extras=atom_map) + assert atoms[0].info["charge"] == 5.0 + np.testing.assert_allclose(atoms[0].arrays["my_tags"], [1.0, 2.0, 3.0]) + atoms_no_map = ts.io.state_to_atoms(state) + assert "charge" not in atoms_no_map[0].info @pytest.mark.parametrize( @@ -214,7 +132,6 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: "cu_sim_state", "ar_double_sim_state", "mixed_double_sim_state", - # TODO: round trip benzene/non-pbc systems ], [ (ts.io.state_to_atoms, ts.io.atoms_to_state), @@ -226,100 +143,74 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: def test_state_round_trip( sim_state_name: str, conversion_functions: tuple, request: pytest.FixtureRequest ) -> None: - """Test round-trip conversion from SimState through various formats and back. - - Args: - sim_state_name: Name of the sim_state fixture to test - conversion_functions: Tuple of (to_format, from_format) conversion functions - request: Pytest fixture request object to get dynamic fixtures - """ - # Get the sim_state fixture dynamically using the name + """Test round-trip conversion from SimState through various formats and back.""" sim_state: SimState = request.getfixturevalue(sim_state_name) to_format_fn, from_format_fn = conversion_functions assert sim_state.system_idx is not None uniq_systems = torch.unique(sim_state.system_idx) - - # Convert to intermediate format intermediate_format = to_format_fn(sim_state) assert len(intermediate_format) == len(uniq_systems) - - # Convert back to state round_trip_state: SimState = from_format_fn(intermediate_format, DEVICE, DTYPE) - - # Check that all properties match assert round_trip_state.system_idx is not None assert torch.allclose(sim_state.positions, round_trip_state.positions) assert torch.allclose(sim_state.cell, round_trip_state.cell) assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) assert torch.all(sim_state.system_idx == round_trip_state.system_idx) assert torch.equal(sim_state.pbc, round_trip_state.pbc) - if isinstance(intermediate_format[0], Atoms): - # TODO: masses round trip for pmg and phonopy masses is not exact - # since both use their own isotope masses based on species, - # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) -def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "ase", None) - monkeypatch.setitem(sys.modules, "ase.data", None) - - with pytest.raises( - ImportError, match="ASE is required for state_to_atoms conversion" - ): - ts.io.state_to_atoms(None) - - -def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "phonopy", None) - monkeypatch.setitem(sys.modules, "phonopy.structure", None) - monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) - - with pytest.raises( - ImportError, match="Phonopy is required for state_to_phonopy conversion" - ): - ts.io.state_to_phonopy(None) - - -def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "pymatgen", None) - monkeypatch.setitem(sys.modules, "pymatgen.core", None) - monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) - - with pytest.raises( - ImportError, match="Pymatgen is required for state_to_structures conversion" - ): - ts.io.state_to_structures(None) - - -def test_atoms_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "ase", None) - monkeypatch.setitem(sys.modules, "ase.data", None) - - with pytest.raises( - ImportError, match="ASE is required for atoms_to_state conversion" - ): - ts.io.atoms_to_state(None, None, None) - - -def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "phonopy", None) - monkeypatch.setitem(sys.modules, "phonopy.structure", None) - monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) - - with pytest.raises( - ImportError, match="Phonopy is required for phonopy_to_state conversion" - ): - ts.io.phonopy_to_state(None, None, None) - - -def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "pymatgen", None) - monkeypatch.setitem(sys.modules, "pymatgen.core", None) - monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) - - with pytest.raises( - ImportError, match="Pymatgen is required for structures_to_state conversion" - ): - ts.io.structures_to_state(None, None, None) +@pytest.mark.parametrize( + ("monkeypatch_modules", "func", "args", "match"), + [ + ( + ["ase", "ase.data"], + ts.io.state_to_atoms, + (None,), + "ASE is required for state_to_atoms", + ), + ( + ["ase", "ase.data"], + ts.io.atoms_to_state, + (None, None, None), + "ASE is required for atoms_to_state", + ), + ( + ["phonopy", "phonopy.structure", "phonopy.structure.atoms"], + ts.io.state_to_phonopy, + (None,), + "Phonopy is required for state_to_phonopy", + ), + ( + ["phonopy", "phonopy.structure", "phonopy.structure.atoms"], + ts.io.phonopy_to_state, + (None, None, None), + "Phonopy is required for phonopy_to_state", + ), + ( + ["pymatgen", "pymatgen.core", "pymatgen.core.structure"], + ts.io.state_to_structures, + (None,), + "Pymatgen is required for state_to_structures", + ), + ( + ["pymatgen", "pymatgen.core", "pymatgen.core.structure"], + ts.io.structures_to_state, + (None, None, None), + "Pymatgen is required for structures_to_state", + ), + ], +) +def test_import_errors( + monkeypatch: pytest.MonkeyPatch, + monkeypatch_modules: list[str], + func: Any, + args: tuple, + match: str, +) -> None: + """All IO functions raise ImportError when their backend is unavailable.""" + for mod in monkeypatch_modules: + monkeypatch.setitem(sys.modules, mod, None) + with pytest.raises(ImportError, match=match): + func(*args) diff --git a/tests/test_nbody.py b/tests/test_nbody.py index e235cd629..5da5a45c5 100644 --- a/tests/test_nbody.py +++ b/tests/test_nbody.py @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None: result = build_triplets(edge_index, n_atoms) - assert result["trip_in"].device == dev - assert result["trip_out"].device == dev - assert result["center_atom"].device == dev + assert result["trip_in"].device.type == dev.type + assert result["trip_out"].device.type == dev.type + assert result["center_atom"].device.type == dev.type @pytest.mark.parametrize( @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None: internal_cell_offsets, ) - assert result["quad_c_to_a_edge"].device == dev - assert result["quad_d_to_b_trip_idx"].device == dev - assert result["d_to_b_edge"].device == dev - assert result["c_to_a_edge"].device == dev + assert result["quad_c_to_a_edge"].device.type == dev.type + assert result["quad_d_to_b_trip_idx"].device.type == dev.type + assert result["d_to_b_edge"].device.type == dev.type + assert result["c_to_a_edge"].device.type == dev.type def test_build_triplets_jit_script() -> None: diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 3ab10533c..35f0b5665 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1464,8 +1464,12 @@ def test_optimizer_preserves_charge_spin( original_spin = torch.tensor( [6.0], device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype ) - ar_supercell_sim_state.charge = original_charge.clone() - ar_supercell_sim_state.spin = original_spin.clone() + + # NOTE the convenience method for setting extras is not used here because + # they are only available if the key is already in the extras dict. If not + # we need to set them explicitly. + ar_supercell_sim_state._system_extras["charge"] = original_charge.clone() # noqa: SLF001 + ar_supercell_sim_state._system_extras["spin"] = original_spin.clone() # noqa: SLF001 init_fn, step_fn = ts.OPTIM_REGISTRY[optimizer_fn] opt_state = init_fn( diff --git a/tests/test_state.py b/tests/test_state.py index ed1b3c29d..4cd98693d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -34,7 +34,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell", "charge", "spin"} + assert set(per_system_attrs) == {"cell"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) assert set(global_attrs) == {"pbc", "_rng"} diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 944cdb813..72341ebb1 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -91,7 +91,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 90) < angle_tol ): - return BravaisType.cubic + return BravaisType.CUBIC # Hexagonal: a = b ≠ c, alpha = beta = 90°, gamma = 120° if ( @@ -100,7 +100,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 120) < angle_tol ): - return BravaisType.hexagonal + return BravaisType.HEXAGONAL # Tetragonal: a = b ≠ c, alpha = beta = gamma = 90° if ( @@ -110,7 +110,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 90) < angle_tol ): - return BravaisType.tetragonal + return BravaisType.TETRAGONAL # Orthorhombic: a ≠ b ≠ c, alpha = beta = gamma = 90° if ( @@ -120,7 +120,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(a - b) > length_tol and (abs(b - c) > length_tol or abs(a - c) > length_tol) ): - return BravaisType.orthorhombic + return BravaisType.ORTHORHOMBIC # Monoclinic: a ≠ b ≠ c, alpha = gamma = 90°, beta ≠ 90° if ( @@ -128,7 +128,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(gamma - 90) < angle_tol and abs(beta - 90) > angle_tol ): - return BravaisType.monoclinic + return BravaisType.MONOCLINIC # Trigonal/Rhombohedral: a = b = c, alpha = beta = gamma ≠ 90° if ( @@ -138,10 +138,10 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - gamma) < angle_tol and abs(alpha - 90) > angle_tol ): - return BravaisType.trigonal + return BravaisType.TRIGONAL # Triclinic: a ≠ b ≠ c, alpha ≠ beta ≠ gamma ≠ 90° - return BravaisType.triclinic + return BravaisType.TRICLINIC def regular_symmetry(strains: torch.Tensor) -> torch.Tensor: @@ -666,21 +666,13 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> else: # axis == 5 L[0, 1] += size # xy shear - # Convert positions to fractional coordinates - old_inv = torch.linalg.inv(row_vector_cell) - frac_coords = torch.matmul(positions, old_inv) + frac_coords = torch.matmul(positions, torch.linalg.inv(row_vector_cell)) + new_cell = torch.matmul(row_vector_cell, L) - # Apply transformation to cell and convert positions back to cartesian - row_vector_cell = torch.matmul(row_vector_cell, L) - new_positions = torch.matmul(frac_coords, row_vector_cell) - - return SimState( - positions=new_positions, - cell=row_vector_cell.mT.unsqueeze(0), - masses=state.masses, - pbc=state.pbc, - atomic_numbers=state.atomic_numbers, - ) + new_state = state.clone() + new_state.row_vector_cell = new_cell.unsqueeze(0) + new_state.positions = torch.matmul(frac_coords, new_cell) + return new_state def get_elementary_deformations( @@ -719,20 +711,20 @@ def get_elementary_deformations( # Deformation rules for different Bravais lattices # Each tuple contains (allowed_axes, symmetry_handler_function) deformation_rules: dict[BravaisType, DeformationRule] = { - BravaisType.cubic: DeformationRule([0, 3], regular_symmetry), - BravaisType.hexagonal: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), - BravaisType.trigonal: DeformationRule([0, 1, 2, 3, 4, 5], trigonal_symmetry), - BravaisType.tetragonal: DeformationRule([0, 2, 3, 5], tetragonal_symmetry), - BravaisType.orthorhombic: DeformationRule( + BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry), + BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), + BravaisType.TRIGONAL: DeformationRule([0, 1, 2, 3, 4, 5], trigonal_symmetry), + BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 5], tetragonal_symmetry), + BravaisType.ORTHORHOMBIC: DeformationRule( [0, 1, 2, 3, 4, 5], orthorhombic_symmetry ), - BravaisType.monoclinic: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), - BravaisType.triclinic: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), + BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), + BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), } # Get deformation rules for this Bravais lattice (default to triclinic if None) if bravais_type is None: - bravais_type = BravaisType.triclinic + bravais_type = BravaisType.TRICLINIC rule = deformation_rules[bravais_type] allowed_axes = rule.axes @@ -898,7 +890,7 @@ def get_elastic_coeffs( deformed_states: list[SimState], stresses: torch.Tensor, base_pressure: torch.Tensor, - bravais_type: BravaisType = BravaisType.triclinic, + bravais_type: BravaisType = BravaisType.TRICLINIC, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]]: """Calculate elastic tensor from stress-strain relationships. @@ -932,15 +924,15 @@ def get_elastic_coeffs( """ # Deformation rules for different Bravais lattices deformation_rules: dict[BravaisType, DeformationRule] = { - BravaisType.cubic: DeformationRule([0, 3], regular_symmetry), - BravaisType.hexagonal: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), - BravaisType.trigonal: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry), - BravaisType.tetragonal: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry), - BravaisType.orthorhombic: DeformationRule( + BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry), + BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), + BravaisType.TRIGONAL: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry), + BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry), + BravaisType.ORTHORHOMBIC: DeformationRule( [0, 1, 2, 3, 4, 5], orthorhombic_symmetry ), - BravaisType.monoclinic: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), - BravaisType.triclinic: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), + BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), + BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), } # Get symmetry handler for this Bravais lattice @@ -973,15 +965,15 @@ def get_elastic_coeffs( # Calculate elastic constants with pressure correction p = base_pressure pressure_corrections = { - BravaisType.cubic: torch.tensor([-p, p, -p]), - BravaisType.hexagonal: torch.tensor([-p, -p, p, p, -p]), - BravaisType.trigonal: torch.tensor([-p, -p, p, p, p, p, -p]), - BravaisType.tetragonal: torch.tensor([-p, -p, p, p, -p, -p, -p]), - BravaisType.orthorhombic: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]), - BravaisType.monoclinic: torch.tensor( + BravaisType.CUBIC: torch.tensor([-p, p, -p]), + BravaisType.HEXAGONAL: torch.tensor([-p, -p, p, p, -p]), + BravaisType.TRIGONAL: torch.tensor([-p, -p, p, p, p, p, -p]), + BravaisType.TETRAGONAL: torch.tensor([-p, -p, p, p, -p, -p, -p]), + BravaisType.ORTHORHOMBIC: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]), + BravaisType.MONOCLINIC: torch.tensor( [-p, -p, -p, p, p, p, -p, -p, -p, p, p, p, p] ), - BravaisType.triclinic: torch.tensor( + BravaisType.TRICLINIC: torch.tensor( [ -p, p, @@ -1044,7 +1036,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 # Initialize full tensor C = torch.zeros((6, 6), dtype=Cij.dtype, device=Cij.device) - if bravais_type == BravaisType.triclinic: + if bravais_type == BravaisType.TRICLINIC: if len(Cij) != 21: raise ValueError( f"Triclinic symmetry requires 21 independent constants, " @@ -1057,19 +1049,19 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[i, j] = C[j, i] = Cij[idx] idx += 1 - elif bravais_type == BravaisType.cubic: + elif bravais_type == BravaisType.CUBIC: C11, C12, C44 = Cij diag = torch.tensor([C11, C11, C11, C44, C44, C44]) C.diagonal().copy_(diag) C[0, 1] = C[1, 0] = C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C12 - elif bravais_type == BravaisType.hexagonal: + elif bravais_type == BravaisType.HEXAGONAL: C11, C12, C13, C33, C44 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2])) C[0, 1] = C[1, 0] = C12 C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C13 - elif bravais_type == BravaisType.trigonal: + elif bravais_type == BravaisType.TRIGONAL: C11, C12, C13, C14, C15, C33, C44 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2])) C[0, 1] = C[1, 0] = C12 @@ -1081,7 +1073,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[3, 5] = C[5, 3] = -C15 C[4, 5] = C[5, 4] = C14 - elif bravais_type == BravaisType.tetragonal: + elif bravais_type == BravaisType.TETRAGONAL: C11, C12, C13, C16, C33, C44, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, C66])) C[0, 1] = C[1, 0] = C12 @@ -1089,14 +1081,14 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[0, 5] = C[5, 0] = C16 C[1, 5] = C[5, 1] = -C16 - elif bravais_type == BravaisType.orthorhombic: + elif bravais_type == BravaisType.ORTHORHOMBIC: C11, C12, C13, C22, C23, C33, C44, C55, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66])) C[0, 1] = C[1, 0] = C12 C[0, 2] = C[2, 0] = C13 C[1, 2] = C[2, 1] = C23 - elif bravais_type == BravaisType.monoclinic: + elif bravais_type == BravaisType.MONOCLINIC: C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66])) C[0, 1] = C[1, 0] = C12 @@ -1114,7 +1106,7 @@ def calculate_elastic_tensor( state: OptimState, model: ModelInterface, *, - bravais_type: BravaisType = BravaisType.triclinic, + bravais_type: BravaisType = BravaisType.TRICLINIC, max_strain_normal: float = 0.01, max_strain_shear: float = 0.06, n_deform: int = 5, diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 4a593c048..a88685746 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -233,6 +233,7 @@ def velocity_verlet_step[T: MDState]( state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_2) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 1b47e16f7..883d992a4 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -633,7 +633,7 @@ def npt_langevin_init( logger.warning(msg) # Create the initial state - return NPTLangevinState.from_state( + npt_state = NPTLangevinState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -647,6 +647,8 @@ def npt_langevin_init( cell_masses=cell_masses, cell_alpha=cell_alpha, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1063/1.4901303") @@ -742,6 +744,7 @@ def npt_langevin_step( state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Compute updated pressure force F_p_n_new = _compute_cell_force( @@ -1283,6 +1286,7 @@ def _npt_nose_hoover_inner_step( state.forces = model_output["forces"] state.stress = model_output["stress"] state.energy = model_output["energy"] + state.store_model_extras(model_output) state.cell_position = cell_position state.cell_momentum = cell_momentum state.cell_mass = cell_mass @@ -1437,7 +1441,7 @@ def npt_nose_hoover_init( logger.warning(msg) # Create initial state - return NPTNoseHooverState.from_state( + npt_state = NPTNoseHooverState.from_state( state, momenta=momenta, energy=energy, @@ -1453,6 +1457,8 @@ def npt_nose_hoover_init( barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1080/00268979600100761") @@ -2074,6 +2080,7 @@ def npt_crescale_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt_tensor / 2) @@ -2149,6 +2156,7 @@ def npt_crescale_independent_lengths_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2225,6 +2233,7 @@ def npt_crescale_average_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2302,6 +2311,7 @@ def npt_crescale_isotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2375,7 +2385,7 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState.from_state( + npt_state = NPTCRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -2384,3 +2394,5 @@ def npt_crescale_init( tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) + npt_state.store_model_extras(model_output) + return npt_state diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 07f3064bb..316ef78c7 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,12 +57,14 @@ def nve_init( state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state def nve_step( @@ -100,5 +102,6 @@ def nve_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 1b8017279..e6379b47e 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -126,12 +126,14 @@ def nvt_langevin_init( kT, state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state @dcite("10.1098/rspa.2016.0138") @@ -191,6 +193,7 @@ def nvt_langevin_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_tensor / 2) @@ -319,7 +322,7 @@ def nvt_nose_hoover_init( dof_per_system = state.get_number_of_degrees_of_freedom() - 3 # Initialize state - return NVTNoseHooverState.from_state( + nh_state = NVTNoseHooverState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -328,6 +331,8 @@ def nvt_nose_hoover_init( chain=chain_fns.initialize(dof_per_system, KE, kT_tensor), _chain_fns=chain_fns, ) + nh_state.store_model_extras(model_output) + return nh_state @dcite("10.1080/00268979600100761") @@ -606,12 +611,14 @@ def nvt_vrescale_init( state.rng, ) - return NVTVRescaleState.from_state( + vr_state = NVTVRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + vr_state.store_model_extras(model_output) + return vr_state @dcite("10.1063/1.2408420") diff --git a/torch_sim/io.py b/torch_sim/io.py index edce091e1..50d0dc702 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -12,6 +12,8 @@ * Batched conversions for multiple structures """ +from __future__ import annotations + from typing import TYPE_CHECKING import numpy as np @@ -26,28 +28,40 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure + from torch_sim.typing import ExtrasMap + @dcite( "10.1088/1361-648X/aa680e", description="ASE: Atomic Simulation Environment", path="ase", ) -def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: +def state_to_atoms( + state: ts.SimState, + *, + system_extras: ExtrasMap | None = None, + atom_extras: ExtrasMap | None = None, +) -> list[Atoms]: """Convert a SimState to a list of ASE Atoms objects. Args: - state (SimState): Batched state containing positions, cell, and atomic numbers + state: Batched state containing positions, cell, and atomic numbers. + system_extras: Map of ``{ts_key: ase_key}`` controlling which + ``_system_extras`` entries are written to ``atoms.info``. + ``None`` (default) means no extras are written. + atom_extras: Map of ``{ts_key: ase_key}`` controlling which + ``_atom_extras`` entries are written to ``atoms.arrays``. + ``None`` (default) means no extras are written. Returns: - list[Atoms]: ASE Atoms objects, one per system + list[Atoms]: ASE Atoms objects, one per system. Raises: - ImportError: If ASE is not installed + ImportError: If ASE is not installed. Notes: - Output positions and cell will be in Å - Output masses will be in amu - - Charge and spin are preserved in atoms.info if present in the state """ try: from ase import Atoms @@ -70,10 +84,6 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: else np.array([state.pbc] * 3 if isinstance(state.pbc, bool) else state.pbc) ) - # Extract charge and spin if available (per-system attributes) - charge = state.charge.detach().cpu().numpy() if state.charge is not None else None - spin = state.spin.detach().cpu().numpy() if state.spin is not None else None - atoms_list = [] for sys_idx in np.unique(system_indices): mask = system_indices == sys_idx @@ -91,11 +101,17 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc_for_sys ) - # Preserve charge and spin in atoms.info (as integers for FairChem compatibility) - if charge is not None: - atoms.info["charge"] = int(charge[sys_idx].item()) - if spin is not None: - atoms.info["spin"] = int(spin[sys_idx].item()) + if system_extras: + for ts_key, ase_key in system_extras.items(): + if ts_key in state.system_extras: + val = state.system_extras[ts_key][sys_idx].detach().cpu().numpy() + atoms.info[ase_key] = val + + if atom_extras: + for ts_key, ase_key in atom_extras.items(): + if ts_key in state.atom_extras: + val = state.atom_extras[ts_key][mask].detach().cpu().numpy() + atoms.arrays[ase_key] = val atoms_list.append(atoms) @@ -107,7 +123,7 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: description="pymatgen: Python Materials Genomics", path="pymatgen", ) -def state_to_structures(state: "ts.SimState") -> list["Structure"]: +def state_to_structures(state: ts.SimState) -> list[Structure]: """Convert a SimState to a list of Pymatgen Structure objects. Args: @@ -183,7 +199,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: description="Phonopy: harmonic and quasi-harmonic phonon calculationss", path="phonopy", ) -def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: +def state_to_phonopy(state: ts.SimState) -> list[PhonopyAtoms]: """Convert a SimState to a list of PhonopyAtoms objects. Args: @@ -241,24 +257,33 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: path="ase", ) def atoms_to_state( - atoms: "Atoms | list[Atoms]", + atoms: Atoms | list[Atoms], device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> "ts.SimState": + *, + system_extras: ExtrasMap | None = None, + atom_extras: ExtrasMap | None = None, +) -> ts.SimState: """Convert an ASE Atoms object or list of Atoms objects to a SimState. Args: - atoms (Atoms | list[Atoms]): Single ASE Atoms object or list of Atoms objects - device (torch.device): Device to create tensors on - dtype (torch.dtype): Data type for tensors (typically torch.float32 or - torch.float64) + atoms: Single ASE Atoms object or list of Atoms objects. + device: Device to create tensors on. + dtype: Data type for tensors (typically ``torch.float32`` or + ``torch.float64``). + system_extras: Map of ``{ts_key: ase_key}`` controlling which + ``atoms.info`` entries are read into ``_system_extras``. + ``None`` (default) means no extras are read. + atom_extras: Map of ``{ts_key: ase_key}`` controlling which + ``atoms.arrays`` entries are read into ``_atom_extras``. + ``None`` (default) means no extras are read. Returns: SimState: TorchSim SimState object. Raises: - ImportError: If ASE is not installed - ValueError: If systems have inconsistent periodic boundary conditions + ImportError: If ASE is not installed. + ValueError: If systems have inconsistent periodic boundary conditions. Notes: - Input positions and cell should be in Å @@ -298,12 +323,25 @@ def atoms_to_state( if not all(np.all(np.equal(at.pbc, atoms_list[0].pbc)) for at in atoms_list[1:]): raise ValueError("All systems must have the same periodic boundary conditions") - charge = torch.tensor( - [at.info.get("charge", 0.0) for at in atoms_list], dtype=dtype, device=device - ) - spin = torch.tensor( - [at.info.get("spin", 0.0) for at in atoms_list], dtype=dtype, device=device - ) + _system_extras: dict[str, torch.Tensor] = {} + if system_extras: + for ts_key, ase_key in system_extras.items(): + vals = [at.info.get(ase_key) for at in atoms_list] + non_none = [v for v in vals if v is not None] + if len(non_none) == len(vals): + _system_extras[ts_key] = torch.tensor( + np.array(non_none), dtype=dtype, device=device + ) + + _atom_extras: dict[str, torch.Tensor] = {} + if atom_extras: + for ts_key, ase_key in atom_extras.items(): + arrays = [at.arrays.get(ase_key) for at in atoms_list] + non_none = [a for a in arrays if a is not None] + if len(non_none) == len(arrays): + _atom_extras[ts_key] = torch.tensor( + np.concatenate(non_none), dtype=dtype, device=device + ) return ts.SimState( positions=positions, @@ -312,8 +350,8 @@ def atoms_to_state( pbc=atoms_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, - charge=charge, - spin=spin, + _system_extras=_system_extras, + _atom_extras=_atom_extras, ) @@ -323,10 +361,10 @@ def atoms_to_state( path="pymatgen", ) def structures_to_state( - structure: "Structure | list[Structure]", + structure: Structure | list[Structure], device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> "ts.SimState": +) -> ts.SimState: """Create a SimState from pymatgen Structure(s). Args: @@ -404,10 +442,10 @@ def structures_to_state( path="phonopy", ) def phonopy_to_state( - phonopy_atoms: "PhonopyAtoms | list[PhonopyAtoms]", + phonopy_atoms: PhonopyAtoms | list[PhonopyAtoms], device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> "ts.SimState": +) -> ts.SimState: """Create state tensors from a PhonopyAtoms object or list of PhonopyAtoms objects. Args: diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1c95b00da..56a6d156f 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -223,8 +223,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] pbc=pbc_np if cell is not None else False, ) - charge = sim_state.charge - spin = sim_state.spin + charge = getattr(sim_state, "charge", None) + spin = getattr(sim_state, "spin", None) atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0 diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 885627b1c..98d284335 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -37,6 +37,8 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): if TYPE_CHECKING: + from collections.abc import Callable + from torch_sim.state import SimState from torch_sim.typing import MemoryScaling @@ -263,7 +265,7 @@ def compute_forces(self, value: bool) -> None: # noqa: FBT001 @property def retain_graph(self) -> bool: """Whether any child model retains the computation graph.""" - return any(getattr(m, "retain_graph", False) for m in self._children()) + return all(getattr(m, "retain_graph", False) for m in self._children()) @retain_graph.setter def retain_graph(self, value: bool) -> None: @@ -345,6 +347,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 dtype: torch.dtype, *, check_detached: bool = False, + state_modifier: Callable[[SimState], SimState] | None = None, ) -> None: """Validate the outputs of a model implementation against the interface requirements. @@ -360,6 +363,9 @@ def validate_model_outputs( # noqa: C901, PLR0915 detached from the autograd graph, unless the model has a ``retain_graph`` attribute set to ``True``. Defaults to ``False`` so that external callers are not immediately broken. + state_modifier: If provided, applied to every ``SimState`` created + during validation before the model sees it. Must return the + (possibly new) state. Raises: AssertionError: If the model doesn't conform to the required interface, @@ -378,7 +384,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 and primitive BCC iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase.build import bulk + from ase.build import bulk, molecule + + def _modify(state: SimState) -> SimState: + return state_modifier(state) if state_modifier is not None else state for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): @@ -401,8 +410,9 @@ def validate_model_outputs( # noqa: C901, PLR0915 si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21).repeat([3, 2, 1]) fe_atoms = bulk("Fe", "bcc", a=2.87) - sim_state = ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype) - + sim_state = _modify( + ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype) + ) og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() system_idx = sim_state.system_idx @@ -446,8 +456,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{model_output['stress'].shape=} != (3, 3, 3)") # Test single Si system output shapes (8 atoms) - si_state = ts.io.atoms_to_state([si_atoms], device, dtype) - + si_state = _modify(ts.io.atoms_to_state([si_atoms], device, dtype)) si_model_output = model.forward(si_state) if not torch.allclose( si_model_output["energy"], model_output["energy"][0], atol=VALIDATE_ATOL @@ -468,7 +477,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") # Test single Mg system output shapes (12 atoms) - mg_state = ts.io.atoms_to_state([mg_atoms], device, dtype) + mg_state = _modify(ts.io.atoms_to_state([mg_atoms], device, dtype)) mg_model_output = model.forward(mg_state) if not torch.allclose( mg_model_output["energy"], model_output["energy"][1], atol=VALIDATE_ATOL @@ -492,7 +501,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 # Test single Fe system output shapes (1 atom) # This catches that models do not squeeze away singleton dimensions. - fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) + fe_state = _modify(ts.io.atoms_to_state([fe_atoms], device, dtype)) fe_model_output = model.forward(fe_state) if not torch.allclose( fe_model_output["energy"], model_output["energy"][2], atol=VALIDATE_ATOL @@ -542,3 +551,18 @@ def validate_model_outputs( # noqa: C901, PLR0915 "vector: max diff = " f"{(shifted_output['stress'] - si_model_output['stress']).abs().max()}" ) + + # Test a non-periodic molecule (benzene) + benzene_atoms = molecule("C6H6") + benzene_state = _modify(ts.io.atoms_to_state([benzene_atoms], device, dtype)) + benzene_output = model.forward(benzene_state) + if benzene_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect for benzene: " + f"{benzene_output['energy'].shape=} != (1,)" + ) + if force_computed and benzene_output["forces"].shape != (12, 3): + raise ValueError( + f"forces shape incorrect for benzene: " + f"{benzene_output['forces'].shape=} != (12, 3)" + ) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index e13e3ce85..01afdc51f 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -304,8 +304,8 @@ def forward( # noqa: C901 edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, - total_charge=state.charge, - total_spin=state.spin, + total_charge=getattr(state, "charge", None), + total_spin=getattr(state, "spin", None), ) # Get model output @@ -336,6 +336,13 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Propagate additional model outputs (e.g. dipole, charges, etc.) + for key, val in out.items(): + if key not in ("energy", "forces", "stress") and isinstance( + val, torch.Tensor + ): + results[key] = val.detach() + return results diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 80587a372..5c910cf49 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -11,16 +11,38 @@ import warnings from typing import Any +import torch + try: from orb_models.forcefield.inference.orb_torchsim import OrbTorchSimModel + import torch_sim as ts + # Re-export with backward-compatible name class OrbModel(OrbTorchSimModel): """ORB model wrapper for torch-sim.""" + @staticmethod + def _normalize_charge_spin(state: "ts.SimState") -> "ts.SimState": + """Provide ORB's optional charge/spin inputs when they are missing.""" + charge = getattr(state, "charge", None) + spin = getattr(state, "spin", None) + if charge is not None and spin is not None: + return state + zeros = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + return ts.SimState.from_state( + state, + charge=charge if charge is not None else zeros, + spin=spin if spin is not None else zeros, + ) + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Run forward pass, detaching outputs unless retain_graph is True.""" + if args and isinstance(args[0], ts.SimState): + args = (self._normalize_charge_spin(args[0]), *args[1:]) + elif isinstance(kwargs.get("state"), ts.SimState): + kwargs["state"] = self._normalize_charge_spin(kwargs["state"]) output = super().forward(*args, **kwargs) return { # detach tensors as energy is not detached by default k: v.detach() if hasattr(v, "detach") else v for k, v in output.items() diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 04dfde316..8a4a0d37c 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -223,7 +223,7 @@ def swap_mc_init( """ model_output = model(state) - return SwapMCState( + mc_state = SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, @@ -233,6 +233,8 @@ def swap_mc_init( energy=model_output["energy"], _constraints=state.constraints, ) + mc_state.store_model_extras(model_output) + return mc_state def swap_mc_step( @@ -292,5 +294,6 @@ def swap_mc_step( state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() + state.store_model_extras(model_output) return state diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index c3a344cf3..c9ada4992 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -25,13 +25,13 @@ _clamp_deform_grad_log, frechet_cell_filter_init, ) +from torch_sim.optimizers.state import BFGSState from torch_sim.state import SimState if TYPE_CHECKING: from torch_sim.models.interface import ModelInterface from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs - from torch_sim.optimizers.state import BFGSState BFGS_EPS = 1e-7 # eps kept same as ASE's BFGS. @@ -115,8 +115,6 @@ def bfgs_init( Returns: BFGSState or CellBFGSState if cell_filter is provided """ - from torch_sim.optimizers import BFGSState, CellBFGSState - device: torch.device = model.device dtype: torch.dtype = model.dtype @@ -143,6 +141,19 @@ def bfgs_init( n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] + bfgs_attrs = { + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # [S] + } + if cell_filter is not None: # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) @@ -153,59 +164,31 @@ def bfgs_init( cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - # Note (AG): At initialization, deform_grad is identity, so we have: - # fractional = Cartesian / cell and scaled forces = forces @ I = forces - # For ASE compatibility, we need to store prev_positions as fractional coords - # and prev_forces as scaled forces - - # Get initial deform_grad (identity at start since reference_cell = current_cell) + # At initialization, deform_grad is identity, so fractional = Cartesian + # and scaled forces = forces. For ASE compatibility, store prev_positions + # as fractional coords and prev_forces as scaled forces. reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = cell_filters.deform_grad( reference_cell.mT, state.cell.mT ) # [S, 3, 3] - # Initial fractional positions = solve(deform_grad, positions) = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D_ext, D_ext] - # Note (AG): Store fractional positions and scaled forces - # for ASE compatibility - "prev_forces": scaled_forces, # [N, 3] (scaled) - "prev_positions": frac_positions, # [N, 3] (fractional) - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "reference_cell": reference_cell, # [S, 3, 3] - "cell_filter": cell_filter_funcs, - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - cell_state = CellBFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_attrs["hessian"] = hessian # [S, D_ext, D_ext] + bfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) + bfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + bfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + bfgs_attrs["cell_filter"] = cell_filter_funcs + + cell_state = CellBFGSState.from_state(state, **bfgs_attrs) # Initialize cell-specific attributes (cell_positions, cell_forces, etc.) # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -215,6 +198,7 @@ def bfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms @@ -222,31 +206,11 @@ def bfgs_init( hessian = torch.eye(dim, device=device, dtype=dtype).unsqueeze(0).repeat( n_systems, 1, 1 ) * alpha_t.view(n_systems, 1, 1) # [S, D, D] + bfgs_attrs["hessian"] = hessian # [S, D, D] - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D, D] - "prev_forces": forces.clone(), # [N, 3] - "prev_positions": state.positions.clone(), # [N, 3] - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - return BFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_state = BFGSState.from_state(state, **bfgs_attrs) + bfgs_state.store_model_extras(model_output) + return bfgs_state def bfgs_step( # noqa: C901, PLR0915 @@ -550,6 +514,7 @@ def bfgs_step( # noqa: C901, PLR0915 state.energy = model_output["energy"] # [S] if "stress" in model_output: state.stress = model_output["stress"] # [S, 3, 3] + state.store_model_extras(model_output) # Update cell forces for next step # Update cell forces for cell state: [S, 3, 3] diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 8efcb3a7b..e45c7ec8d 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -106,9 +106,12 @@ def fire_init( cell_state.cell_forces.shape, torch.nan, device=device, dtype=dtype ) + cell_state.store_model_extras(model_output) return cell_state # Create regular FireState without cell optimization - return FireState.from_state(state, **fire_attrs) + fire_state = FireState.from_state(state, **fire_attrs) + fire_state.store_model_extras(model_output) + return fire_state def fire_step( @@ -173,7 +176,7 @@ def fire_step( return step_func(state, **step_func_kwargs) # ty: ignore[invalid-argument-type] -def _vv_fire_step[T: "FireState | CellFireState"]( +def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state: T, model: "ModelInterface", *, @@ -215,6 +218,7 @@ def _vv_fire_step[T: "FireState | CellFireState"]( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): @@ -465,6 +469,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 6f940ff0f..7356ffe80 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -53,6 +53,8 @@ def gradient_descent_init( "stress": stress, } + state.store_model_extras(model_output) + if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) optim_attrs["reference_cell"] = state.cell.clone() @@ -112,6 +114,7 @@ def gradient_descent_step( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellOptimState): diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index f8413399c..cb88c6577 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -192,22 +192,10 @@ def lbfgs_init( if step_size_tensor.ndim == 0: step_size_tensor = step_size_tensor.expand(n_systems) - common_args = { - # Copy SimState attributes - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - # Optimization state + lbfgs_attrs = { "forces": forces, # [N, 3] "energy": energy, # [S] "stress": stress, # [S, 3, 3] or None - # L-BFGS specific state "prev_forces": forces.clone(), # [N, 3] "prev_positions": state.positions.clone(), # [N, 3] "s_history": s_history, # [S, 0, M, 3] @@ -227,41 +215,35 @@ def lbfgs_init( reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = deform_grad(reference_cell.mT, state.cell.mT) # [S, 3, 3] - # Initial fractional positions = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) # [N, 3] - common_args["reference_cell"] = reference_cell # [S, 3, 3] - common_args["cell_filter"] = cell_filter_funcs - # Store fractional positions and scaled forces for ASE compatibility - common_args["prev_positions"] = frac_positions # [N, 3] - common_args["prev_forces"] = scaled_forces # [N, 3] + lbfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + lbfgs_attrs["cell_filter"] = cell_filter_funcs + lbfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + lbfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) # Extended per-system history includes cell DOFs (3 "virtual atoms" per system) - # History shape: [S, H, M+3, 3] where M = global_max_atoms extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3 - common_args["s_history"] = torch.zeros( + lbfgs_attrs["s_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - common_args["y_history"] = torch.zeros( + lbfgs_attrs["y_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - cell_state = CellLBFGSState(**common_args) # ty: ignore[invalid-argument-type] + cell_state = CellLBFGSState.from_state(state, **lbfgs_attrs) # Initialize cell-specific attributes # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -271,9 +253,12 @@ def lbfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state - return LBFGSState(**common_args) # ty: ignore[invalid-argument-type] + lbfgs_state = LBFGSState.from_state(state, **lbfgs_attrs) + lbfgs_state.store_model_extras(model_output) + return lbfgs_state def lbfgs_step( # noqa: PLR0915, C901 @@ -536,6 +521,7 @@ def lbfgs_step( # noqa: PLR0915, C901 new_forces = model_output["forces"] # [N, 3] new_energy = model_output["energy"] # [S] new_stress = model_output.get("stress") # [S, 3, 3] or None + state.store_model_extras(model_output) # Update cell forces for next step: [S, 3, 3] if isinstance(state, CellLBFGSState): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 72c25c070..7e66ca26d 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState -from torch_sim.state import SimState +from torch_sim.state import _CANONICAL_MODEL_KEYS, SimState from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike from torch_sim.units import UnitSystem @@ -732,7 +732,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) -def static( +def static( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -836,8 +836,25 @@ class StaticState(SimState): else torch.full_like(sub_state.cell, fill_value=float("nan")) ), ) + static_state.store_model_extras(model_outputs) props = trajectory_reporter.report(static_state, 0, model=model) + + # Merge extra model outputs into per-system property dicts + # TODO: this should be cleaner? + extra_keys = {k for k in model_outputs if k not in _CANONICAL_MODEL_KEYS} + if extra_keys: + for sys_idx, sys_props in enumerate(props): + for key in extra_keys: + val = model_outputs[key] + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + if val.shape[0] == static_state.n_atoms: + mask = static_state.system_idx == sys_idx + sys_props[key] = val[mask] + elif val.shape[0] == static_state.n_systems: + sys_props[key] = val[sys_idx : sys_idx + 1] + all_props.extend(props) if tqdm_pbar: diff --git a/torch_sim/state.py b/torch_sim/state.py index 68b656e07..fde62b5aa 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -5,18 +5,19 @@ """ import copy +import functools import importlib import typing from collections import defaultdict from collections.abc import Generator, Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self import torch from torch._prims_common import DeviceLikeType import torch_sim as ts -from torch_sim.typing import PRNGLike, StateLike +from torch_sim.typing import ExtrasMap, PRNGLike, StateLike if TYPE_CHECKING: @@ -32,6 +33,10 @@ ) +# Canonical model output keys that are handled explicitly by integrators/runners +_CANONICAL_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) + + def coerce_prng(rng: PRNGLike, device: DeviceLikeType | None) -> torch.Generator: """Coerce an int seed or existing Generator into a ``torch.Generator``. @@ -67,6 +72,27 @@ def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: return system_idx +def _wrap_init_for_extras(cls: type) -> None: + """Wrap a dataclass __init__ to route unknown kwargs into _system_extras.""" + original_init = cls.__init__ + all_fields = {f.name for f in fields(cls)} + + @functools.wraps(original_init) + def _wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: + extras = kwargs.get("_system_extras") + if extras is None: + extras = {} + kwargs["_system_extras"] = extras + unknown = [k for k in kwargs if k not in all_fields] + for key in unknown: + val = kwargs.pop(key) + if val is not None: + extras[key] = val + original_init(self, *args, **kwargs) + + cls.__init__ = _wrapped_init # type: ignore[assignment] + + @dataclass(kw_only=True) class SimState: """State representation for atomistic systems with batched operations support. @@ -129,10 +155,10 @@ class SimState: cell: torch.Tensor pbc: torch.Tensor # coerced from bool/list[bool] by __setattr__ atomic_numbers: torch.Tensor - charge: torch.Tensor | None = field(default=None) - spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ - _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 + _constraints: list["Constraint"] = field(default_factory=list) + _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) + _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) _rng: PRNGLike = field(default=None, repr=False) if TYPE_CHECKING: @@ -145,11 +171,10 @@ def __init__( # noqa: D107 cell: torch.Tensor, pbc: torch.Tensor | list[bool] | bool, atomic_numbers: torch.Tensor, - charge: torch.Tensor | None = None, - spin: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, _constraints: list[Constraint] | None = None, _rng: PRNGLike = None, + **kwargs: Any, ) -> None: ... _atom_attributes: ClassVar[set[str]] = { @@ -158,7 +183,7 @@ def __init__( # noqa: D107 "atomic_numbers", "system_idx", } - _system_attributes: ClassVar[set[str]] = {"cell", "charge", "spin"} + _system_attributes: ClassVar[set[str]] = {"cell"} _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} @property @@ -171,8 +196,23 @@ def rng(self) -> torch.Generator: def rng(self, value: PRNGLike) -> None: self._rng = value - def __setattr__(self, name: str, value: object) -> None: - """Coerce pbc and system_idx on every assignment.""" + def __setattr__(self, name: str, value: object) -> None: # noqa: C901 + """Coerce pbc and system_idx on every assignment. + + Routes writes to existing extras keys back into their extras dict. + """ + if not name.startswith("_"): + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + if value is not None: + extras[name] = value + else: + del extras[name] + return if name == "pbc" and not isinstance(value, torch.Tensor): if isinstance(value, bool): value = [value] * 3 @@ -210,15 +250,6 @@ def __post_init__(self) -> None: # noqa: C901 if self.constraints: validate_constraints(self.constraints, state=self) - if self.charge is None: - self.charge = torch.zeros(n_systems, device=self.device, dtype=self.dtype) - elif self.charge.shape[0] != n_systems: - raise ValueError(f"Charge must have shape (n_systems={n_systems},)") - if self.spin is None: - self.spin = torch.zeros(n_systems, device=self.device, dtype=self.dtype) - elif self.spin.shape[0] != n_systems: - raise ValueError(f"Spin must have shape (n_systems={n_systems},)") - if self.cell.ndim != 3: self.cell = self.cell.unsqueeze(0) @@ -246,6 +277,29 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + # Validate extras shapes and prevent shadowing + all_attrs = self._get_all_attributes() + for key, val in self._system_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"System extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"System extra '{key}' must be a torch.Tensor") + if val.shape[0] != n_systems: + raise ValueError( + f"System extra '{key}' leading dim must be " + f"n_systems={n_systems}, got {val.shape[0]}" + ) + for key, val in self._atom_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"Atom extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"Atom extra '{key}' must be a torch.Tensor") + if val.shape[0] != self.n_atoms: + raise ValueError( + f"Atom extra '{key}' leading dim must be " + f"n_atoms={self.n_atoms}, got {val.shape[0]}" + ) + @classmethod def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" @@ -253,9 +307,72 @@ def _get_all_attributes(cls) -> set[str]: cls._atom_attributes | cls._system_attributes | cls._global_attributes - | {"_constraints"} + | {"_constraints", "_system_extras", "_atom_extras"} ) + def __getattr__(self, name: str) -> Any: + """Allow attribute-style access to extras dict entries.""" + # Guard: don't look up private attrs in extras (avoids recursion during init) + if name.startswith("_"): + raise AttributeError(name) + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + return extras[name] + + # Raise AttributeError so that Python's getattr(obj, name, default), + # hasattr(obj, name), and other descriptor-protocol machinery work correctly. + raise AttributeError( + f"'{type(self).__name__}' has no attribute or extra '{name}'" + ) + + @property + def system_extras(self) -> dict[str, torch.Tensor]: + """Get the system extras.""" + return self._system_extras + + @property + def atom_extras(self) -> dict[str, torch.Tensor]: + """Get the atom extras.""" + return self._atom_extras + + def has_extras(self, key: str) -> bool: + """Check if an extras key exists.""" + return key in self._system_extras or key in self._atom_extras + + def store_model_extras(self, model_output: dict[str, torch.Tensor]) -> None: + """Store non-canonical model outputs into state extras (in-place). + + Any key in *model_output* that is not in ``{"energy", "forces", "stress"}`` + is classified by its leading dimension: + + * ``n_atoms`` → stored in ``_atom_extras`` + * ``n_systems`` → stored in ``_system_extras`` + * otherwise → skipped (ambiguity or scalar) + + When ``n_atoms == n_systems`` (single-atom system), the tensor is stored as + per-atom by convention. + + Args: + model_output: Full dict returned by ``model.forward()``. + """ + n_atoms = self.n_atoms + n_systems = self.n_systems + + for key, val in model_output.items(): + if key in _CANONICAL_MODEL_KEYS: + continue + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + leading = val.shape[0] + if leading == n_atoms: + self._atom_extras[key] = val + elif leading == n_systems: + self._system_extras[key] = val + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -486,18 +603,39 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: if attr_name in cls._get_all_attributes(): attrs[attr_name] = cls._clone_attr(attr_value) - # Add/override with additional attributes - attrs.update(additional_attrs) + # Route additional_attrs: known attrs go directly, unknown tensor attrs + # go to _system_extras (backward compat for charge/spin and extensibility) + all_known = cls._get_all_attributes() + for key, val in additional_attrs.items(): + if key in all_known: + attrs[key] = val + elif isinstance(val, torch.Tensor): + if "_system_extras" not in attrs: + attrs["_system_extras"] = {} + attrs["_system_extras"][key] = val + else: + attrs[key] = val return cls(**attrs) - def to_atoms(self) -> list["Atoms"]: + def to_atoms( + self, + *, + system_extras: ExtrasMap | None = None, + atom_extras: ExtrasMap | None = None, + ) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. + Args: + system_extras: Map of ``{ts_key: ase_key}`` for system extras. + atom_extras: Map of ``{ts_key: ase_key}`` for atom extras. + Returns: - list[Atoms]: A list of ASE Atoms objects, one per system + list[Atoms]: A list of ASE Atoms objects, one per system. """ - return ts.io.state_to_atoms(self) + return ts.io.state_to_atoms( + self, system_extras=system_extras, atom_extras=atom_extras + ) def to_structures(self) -> list["Structure"]: """Convert the SimState to a list of pymatgen Structure objects. @@ -606,7 +744,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) - exceptions = {"system_idx", "charge", "spin"} + exceptions = {"system_idx"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items(): @@ -684,6 +822,9 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: ) +_wrap_init_for_extras(SimState) + + @dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" @@ -769,7 +910,7 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") -def _state_to_device[T: SimState]( +def _state_to_device[T: SimState]( # noqa: C901 state: T, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> T: """Convert the SimState to a new device and dtype. @@ -803,9 +944,23 @@ def _state_to_device[T: SimState]( elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(device=device) for k, v in attrs[extras_key].items() + } + if dtype is not None: attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + # Update floating point extras to new dtype + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(dtype=dtype) if v.is_floating_point() else v + for k, v in attrs[extras_key].items() + } + if attrs.get("_constraints"): attrs["_constraints"] = [ c.to(device=device, dtype=dtype) for c in attrs["_constraints"] @@ -839,6 +994,11 @@ def get_attrs_for_scope( for attr_name in attr_names: yield attr_name, getattr(state, attr_name) + if scope == "per-system": + yield from state._system_extras.items() # noqa: SLF001 + elif scope == "per-atom": + yield from state._atom_extras.items() # noqa: SLF001 + def _filter_attrs_by_index( state: SimState, @@ -894,19 +1054,30 @@ def _filter_attrs_by_index( c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] for name, val in get_attrs_for_scope(state, "per-atom"): + if name in state.atom_extras: + continue filtered_attrs[name] = ( system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices] ) for name, val in get_attrs_for_scope(state, "per-system"): + if name in state.system_extras: + continue filtered_attrs[name] = ( val[system_indices] if isinstance(val, torch.Tensor) else val ) + filtered_attrs["_system_extras"] = { + key: val[system_indices] for key, val in state.system_extras.items() + } + filtered_attrs["_atom_extras"] = { + key: val[atom_indices] for key, val in state.atom_extras.items() + } + return filtered_attrs -def _split_state[T: SimState](state: T) -> list[T]: +def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 """Split a SimState into a list of states, each containing a single system. Divides a multi-system state into individual single-system states, preserving @@ -923,11 +1094,14 @@ def _split_state[T: SimState](state: T) -> list[T]: split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + if attr_name == "system_idx" or attr_name in state.atom_extras: + continue + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + if attr_name in state.system_extras: + continue if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split @@ -935,6 +1109,14 @@ def _split_state[T: SimState](state: T) -> list[T]: global_attrs = dict(get_attrs_for_scope(state, "global")) + split_system_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._system_extras.items(): # noqa: SLF001 + split_system_extras[key] = list(torch.split(val, 1, dim=0)) + + split_atom_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._atom_extras.items(): # noqa: SLF001 + split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) @@ -961,6 +1143,12 @@ def _split_state[T: SimState](state: T) -> list[T]: **per_system_dict, # Add the global attributes **global_attrs, + "_system_extras": { + key: split_system_extras[key][sys_idx] for key in split_system_extras + }, + "_atom_extras": { + key: split_atom_extras[key][sys_idx] for key in split_atom_extras + }, } start_idx = int(cumsum_atoms[sys_idx].item()) @@ -1107,6 +1295,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) + atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] system_offset = 0 num_atoms_per_state = [] @@ -1119,15 +1309,23 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): - if prop == "system_idx": + if prop == "system_idx" or prop in state.atom_extras: # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): + if prop in state.system_extras: + continue per_system_tensors[prop].append(val) + # Collect extras + for key, val in state.system_extras.items(): + system_extras_tensors[key].append(val) + for key, val in state.atom_extras.items(): + atom_extras_tensors[key].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -1198,6 +1396,14 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Concatenate extras + concatenated["_system_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in system_extras_tensors.items() + } + concatenated["_atom_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in atom_extras_tensors.items() + } + # Merge constraints constraint_lists = [state.constraints for state in states] num_systems_per_state = [state.n_systems for state in states] diff --git a/torch_sim/typing.py b/torch_sim/typing.py index ab4a74145..a2221e9e9 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -14,6 +14,35 @@ from torch_sim.state import SimState +class AtomExtras(StrEnum): + """Blessed names for per-atom :class:`~torch_sim.state.SimState` extras. + + Stored in ``SimState._atom_extras``; leading dimension is ``n_atoms``. + """ + + PARTIAL_CHARGES = "partial_charges" + BORN_EFFECTIVE_CHARGES = "born_effective_charges" + MAGNETIC_MOMENTS = "magnetic_moments" + + +class SystemExtras(StrEnum): + """Blessed names for per-system :class:`~torch_sim.state.SimState` extras. + + Stored in ``SimState._system_extras``; leading dimension is ``n_systems``. + """ + + CHARGE = "charge" # TOTAL_CHARGE preferred for less ambiguity with partial charges + SPIN = "spin" # TOTAL_SPIN preferred + TOTAL_CHARGE = "total_charge" + TOTAL_SPIN = "total_spin" + EXTERNAL_E_FIELD = "external_E_field" + POLARIZABILITY = "polarizability" + TOTAL_POLARIZATION = "total_polarization" + EXTERNAL_H_FIELD = "external_H_field" + MAGNETIC_SUSCEPTIBILITY = "magnetic_susceptibility" + TOTAL_MAGNETIZATION = "total_magnetization" + + class BravaisType(StrEnum): """Enumeration of the seven Bravais lattice types in 3D crystals. @@ -25,14 +54,16 @@ class BravaisType(StrEnum): which determine the number of independent elastic constants. """ - cubic = "cubic" - hexagonal = "hexagonal" - trigonal = "trigonal" - tetragonal = "tetragonal" - orthorhombic = "orthorhombic" - monoclinic = "monoclinic" - triclinic = "triclinic" + CUBIC = "cubic" + HEXAGONAL = "hexagonal" + TRIGONAL = "trigonal" + TETRAGONAL = "tetragonal" + ORTHORHOMBIC = "orthorhombic" + MONOCLINIC = "monoclinic" + TRICLINIC = "triclinic" + +ExtrasMap = dict[str, str] StateLike = Union[ "Atoms", diff --git a/torch_sim/units.py b/torch_sim/units.py index e59dfff61..19cd6c4fc 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -10,7 +10,7 @@ from typing import Self -class BaseConstant: +class BaseConstant(float, Enum): """CODATA Recommended Values of the Fundamental Physical Constants: 2014. References: @@ -18,6 +18,10 @@ class BaseConstant: https://wiki.fysik.dtu.dk/ase/_modules/ase/units.html#create_units """ + def __new__(cls, value: float) -> Self: + """Create new BaseConstant enum value.""" + return float.__new__(cls, value) + c = 299792458.0 # speed of light, m/s mu0 = 4.0e-7 * pi # permeability of vacuum grav = 6.67408e-11 # gravitational constant @@ -33,55 +37,30 @@ class BaseConstant: bc = BaseConstant -class UnitConversion: - """Unit conversion class for different unit systems. - - Distance: - Ang (Angstrom) - met (meter) +class UnitConversion(float, Enum): + """Unit conversion factors between common unit systems.""" - Time: - ps (picosecond) - s (second) - fs (femtosecond) - - Pressure: - atm (atmosphere) - pa (pascal) - bar (bar) - GPa (GigaPascal) - - Energy: - cal (calorie) - kcal (kilocalorie) - eV (electron volt) - """ + def __new__(cls, value: float) -> Self: + """Create new UnitConversion enum value.""" + return float.__new__(cls, value) - # Distance Ang_to_met = 1e-10 - Ang2_to_met2 = Ang_to_met * Ang_to_met - Ang3_to_met3 = Ang_to_met * Ang2_to_met2 - - # Time + Ang2_to_met2 = 1e-10**2 + Ang3_to_met3 = 1e-10**3 ps_to_s = 1e-12 fs_to_s = 1e-15 - - # Pressure bar_to_pa = 1e5 atm_to_pa = 101325 pa_to_GPa = 1e-9 - eV_per_Ang3_to_GPa = (bc.e / Ang3_to_met3) * pa_to_GPa - - # Energy + eV_per_Ang3_to_GPa = bc.e * 1e21 cal_to_J = 4.184 kcal_to_cal = 1e3 eV_to_J = bc.e - - # Atomic-unit conversions (Bohr / Hartree <-> Angstrom / eV) Bohr_to_Ang = 0.529177210903 Ang_to_Bohr = 1.0 / Bohr_to_Ang Hartree_to_eV = 27.211386245988 eV_to_Hartree = 1.0 / Hartree_to_eV + e2_per_Ang_to_eV = 14.399645478425668 uc = UnitConversion