From df440c5c4fc4375d57f88c9cd120760f2483e2b2 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 19:46:29 -0400 Subject: [PATCH 01/12] fix: ty now doesn't complain but a bunch of tests fail. --- tests/test_extras.py | 152 +++++++++++++++ tests/test_io.py | 84 ++++++++ torch_sim/integrators/md.py | 1 + torch_sim/integrators/npt.py | 19 +- torch_sim/integrators/nve.py | 5 +- torch_sim/integrators/nvt.py | 13 +- torch_sim/io.py | 64 ++++-- torch_sim/models/interface.py | 48 ++++- torch_sim/models/mace.py | 7 + torch_sim/monte_carlo.py | 5 +- torch_sim/neighbors/vesin.py | 4 +- torch_sim/optimizers/bfgs.py | 95 +++------ torch_sim/optimizers/fire.py | 9 +- torch_sim/optimizers/gradient_descent.py | 3 + torch_sim/optimizers/lbfgs.py | 40 ++-- torch_sim/runners.py | 21 +- torch_sim/state.py | 238 ++++++++++++++++++++--- 17 files changed, 661 insertions(+), 147 deletions(-) create mode 100644 tests/test_extras.py diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 000000000..75947706a --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,152 @@ +import pytest +import torch + +import torch_sim as ts + + +DEVICE = torch.device("cpu") +DTYPE = torch.float64 + + +class TestExtras: + def test_system_extras_construction(self): + """Extras can be passed at construction time.""" + field = torch.randn(1, 3) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + external_E_field=field, + ) + assert torch.equal(state.external_E_field, field) + + def test_atom_extras_construction(self): + """Per-atom extras work at construction time.""" + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + assert torch.equal(state.tags, tags) + + def test_getattr_missing_raises_attribute_error(self, cu_sim_state: ts.SimState): + with pytest.raises(AttributeError, match="nonexistent_key"): + _ = cu_sim_state.nonexistent_key + + def test_post_init_validation_rejects_bad_shape(self): + with pytest.raises(ValueError, match="leading dim must be n_systems"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"bad": torch.randn(5, 3)}, + ) + + def test_construction_extras_cannot_shadow(self): + # Post-init validation should also catch shadowing during construction + with pytest.raises(ValueError, match="shadows an existing attribute"): + ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"cell": torch.zeros(1, 3)}, + ) + + # store_model_extras + def test_store_model_extras_canonical_keys_not_stored( + self, si_double_sim_state: ts.SimState + ): + """Canonical keys (energy, forces, stress) must not land in extras.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "forces": torch.randn(state.n_atoms, 3), + "stress": torch.randn(state.n_systems, 3, 3), + } + ) + assert not state._system_extras # noqa: SLF001 + assert not state._atom_extras # noqa: SLF001 + + def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_systems go into system_extras.""" + state = si_double_sim_state.clone() + dipole = torch.randn(state.n_systems, 3) + state.store_model_extras( + {"energy": torch.randn(state.n_systems), "dipole": dipole} + ) + assert torch.equal(state.dipole, dipole) + + def test_store_model_extras_per_atom(self, si_double_sim_state: ts.SimState): + """Tensors with leading dim == n_atoms go into atom_extras.""" + state = si_double_sim_state.clone() + charges = torch.randn(state.n_atoms) + density = torch.randn(state.n_atoms, 8) + state.store_model_extras( + { + "energy": torch.randn(state.n_systems), + "charges": charges, + "density_coefficients": density, + } + ) + assert torch.equal(state.charges, charges) + assert state.density_coefficients.shape == (state.n_atoms, 8) + + def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState): + """0-d tensors and non-Tensor values are silently ignored.""" + state = si_double_sim_state.clone() + state.store_model_extras( + { + "scalar": torch.tensor(3.14), + "string": "not a tensor", + } + ) + assert not state.has_extras("scalar") + assert not state.has_extras("string") + + +def test_system_extras_atoms_roundtrip(): + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])}, + ) + atoms_list = state.to_atoms() + assert "external_E_field" in atoms_list[0].info + restored = ts.io.atoms_to_state( + atoms_list, + system_extras_keys=["external_E_field"], + ) + assert torch.allclose(restored.external_E_field, state.external_E_field) + + +def test_atom_extras_atoms_roundtrip(): + tags = torch.tensor([1.0, 2.0]) + state = ts.SimState( + positions=torch.zeros(2, 3), + masses=torch.ones(2), + cell=torch.eye(3).unsqueeze(0), + pbc=True, + atomic_numbers=torch.tensor([1, 1], dtype=torch.int), + _atom_extras={"tags": tags}, + ) + atoms_list = state.to_atoms() + assert "tags" in atoms_list[0].arrays + restored = ts.io.atoms_to_state( + atoms_list, + atom_extras_keys=["tags"], + ) + assert torch.allclose(restored.tags, state.tags) diff --git a/tests/test_io.py b/tests/test_io.py index 2bb4f0175..8e1de1ac2 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -5,6 +5,7 @@ import pytest import torch from ase import Atoms +from ase.build import molecule from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure @@ -91,6 +92,69 @@ def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: ) +@pytest.mark.parametrize( + ("charge", "spin", "expected_charge", "expected_spin"), + [ + (1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin + (0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin + (None, None, 0.0, 0.0), # No charge/spin set, defaults to zero + ], +) +def test_atoms_to_state_with_charge_spin( + charge: float | None, + spin: float | None, + expected_charge: float, + expected_spin: float, +) -> None: + """Test conversion from ASE Atoms with charge and spin to state tensors.""" + mol = molecule("H2O") + if charge is not None: + mol.info["charge"] = charge + if spin is not None: + mol.info["spin"] = spin + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (1,) + assert state.spin.shape == (1,) + assert state.charge[0].item() == expected_charge + assert state.spin[0].item() == expected_spin + + +def test_multiple_atoms_to_state_with_charge_spin() -> None: + """Test conversion from multiple ASE Atoms with different charge/spin values.""" + mol1 = molecule("H2O") + mol1.info["charge"] = 1.0 + mol1.info["spin"] = 1.0 + + mol2 = molecule("CH4") + mol2.info["charge"] = -1.0 + mol2.info["spin"] = 0.0 + + mol3 = molecule("NH3") + mol3.info["charge"] = 0.0 + mol3.info["spin"] = 2.0 + + state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE) + + # Check basic properties + assert isinstance(state, SimState) + assert state.charge is not None + assert state.spin is not None + assert state.charge.shape == (3,) + assert state.spin.shape == (3,) + assert state.charge[0].item() == 1.0 + assert state.charge[1].item() == -1.0 + assert state.charge[2].item() == 0.0 + assert state.spin[0].item() == 1.0 + assert state.spin[1].item() == 0.0 + assert state.spin[2].item() == 2.0 + + def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: """Test conversion from state tensors to list of pymatgen Structure.""" structures = ts.io.state_to_structures(ar_supercell_sim_state) @@ -117,6 +181,23 @@ def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: assert len(atoms[0]) == 32 +def test_state_to_atoms_with_charge_spin() -> None: + """Test conversion from state with charge/spin to ASE Atoms preserves charge/spin.""" + mol = molecule("H2O") + mol.info["charge"] = 1.0 + mol.info["spin"] = 1.0 + + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) + atoms = ts.io.state_to_atoms(state) + + assert len(atoms) == 1 + assert isinstance(atoms[0], Atoms) + assert "charge" in atoms[0].info + assert "spin" in atoms[0].info + assert atoms[0].info["charge"] == 1 + assert atoms[0].info["spin"] == 1 + + def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: """Test conversion from state tensors to list of ASE Atoms.""" atoms = ts.io.state_to_atoms(ar_double_sim_state) @@ -259,6 +340,9 @@ def test_state_round_trip( # since both use their own isotope masses based on species, # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) + # Check charge/spin round trip + assert torch.allclose(sim_state.charge, round_trip_state.charge) + assert torch.allclose(sim_state.spin, round_trip_state.spin) def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 76ff69c1f..49d9b62a8 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -233,6 +233,7 @@ def velocity_verlet_step[T: MDState]( state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_2) diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4a86084c1..eaa5b8128 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -633,7 +633,7 @@ def npt_langevin_init( logger.warning(msg) # Create the initial state - return NPTLangevinState.from_state( + npt_state = NPTLangevinState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -647,6 +647,8 @@ def npt_langevin_init( cell_masses=cell_masses, cell_alpha=cell_alpha, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1063/1.4901303") @@ -708,6 +710,7 @@ def npt_langevin_step( model_output = model(state) state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Store initial values for integration forces = state.forces @@ -747,6 +750,7 @@ def npt_langevin_step( state.energy = model_output["energy"] state.forces = model_output["forces"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Compute updated pressure force F_p_n_new = _compute_cell_force( @@ -1291,6 +1295,7 @@ def _npt_nose_hoover_inner_step( state.forces = model_output["forces"] state.stress = model_output["stress"] state.energy = model_output["energy"] + state.store_model_extras(model_output) state.cell_position = cell_position state.cell_momentum = cell_momentum state.cell_mass = cell_mass @@ -1444,7 +1449,7 @@ def npt_nose_hoover_init( logger.warning(msg) # Create initial state - return NPTNoseHooverState.from_state( + npt_state = NPTNoseHooverState.from_state( state, momenta=momenta, energy=energy, @@ -1460,6 +1465,8 @@ def npt_nose_hoover_init( barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) + npt_state.store_model_extras(model_output) + return npt_state @dcite("10.1080/00268979600100761") @@ -2082,6 +2089,7 @@ def npt_crescale_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt_tensor / 2) @@ -2157,6 +2165,7 @@ def npt_crescale_independent_lengths_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2233,6 +2242,7 @@ def npt_crescale_average_anisotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2310,6 +2320,7 @@ def npt_crescale_isotropic_step( state.forces = model_output["forces"] state.energy = model_output["energy"] state.stress = model_output["stress"] + state.store_model_extras(model_output) # Final momentum step state = momentum_step(state, dt / 2) @@ -2383,7 +2394,7 @@ def npt_crescale_init( ) # Create the initial state - return NPTCRescaleState.from_state( + npt_state = NPTCRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -2392,3 +2403,5 @@ def npt_crescale_init( tau_p=tau_p, isothermal_compressibility=isothermal_compressibility, ) + npt_state.store_model_extras(model_output) + return npt_state diff --git a/torch_sim/integrators/nve.py b/torch_sim/integrators/nve.py index 07f3064bb..316ef78c7 100644 --- a/torch_sim/integrators/nve.py +++ b/torch_sim/integrators/nve.py @@ -57,12 +57,14 @@ def nve_init( state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state def nve_step( @@ -100,5 +102,6 @@ def nve_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt / 2) diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 8e74bf855..841399b4a 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -126,12 +126,14 @@ def nvt_langevin_init( kT, state.rng, ) - return MDState.from_state( + md_state = MDState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + md_state.store_model_extras(model_output) + return md_state @dcite("10.1098/rspa.2016.0138") @@ -191,6 +193,7 @@ def nvt_langevin_step( model_output = model(state) state.energy = model_output["energy"] state.forces = model_output["forces"] + state.store_model_extras(model_output) return momentum_step(state, dt_tensor / 2) @@ -321,7 +324,7 @@ def nvt_nose_hoover_init( ) # n_atoms * n_dimensions # Initialize state - return NVTNoseHooverState.from_state( + nh_state = NVTNoseHooverState.from_state( state, momenta=momenta, energy=model_output["energy"], @@ -330,6 +333,8 @@ def nvt_nose_hoover_init( chain=chain_fns.initialize(dof_per_system, KE, kT_tensor), _chain_fns=chain_fns, ) + nh_state.store_model_extras(model_output) + return nh_state @dcite("10.1080/00268979600100761") @@ -609,12 +614,14 @@ def nvt_vrescale_init( state.rng, ) - return NVTVRescaleState.from_state( + vr_state = NVTVRescaleState.from_state( state, momenta=momenta, energy=model_output["energy"], forces=model_output["forces"], ) + vr_state.store_model_extras(model_output) + return vr_state @dcite("10.1063/1.2408420") diff --git a/torch_sim/io.py b/torch_sim/io.py index edce091e1..6044c8136 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -32,11 +32,17 @@ description="ASE: Atomic Simulation Environment", path="ase", ) -def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: +def state_to_atoms( + state: "ts.SimState", + system_extras_keys: list[str] | None = None, + atom_extras_keys: list[str] | None = None, +) -> list["Atoms"]: """Convert a SimState to a list of ASE Atoms objects. Args: state (SimState): Batched state containing positions, cell, and atomic numbers + system_extras_keys: Keys for per-system extras to include in atoms.info + atom_extras_keys: Keys for per-atom extras to include in atoms.arrays Returns: list[Atoms]: ASE Atoms objects, one per system @@ -70,10 +76,6 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: else np.array([state.pbc] * 3 if isinstance(state.pbc, bool) else state.pbc) ) - # Extract charge and spin if available (per-system attributes) - charge = state.charge.detach().cpu().numpy() if state.charge is not None else None - spin = state.spin.detach().cpu().numpy() if state.spin is not None else None - atoms_list = [] for sys_idx in np.unique(system_indices): mask = system_indices == sys_idx @@ -91,11 +93,18 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc_for_sys ) - # Preserve charge and spin in atoms.info (as integers for FairChem compatibility) - if charge is not None: - atoms.info["charge"] = int(charge[sys_idx].item()) - if spin is not None: - atoms.info["spin"] = int(spin[sys_idx].item()) + # Write system extras to atoms.info + # charge/spin stored as int scalars for FairChem compatibility + if system_extras_keys is not None: + for key in system_extras_keys: + val = state.system_extras[key][sys_idx].detach().cpu().numpy() + atoms.info[key] = val + + # Write atom extras to atoms.arrays + if atom_extras_keys is not None: + for key in atom_extras_keys: + val = state.atom_extras[key][mask].detach().cpu().numpy() + atoms.arrays[key] = val atoms_list.append(atoms) @@ -244,6 +253,8 @@ def atoms_to_state( atoms: "Atoms | list[Atoms]", device: torch.device | None = None, dtype: torch.dtype | None = None, + system_extras_keys: list[str] | None = None, + atom_extras_keys: list[str] | None = None, ) -> "ts.SimState": """Convert an ASE Atoms object or list of Atoms objects to a SimState. @@ -252,6 +263,10 @@ def atoms_to_state( device (torch.device): Device to create tensors on dtype (torch.dtype): Data type for tensors (typically torch.float32 or torch.float64) + system_extras_keys (list[str]): Optional list of keys to read from atoms.info + into _system_extras + atom_extras_keys (list[str]): Optional list of keys to read from atoms.arrays + into _atom_extras Returns: SimState: TorchSim SimState object. @@ -298,12 +313,25 @@ def atoms_to_state( if not all(np.all(np.equal(at.pbc, atoms_list[0].pbc)) for at in atoms_list[1:]): raise ValueError("All systems must have the same periodic boundary conditions") - charge = torch.tensor( - [at.info.get("charge", 0.0) for at in atoms_list], dtype=dtype, device=device - ) - spin = torch.tensor( - [at.info.get("spin", 0.0) for at in atoms_list], dtype=dtype, device=device - ) + _system_extras: dict[str, torch.Tensor] = {} + if system_extras_keys: + for key in system_extras_keys: + vals = [at.info.get(key) for at in atoms_list] + non_none_vals = [v for v in vals if v is not None] + if len(non_none_vals) == len(vals): + _system_extras[key] = torch.tensor( + np.stack(non_none_vals), dtype=dtype, device=device + ) + + _atom_extras: dict[str, torch.Tensor] = {} + if atom_extras_keys: + for key in atom_extras_keys: + arrays = [at.arrays.get(key) for at in atoms_list] + non_none_arrays = [a for a in arrays if a is not None] + if len(non_none_arrays) == len(arrays): + _atom_extras[key] = torch.tensor( + np.concatenate(non_none_arrays), dtype=dtype, device=device + ) return ts.SimState( positions=positions, @@ -312,8 +340,8 @@ def atoms_to_state( pbc=atoms_list[0].pbc, atomic_numbers=atomic_numbers, system_idx=system_idx, - charge=charge, - spin=spin, + _system_extras=_system_extras, + _atom_extras=_atom_extras, ) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index 8aa6bb5e6..d09878dec 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -243,7 +243,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 and primitive BCC iron) for validation. It tests both single and multi-batch processing capabilities. """ - from ase.build import bulk + from ase.build import bulk, molecule for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): @@ -273,6 +273,8 @@ def validate_model_outputs( # noqa: C901, PLR0915 system_idx = sim_state.system_idx og_system_idx = system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() + og_charge = sim_state.charge.clone() + og_spin = sim_state.spin.clone() if check_detached and hasattr(model, "retain_graph"): model.__dict__["retain_graph"] = True @@ -293,6 +295,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") + if not torch.allclose(og_charge, sim_state.charge): + raise ValueError(f"{og_charge=} != {sim_state.charge=}") + if not torch.allclose(og_spin, sim_state.spin): + raise ValueError(f"{og_spin=} != {sim_state.spin=}") # assert model output has the correct keys if "energy" not in model_output: @@ -407,3 +413,43 @@ def validate_model_outputs( # noqa: C901, PLR0915 "vector: max diff = " f"{(shifted_output['stress'] - si_model_output['stress']).abs().max()}" ) + + # Test that models can handle non-zero charge and spin + benzene_atoms = molecule("C6H6") + benzene_atoms.info["charge"] = 1.0 + benzene_atoms.info["spin"] = 1.0 + charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype) + + # Ensure state has charge/spin before testing model + if charged_state.charge is None or charged_state.spin is None: + raise ValueError( + "atoms_to_state did not extract charge/spin. " + "Cannot test model charge/spin handling." + ) + + # Test that model can handle charge/spin without crashing + og_charged_charge = charged_state.charge.clone() + og_charged_spin = charged_state.spin.clone() + try: + charged_output = model.forward(charged_state) + except Exception as e: + raise ValueError( + "Model failed to handle non-zero charge/spin. " + "Models must be able to process states with charge and spin values. " + ) from e + + # Verify model didn't mutate charge/spin + if not torch.allclose(og_charged_charge, charged_state.charge): + raise ValueError( + f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}" + ) + if not torch.allclose(og_charged_spin, charged_state.spin): + raise ValueError( + f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}" + ) + # Verify output shape is still correct + if charged_output["energy"].shape != (1,): + raise ValueError( + f"energy shape incorrect with charge/spin: " + f"{charged_output['energy'].shape=} != (1,)" + ) diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index e13e3ce85..75078670a 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -336,6 +336,13 @@ def forward( # noqa: C901 if stress is not None: results["stress"] = stress.detach() + # Propagate additional model outputs (e.g. dipole, charges, etc.) + for key, val in out.items(): + if key not in ("energy", "forces", "stress") and isinstance( + val, torch.Tensor + ): + results[key] = val.detach() + return results diff --git a/torch_sim/monte_carlo.py b/torch_sim/monte_carlo.py index 04dfde316..8a4a0d37c 100644 --- a/torch_sim/monte_carlo.py +++ b/torch_sim/monte_carlo.py @@ -223,7 +223,7 @@ def swap_mc_init( """ model_output = model(state) - return SwapMCState( + mc_state = SwapMCState( positions=state.positions, masses=state.masses, cell=state.cell, @@ -233,6 +233,8 @@ def swap_mc_init( energy=model_output["energy"], _constraints=state.constraints, ) + mc_state.store_model_extras(model_output) + return mc_state def swap_mc_step( @@ -292,5 +294,6 @@ def swap_mc_step( state.energy = torch.where(accepted, energies_new, energies_old) state.last_permutation = permutation[reverse_rejected_swaps].clone() + state.store_model_extras(model_output) return state diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index 009fe9bbd..16648950a 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -12,13 +12,13 @@ try: from vesin import NeighborList as VesinNeighborList except ImportError: - VesinNeighborList = None # type: ignore[assignment] + VesinNeighborList = None try: from vesin.torch import NeighborList as VesinNeighborListTorch except ImportError: - VesinNeighborListTorch = None # ty:ignore[invalid-assignment] + VesinNeighborListTorch = None VESIN_AVAILABLE = VesinNeighborList is not None VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None diff --git a/torch_sim/optimizers/bfgs.py b/torch_sim/optimizers/bfgs.py index c3a344cf3..c9ada4992 100644 --- a/torch_sim/optimizers/bfgs.py +++ b/torch_sim/optimizers/bfgs.py @@ -25,13 +25,13 @@ _clamp_deform_grad_log, frechet_cell_filter_init, ) +from torch_sim.optimizers.state import BFGSState from torch_sim.state import SimState if TYPE_CHECKING: from torch_sim.models.interface import ModelInterface from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs - from torch_sim.optimizers.state import BFGSState BFGS_EPS = 1e-7 # eps kept same as ASE's BFGS. @@ -115,8 +115,6 @@ def bfgs_init( Returns: BFGSState or CellBFGSState if cell_filter is provided """ - from torch_sim.optimizers import BFGSState, CellBFGSState - device: torch.device = model.device dtype: torch.dtype = model.dtype @@ -143,6 +141,19 @@ def bfgs_init( n_iter = torch.zeros((n_systems,), device=model.device, dtype=torch.int32) # [S] + bfgs_attrs = { + "forces": forces, # [N, 3] + "energy": energy, # [S] + "stress": stress, # [S, 3, 3] or None + "prev_forces": forces.clone(), # [N, 3] + "prev_positions": state.positions.clone(), # [N, 3] + "alpha": alpha_t, # [S] + "max_step": max_step_t, # [S] + "n_iter": n_iter, # [S] + "atom_idx_in_system": atom_idx, # [N] + "max_atoms": max_atoms, # [S] + } + if cell_filter is not None: # Extended Hessian: (3*global_max_atoms + 9) x (3*global_max_atoms + 9) # The extra 9 DOFs are for cell parameters (3x3 matrix flattened) @@ -153,59 +164,31 @@ def bfgs_init( cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) - # Note (AG): At initialization, deform_grad is identity, so we have: - # fractional = Cartesian / cell and scaled forces = forces @ I = forces - # For ASE compatibility, we need to store prev_positions as fractional coords - # and prev_forces as scaled forces - - # Get initial deform_grad (identity at start since reference_cell = current_cell) + # At initialization, deform_grad is identity, so fractional = Cartesian + # and scaled forces = forces. For ASE compatibility, store prev_positions + # as fractional coords and prev_forces as scaled forces. reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = cell_filters.deform_grad( reference_cell.mT, state.cell.mT ) # [S, 3, 3] - # Initial fractional positions = solve(deform_grad, positions) = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D_ext, D_ext] - # Note (AG): Store fractional positions and scaled forces - # for ASE compatibility - "prev_forces": scaled_forces, # [N, 3] (scaled) - "prev_positions": frac_positions, # [N, 3] (fractional) - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "reference_cell": reference_cell, # [S, 3, 3] - "cell_filter": cell_filter_funcs, - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - cell_state = CellBFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_attrs["hessian"] = hessian # [S, D_ext, D_ext] + bfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) + bfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + bfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + bfgs_attrs["cell_filter"] = cell_filter_funcs + + cell_state = CellBFGSState.from_state(state, **bfgs_attrs) # Initialize cell-specific attributes (cell_positions, cell_forces, etc.) # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -215,6 +198,7 @@ def bfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state # Position-only Hessian: 3*global_max_atoms x 3*global_max_atoms @@ -222,31 +206,11 @@ def bfgs_init( hessian = torch.eye(dim, device=device, dtype=dtype).unsqueeze(0).repeat( n_systems, 1, 1 ) * alpha_t.view(n_systems, 1, 1) # [S, D, D] + bfgs_attrs["hessian"] = hessian # [S, D, D] - common_args = { - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "forces": forces, # [N, 3] - "energy": energy, # [S] - "stress": stress, # [S, 3, 3] or None - "hessian": hessian, # [S, D, D] - "prev_forces": forces.clone(), # [N, 3] - "prev_positions": state.positions.clone(), # [N, 3] - "alpha": alpha_t, # [S] - "max_step": max_step_t, # [S] - "n_iter": n_iter, # [S] - "atom_idx_in_system": atom_idx, # [N] - "max_atoms": max_atoms, # scalar M - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - } - - return BFGSState(**common_args) # ty: ignore[invalid-argument-type] + bfgs_state = BFGSState.from_state(state, **bfgs_attrs) + bfgs_state.store_model_extras(model_output) + return bfgs_state def bfgs_step( # noqa: C901, PLR0915 @@ -550,6 +514,7 @@ def bfgs_step( # noqa: C901, PLR0915 state.energy = model_output["energy"] # [S] if "stress" in model_output: state.stress = model_output["stress"] # [S, 3, 3] + state.store_model_extras(model_output) # Update cell forces for next step # Update cell forces for cell state: [S, 3, 3] diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 8efcb3a7b..e45c7ec8d 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -106,9 +106,12 @@ def fire_init( cell_state.cell_forces.shape, torch.nan, device=device, dtype=dtype ) + cell_state.store_model_extras(model_output) return cell_state # Create regular FireState without cell optimization - return FireState.from_state(state, **fire_attrs) + fire_state = FireState.from_state(state, **fire_attrs) + fire_state.store_model_extras(model_output) + return fire_state def fire_step( @@ -173,7 +176,7 @@ def fire_step( return step_func(state, **step_func_kwargs) # ty: ignore[invalid-argument-type] -def _vv_fire_step[T: "FireState | CellFireState"]( +def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state: T, model: "ModelInterface", *, @@ -215,6 +218,7 @@ def _vv_fire_step[T: "FireState | CellFireState"]( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): @@ -465,6 +469,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellFireState): diff --git a/torch_sim/optimizers/gradient_descent.py b/torch_sim/optimizers/gradient_descent.py index 6f940ff0f..7356ffe80 100644 --- a/torch_sim/optimizers/gradient_descent.py +++ b/torch_sim/optimizers/gradient_descent.py @@ -53,6 +53,8 @@ def gradient_descent_init( "stress": stress, } + state.store_model_extras(model_output) + if cell_filter is not None: # Create cell optimization state cell_filter_funcs = init_fn, _step_fn = ts.get_cell_filter(cell_filter) optim_attrs["reference_cell"] = state.cell.clone() @@ -112,6 +114,7 @@ def gradient_descent_step( state.energy = model_output["energy"] if "stress" in model_output: state.stress = model_output["stress"] + state.store_model_extras(model_output) # Update cell forces if isinstance(state, CellOptimState): diff --git a/torch_sim/optimizers/lbfgs.py b/torch_sim/optimizers/lbfgs.py index f8413399c..cb88c6577 100644 --- a/torch_sim/optimizers/lbfgs.py +++ b/torch_sim/optimizers/lbfgs.py @@ -192,22 +192,10 @@ def lbfgs_init( if step_size_tensor.ndim == 0: step_size_tensor = step_size_tensor.expand(n_systems) - common_args = { - # Copy SimState attributes - "positions": state.positions.clone(), # [N, 3] - "masses": state.masses.clone(), # [N] - "cell": state.cell.clone(), # [S, 3, 3] - "atomic_numbers": state.atomic_numbers.clone(), # [N] - "system_idx": state.system_idx.clone(), # [N] - "pbc": state.pbc, # [S, 3] - "charge": state.charge, # preserve charge - "spin": state.spin, # preserve spin - "_constraints": state.constraints, # preserve constraints - # Optimization state + lbfgs_attrs = { "forces": forces, # [N, 3] "energy": energy, # [S] "stress": stress, # [S, 3, 3] or None - # L-BFGS specific state "prev_forces": forces.clone(), # [N, 3] "prev_positions": state.positions.clone(), # [N, 3] "s_history": s_history, # [S, 0, M, 3] @@ -227,41 +215,35 @@ def lbfgs_init( reference_cell = state.cell.clone() # [S, 3, 3] cur_deform_grad = deform_grad(reference_cell.mT, state.cell.mT) # [S, 3, 3] - # Initial fractional positions = positions - # cur_deform_grad[system_idx]: [N, 3, 3], positions: [N, 3] -> [N, 3] frac_positions = torch.linalg.solve( cur_deform_grad[state.system_idx], # [N, 3, 3] state.positions.unsqueeze(-1), # [N, 3, 1] ).squeeze(-1) # [N, 3] - # Initial scaled forces = forces @ deform_grad = forces - # forces: [N, 3], cur_deform_grad[system_idx]: [N, 3, 3] -> [N, 3] scaled_forces = torch.bmm( forces.unsqueeze(1), # [N, 1, 3] cur_deform_grad[state.system_idx], # [N, 3, 3] ).squeeze(1) # [N, 3] - common_args["reference_cell"] = reference_cell # [S, 3, 3] - common_args["cell_filter"] = cell_filter_funcs - # Store fractional positions and scaled forces for ASE compatibility - common_args["prev_positions"] = frac_positions # [N, 3] - common_args["prev_forces"] = scaled_forces # [N, 3] + lbfgs_attrs["reference_cell"] = reference_cell # [S, 3, 3] + lbfgs_attrs["cell_filter"] = cell_filter_funcs + lbfgs_attrs["prev_positions"] = frac_positions # [N, 3] (fractional) + lbfgs_attrs["prev_forces"] = scaled_forces # [N, 3] (scaled) # Extended per-system history includes cell DOFs (3 "virtual atoms" per system) - # History shape: [S, H, M+3, 3] where M = global_max_atoms extended_size_per_system = global_max_atoms + 3 # M_ext = M + 3 - common_args["s_history"] = torch.zeros( + lbfgs_attrs["s_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - common_args["y_history"] = torch.zeros( + lbfgs_attrs["y_history"] = torch.zeros( (n_systems, 0, extended_size_per_system, 3), device=device, dtype=dtype, ) # [S, 0, M_ext, 3] - cell_state = CellLBFGSState(**common_args) # ty: ignore[invalid-argument-type] + cell_state = CellLBFGSState.from_state(state, **lbfgs_attrs) # Initialize cell-specific attributes # After init: cell_positions [S, 3, 3], cell_forces [S, 3, 3], cell_factor [S] @@ -271,9 +253,12 @@ def lbfgs_init( cell_state.prev_cell_positions = cell_state.cell_positions.clone() # [S, 3, 3] cell_state.prev_cell_forces = cell_state.cell_forces.clone() # [S, 3, 3] + cell_state.store_model_extras(model_output) return cell_state - return LBFGSState(**common_args) # ty: ignore[invalid-argument-type] + lbfgs_state = LBFGSState.from_state(state, **lbfgs_attrs) + lbfgs_state.store_model_extras(model_output) + return lbfgs_state def lbfgs_step( # noqa: PLR0915, C901 @@ -536,6 +521,7 @@ def lbfgs_step( # noqa: PLR0915, C901 new_forces = model_output["forces"] # [N, 3] new_energy = model_output["energy"] # [S] new_stress = model_output.get("stress") # [S, 3, 3] or None + state.store_model_extras(model_output) # Update cell forces for next step: [S, 3, 3] if isinstance(state, CellLBFGSState): diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 72c25c070..7e66ca26d 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -22,7 +22,7 @@ from torch_sim.integrators.md import MDState from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import OPTIM_REGISTRY, FireState, Optimizer, OptimState -from torch_sim.state import SimState +from torch_sim.state import _CANONICAL_MODEL_KEYS, SimState from torch_sim.trajectory import TrajectoryReporter from torch_sim.typing import StateLike from torch_sim.units import UnitSystem @@ -732,7 +732,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 ) -def static( +def static( # noqa: C901 system: StateLike, model: ModelInterface, *, @@ -836,8 +836,25 @@ class StaticState(SimState): else torch.full_like(sub_state.cell, fill_value=float("nan")) ), ) + static_state.store_model_extras(model_outputs) props = trajectory_reporter.report(static_state, 0, model=model) + + # Merge extra model outputs into per-system property dicts + # TODO: this should be cleaner? + extra_keys = {k for k in model_outputs if k not in _CANONICAL_MODEL_KEYS} + if extra_keys: + for sys_idx, sys_props in enumerate(props): + for key in extra_keys: + val = model_outputs[key] + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + if val.shape[0] == static_state.n_atoms: + mask = static_state.system_idx == sys_idx + sys_props[key] = val[mask] + elif val.shape[0] == static_state.n_systems: + sys_props[key] = val[sys_idx : sys_idx + 1] + all_props.extend(props) if tqdm_pbar: diff --git a/torch_sim/state.py b/torch_sim/state.py index 390b55fd3..e9079d91a 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -5,11 +5,12 @@ """ import copy +import functools import importlib import typing from collections import defaultdict from collections.abc import Generator, Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self import torch @@ -32,6 +33,10 @@ ) +# Canonical model output keys that are handled explicitly by integrators/runners +_CANONICAL_MODEL_KEYS = frozenset({"energy", "forces", "stress"}) + + def coerce_prng(rng: PRNGLike, device: DeviceLikeType | None) -> torch.Generator: """Coerce an int seed or existing Generator into a ``torch.Generator``. @@ -67,6 +72,30 @@ def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: return system_idx +_EXTRAS_COMPAT_KEYS = frozenset({"charge", "spin"}) + + +def _wrap_init_for_extras(cls: type) -> None: + """Wrap a dataclass __init__ to route unknown kwargs into _system_extras.""" + original_init = cls.__init__ + all_fields = {f.name for f in fields(cls)} + + @functools.wraps(original_init) + def _wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: + extras = kwargs.get("_system_extras") + if extras is None: + extras = {} + kwargs["_system_extras"] = extras + unknown = [k for k in kwargs if k not in all_fields] + for key in unknown: + val = kwargs.pop(key) + if val is not None: + extras[key] = val + original_init(self, *args, **kwargs) + + cls.__init__ = _wrapped_init # type: ignore[assignment] + + @dataclass(kw_only=True) class SimState: """State representation for atomistic systems with batched operations support. @@ -129,10 +158,10 @@ class SimState: cell: torch.Tensor pbc: torch.Tensor # coerced from bool/list[bool] by __setattr__ atomic_numbers: torch.Tensor - charge: torch.Tensor | None = field(default=None) - spin: torch.Tensor | None = field(default=None) system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ - _constraints: list["Constraint"] = field(default_factory=lambda: []) # noqa: PIE807 + _constraints: list["Constraint"] = field(default_factory=list) + _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) + _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) _rng: PRNGLike = field(default=None, repr=False) if TYPE_CHECKING: @@ -145,11 +174,10 @@ def __init__( # noqa: D107 cell: torch.Tensor, pbc: torch.Tensor | list[bool] | bool, atomic_numbers: torch.Tensor, - charge: torch.Tensor | None = None, - spin: torch.Tensor | None = None, system_idx: torch.Tensor | None = None, _constraints: list[Constraint] | None = None, _rng: PRNGLike = None, + **kwargs: Any, ) -> None: ... _atom_attributes: ClassVar[set[str]] = { @@ -158,7 +186,7 @@ def __init__( # noqa: D107 "atomic_numbers", "system_idx", } - _system_attributes: ClassVar[set[str]] = {"cell", "charge", "spin"} + _system_attributes: ClassVar[set[str]] = {"cell"} _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} @property @@ -171,8 +199,20 @@ def rng(self) -> torch.Generator: def rng(self, value: PRNGLike) -> None: self._rng = value - def __setattr__(self, name: str, value: object) -> None: - """Coerce pbc and system_idx on every assignment.""" + def __setattr__(self, name: str, value: object) -> None: # noqa: C901 + """Coerce pbc and system_idx on every assignment. + + Routes charge/spin writes to _system_extras for backward compatibility. + """ + if name in _EXTRAS_COMPAT_KEYS: + try: + extras = object.__getattribute__(self, "_system_extras") + except AttributeError: + extras = {} + super().__setattr__("_system_extras", extras) + if value is not None: + extras[name] = value + return if name == "pbc" and not isinstance(value, torch.Tensor): if isinstance(value, bool): value = [value] * 3 @@ -210,14 +250,14 @@ def __post_init__(self) -> None: # noqa: C901 if self.constraints: validate_constraints(self.constraints, state=self) - if self.charge is None: - self.charge = torch.zeros(n_systems, device=self.device, dtype=self.dtype) - elif self.charge.shape[0] != n_systems: - raise ValueError(f"Charge must have shape (n_systems={n_systems},)") - if self.spin is None: - self.spin = torch.zeros(n_systems, device=self.device, dtype=self.dtype) - elif self.spin.shape[0] != n_systems: - raise ValueError(f"Spin must have shape (n_systems={n_systems},)") + if "charge" not in self._system_extras: + self._system_extras["charge"] = torch.zeros( + n_systems, device=self.device, dtype=self.dtype + ) + if "spin" not in self._system_extras: + self._system_extras["spin"] = torch.zeros( + n_systems, device=self.device, dtype=self.dtype + ) if self.cell.ndim != 3: self.cell = self.cell.unsqueeze(0) @@ -246,6 +286,29 @@ def __post_init__(self) -> None: # noqa: C901 if len(set(devices.values())) > 1: raise ValueError("All tensors must be on the same device") + # Validate extras shapes and prevent shadowing + all_attrs = self._get_all_attributes() + for key, val in self._system_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"System extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"System extra '{key}' must be a torch.Tensor") + if val.shape[0] != n_systems: + raise ValueError( + f"System extra '{key}' leading dim must be " + f"n_systems={n_systems}, got {val.shape[0]}" + ) + for key, val in self._atom_extras.items(): + if key in all_attrs or hasattr(type(self), key): + raise ValueError(f"Atom extra '{key}' shadows an existing attribute") + if not isinstance(val, torch.Tensor): + raise TypeError(f"Atom extra '{key}' must be a torch.Tensor") + if val.shape[0] != self.n_atoms: + raise ValueError( + f"Atom extra '{key}' leading dim must be " + f"n_atoms={self.n_atoms}, got {val.shape[0]}" + ) + @classmethod def _get_all_attributes(cls) -> set[str]: """Get all attributes of the SimState.""" @@ -253,9 +316,72 @@ def _get_all_attributes(cls) -> set[str]: cls._atom_attributes | cls._system_attributes | cls._global_attributes - | {"_constraints"} + | {"_constraints", "_system_extras", "_atom_extras"} + ) + + def __getattr__(self, name: str) -> Any: + """Allow attribute-style access to extras dict entries.""" + # Guard: don't look up private attrs in extras (avoids recursion during init) + if name.startswith("_"): + raise AttributeError(name) + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + return extras[name] + + # Raise AttributeError so that Python's getattr(obj, name, default), + # hasattr(obj, name), and other descriptor-protocol machinery work correctly. + raise AttributeError( + f"'{type(self).__name__}' has no attribute or extra '{name}'" ) + @property + def system_extras(self) -> dict[str, torch.Tensor]: + """Get the system extras.""" + return self._system_extras + + @property + def atom_extras(self) -> dict[str, torch.Tensor]: + """Get the atom extras.""" + return self._atom_extras + + def has_extras(self, key: str) -> bool: + """Check if an extras key exists.""" + return key in self._system_extras or key in self._atom_extras + + def store_model_extras(self, model_output: dict[str, torch.Tensor]) -> None: + """Store non-canonical model outputs into state extras (in-place). + + Any key in *model_output* that is not in ``{"energy", "forces", "stress"}`` + is classified by its leading dimension: + + * ``n_atoms`` → stored in ``_atom_extras`` + * ``n_systems`` → stored in ``_system_extras`` + * otherwise → skipped (ambiguity or scalar) + + When ``n_atoms == n_systems`` (single-atom system), the tensor is stored as + per-atom by convention. + + Args: + model_output: Full dict returned by ``model.forward()``. + """ + n_atoms = self.n_atoms + n_systems = self.n_systems + + for key, val in model_output.items(): + if key in _CANONICAL_MODEL_KEYS: + continue + if not isinstance(val, torch.Tensor) or val.ndim == 0: + continue + leading = val.shape[0] + if leading == n_atoms: + self._atom_extras[key] = val + elif leading == n_systems: + self._system_extras[key] = val + @property def wrap_positions(self) -> torch.Tensor: """Atomic positions wrapped according to periodic boundary conditions if pbc=True, @@ -486,8 +612,18 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: if attr_name in cls._get_all_attributes(): attrs[attr_name] = cls._clone_attr(attr_value) - # Add/override with additional attributes - attrs.update(additional_attrs) + # Route additional_attrs: known attrs go directly, unknown tensor attrs + # go to _system_extras (backward compat for charge/spin and extensibility) + all_known = cls._get_all_attributes() + for key, val in additional_attrs.items(): + if key in all_known: + attrs[key] = val + elif isinstance(val, torch.Tensor): + if "_system_extras" not in attrs: + attrs["_system_extras"] = {} + attrs["_system_extras"][key] = val + else: + attrs[key] = val return cls(**attrs) @@ -595,9 +731,13 @@ def __init_subclass__(cls, **kwargs) -> None: Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. + + Also wraps __init__ to pop deprecated charge/spin kwargs and route them + to _system_extras for backward compatibility. """ cls._assert_no_tensor_attributes_can_be_none() cls._assert_all_attributes_have_defined_scope() + _wrap_init_for_extras(cls) super().__init_subclass__(**kwargs) @classmethod @@ -606,7 +746,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) - exceptions = {"system_idx", "charge", "spin"} + exceptions = {"system_idx"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items(): @@ -684,6 +824,9 @@ def _assert_all_attributes_have_defined_scope(cls) -> None: ) +_wrap_init_for_extras(SimState) + + @dataclass(kw_only=True) class DeformGradMixin: """Mixin for states that support deformation gradients.""" @@ -769,7 +912,7 @@ def _normalize_system_indices( raise TypeError(f"Unsupported index type: {type(system_indices)}") -def _state_to_device[T: SimState]( +def _state_to_device[T: SimState]( # noqa: C901 state: T, device: torch.device | None = None, dtype: torch.dtype | None = None ) -> T: """Convert the SimState to a new device and dtype. @@ -797,11 +940,25 @@ def _state_to_device[T: SimState]( elif isinstance(attr_value, torch.Generator): attrs[attr_name] = coerce_prng(attr_value, device) + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(device=device) for k, v in attrs[extras_key].items() + } + if dtype is not None: attrs["positions"] = attrs["positions"].to(dtype=dtype) attrs["masses"] = attrs["masses"].to(dtype=dtype) attrs["cell"] = attrs["cell"].to(dtype=dtype) attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + + # Update floating point extras to new dtype + for extras_key in ("_system_extras", "_atom_extras"): + if extras_key in attrs and isinstance(attrs[extras_key], dict): + attrs[extras_key] = { + k: v.to(dtype=dtype) if v.is_floating_point() else v + for k, v in attrs[extras_key].items() + } return type(state)(**attrs) @@ -894,6 +1051,13 @@ def _filter_attrs_by_index( val[system_indices] if isinstance(val, torch.Tensor) else val ) + filtered_attrs["_system_extras"] = { + key: val[system_indices] for key, val in state.system_extras.items() + } + filtered_attrs["_atom_extras"] = { + key: val[atom_indices] for key, val in state.atom_extras.items() + } + return filtered_attrs @@ -926,6 +1090,14 @@ def _split_state[T: SimState](state: T) -> list[T]: global_attrs = dict(get_attrs_for_scope(state, "global")) + split_system_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._system_extras.items(): # noqa: SLF001 + split_system_extras[key] = list(torch.split(val, 1, dim=0)) + + split_atom_extras: dict[str, list[torch.Tensor]] = {} + for key, val in state._atom_extras.items(): # noqa: SLF001 + split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + # Create a state for each system states: list[T] = [] n_systems = len(system_sizes) @@ -952,6 +1124,12 @@ def _split_state[T: SimState](state: T) -> list[T]: **per_system_dict, # Add the global attributes **global_attrs, + "_system_extras": { + key: split_system_extras[key][sys_idx] for key in split_system_extras + }, + "_atom_extras": { + key: split_atom_extras[key][sys_idx] for key in split_atom_extras + }, } start_idx = int(cumsum_atoms[sys_idx].item()) @@ -1098,6 +1276,8 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) + atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] system_offset = 0 num_atoms_per_state = [] @@ -1119,6 +1299,12 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 for prop, val in get_attrs_for_scope(state, "per-system"): per_system_tensors[prop].append(val) + # Collect extras + for key, val in state.system_extras.items(): + system_extras_tensors[key].append(val) + for key, val in state.atom_extras.items(): + atom_extras_tensors[key].append(val) + # Update system indices num_systems = state.n_systems new_indices = state.system_idx + system_offset @@ -1189,6 +1375,14 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + # Concatenate extras + concatenated["_system_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in system_extras_tensors.items() + } + concatenated["_atom_extras"] = { + key: torch.cat(tensors, dim=0) for key, tensors in atom_extras_tensors.items() + } + # Merge constraints constraint_lists = [state.constraints for state in states] num_systems_per_state = [state.n_systems for state in states] From 5e7af8eb6bf9d71932c3918940eb38c8ffba872e Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 19:50:58 -0400 Subject: [PATCH 02/12] remove privileged role of spin and charge --- torch_sim/state.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index e9079d91a..e16b219ea 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -72,9 +72,6 @@ def require_system_idx(system_idx: torch.Tensor | None) -> torch.Tensor: return system_idx -_EXTRAS_COMPAT_KEYS = frozenset({"charge", "spin"}) - - def _wrap_init_for_extras(cls: type) -> None: """Wrap a dataclass __init__ to route unknown kwargs into _system_extras.""" original_init = cls.__init__ @@ -202,17 +199,20 @@ def rng(self, value: PRNGLike) -> None: def __setattr__(self, name: str, value: object) -> None: # noqa: C901 """Coerce pbc and system_idx on every assignment. - Routes charge/spin writes to _system_extras for backward compatibility. + Routes writes to existing extras keys back into their extras dict. """ - if name in _EXTRAS_COMPAT_KEYS: - try: - extras = object.__getattribute__(self, "_system_extras") - except AttributeError: - extras = {} - super().__setattr__("_system_extras", extras) - if value is not None: - extras[name] = value - return + if not name.startswith("_"): + for extras_attr in ("_system_extras", "_atom_extras"): + try: + extras = object.__getattribute__(self, extras_attr) + except AttributeError: + continue + if name in extras: + if value is not None: + extras[name] = value + else: + del extras[name] + return if name == "pbc" and not isinstance(value, torch.Tensor): if isinstance(value, bool): value = [value] * 3 @@ -250,15 +250,6 @@ def __post_init__(self) -> None: # noqa: C901 if self.constraints: validate_constraints(self.constraints, state=self) - if "charge" not in self._system_extras: - self._system_extras["charge"] = torch.zeros( - n_systems, device=self.device, dtype=self.dtype - ) - if "spin" not in self._system_extras: - self._system_extras["spin"] = torch.zeros( - n_systems, device=self.device, dtype=self.dtype - ) - if self.cell.ndim != 3: self.cell = self.cell.unsqueeze(0) From 22476c5cba61e79f99bd513ad30d2cf267db679c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 25 Mar 2026 20:02:34 -0400 Subject: [PATCH 03/12] fix: down to 28 test failures --- torch_sim/state.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_sim/state.py b/torch_sim/state.py index e16b219ea..5eaba48ed 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -722,13 +722,9 @@ def __init_subclass__(cls, **kwargs) -> None: Also enforce all of child classes's attributes are specified in _atom_attributes, _system_attributes, or _global_attributes. - - Also wraps __init__ to pop deprecated charge/spin kwargs and route them - to _system_extras for backward compatibility. """ cls._assert_no_tensor_attributes_can_be_none() cls._assert_all_attributes_have_defined_scope() - _wrap_init_for_extras(cls) super().__init_subclass__(**kwargs) @classmethod From 465c74f30d618770df95d0f03a269ff8fc2852ab Mon Sep 17 00:00:00 2001 From: Stefano Falletta <49149059+falletta@users.noreply.github.com> Date: Thu, 26 Mar 2026 13:11:34 -0400 Subject: [PATCH 04/12] Fixes to Extensible Extras PR (#526) --- tests/test_extras.py | 5 +++-- tests/test_nbody.py | 14 +++++++------- torch_sim/elastic.py | 2 ++ torch_sim/io.py | 30 ++++++++++++++++++++++-------- torch_sim/models/fairchem.py | 4 ++-- torch_sim/models/mace.py | 4 ++-- torch_sim/state.py | 20 +++++++++++++++++--- 7 files changed, 55 insertions(+), 24 deletions(-) diff --git a/tests/test_extras.py b/tests/test_extras.py index 75947706a..8c1eef80c 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -75,8 +75,9 @@ def test_store_model_extras_canonical_keys_not_stored( "stress": torch.randn(state.n_systems, 3, 3), } ) - assert not state._system_extras # noqa: SLF001 - assert not state._atom_extras # noqa: SLF001 + for key in ("energy", "forces", "stress"): + assert key not in state._system_extras # noqa: SLF001 + assert key not in state._atom_extras # noqa: SLF001 def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState): """Tensors with leading dim == n_systems go into system_extras.""" diff --git a/tests/test_nbody.py b/tests/test_nbody.py index e235cd629..5da5a45c5 100644 --- a/tests/test_nbody.py +++ b/tests/test_nbody.py @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None: result = build_triplets(edge_index, n_atoms) - assert result["trip_in"].device == dev - assert result["trip_out"].device == dev - assert result["center_atom"].device == dev + assert result["trip_in"].device.type == dev.type + assert result["trip_out"].device.type == dev.type + assert result["center_atom"].device.type == dev.type @pytest.mark.parametrize( @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None: internal_cell_offsets, ) - assert result["quad_c_to_a_edge"].device == dev - assert result["quad_d_to_b_trip_idx"].device == dev - assert result["d_to_b_edge"].device == dev - assert result["c_to_a_edge"].device == dev + assert result["quad_c_to_a_edge"].device.type == dev.type + assert result["quad_d_to_b_trip_idx"].device.type == dev.type + assert result["d_to_b_edge"].device.type == dev.type + assert result["c_to_a_edge"].device.type == dev.type def test_build_triplets_jit_script() -> None: diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 944cdb813..3efffe7ef 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -680,6 +680,8 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> masses=state.masses, pbc=state.pbc, atomic_numbers=state.atomic_numbers, + _system_extras=state._system_extras, + _atom_extras=state._atom_extras, ) diff --git a/torch_sim/io.py b/torch_sim/io.py index 6044c8136..5ffe0fcd6 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -95,16 +95,22 @@ def state_to_atoms( # Write system extras to atoms.info # charge/spin stored as int scalars for FairChem compatibility - if system_extras_keys is not None: - for key in system_extras_keys: - val = state.system_extras[key][sys_idx].detach().cpu().numpy() - atoms.info[key] = val + _sys_keys = ( + system_extras_keys + if system_extras_keys is not None + else list(state.system_extras) + ) + for key in _sys_keys: + val = state.system_extras[key][sys_idx].detach().cpu().numpy() + atoms.info[key] = val # Write atom extras to atoms.arrays - if atom_extras_keys is not None: - for key in atom_extras_keys: - val = state.atom_extras[key][mask].detach().cpu().numpy() - atoms.arrays[key] = val + _atom_keys = ( + atom_extras_keys if atom_extras_keys is not None else list(state.atom_extras) + ) + for key in _atom_keys: + val = state.atom_extras[key][mask].detach().cpu().numpy() + atoms.arrays[key] = val atoms_list.append(atoms) @@ -314,8 +320,16 @@ def atoms_to_state( raise ValueError("All systems must have the same periodic boundary conditions") _system_extras: dict[str, torch.Tensor] = {} + + # charge and spin always default to 0 for backward compatibility + for key in ("charge", "spin"): + vals = np.array([float(at.info.get(key, 0.0)) for at in atoms_list]) + _system_extras[key] = torch.tensor(vals, dtype=dtype, device=device) + if system_extras_keys: for key in system_extras_keys: + if key in _system_extras: + continue vals = [at.info.get(key) for at in atoms_list] non_none_vals = [v for v in vals if v is not None] if len(non_none_vals) == len(vals): diff --git a/torch_sim/models/fairchem.py b/torch_sim/models/fairchem.py index 1c95b00da..56a6d156f 100644 --- a/torch_sim/models/fairchem.py +++ b/torch_sim/models/fairchem.py @@ -223,8 +223,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] pbc=pbc_np if cell is not None else False, ) - charge = sim_state.charge - spin = sim_state.spin + charge = getattr(sim_state, "charge", None) + spin = getattr(sim_state, "spin", None) atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0 atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0 diff --git a/torch_sim/models/mace.py b/torch_sim/models/mace.py index 75078670a..01afdc51f 100644 --- a/torch_sim/models/mace.py +++ b/torch_sim/models/mace.py @@ -304,8 +304,8 @@ def forward( # noqa: C901 edge_index=edge_index, unit_shifts=unit_shifts, shifts=shifts, - total_charge=state.charge, - total_spin=state.spin, + total_charge=getattr(state, "charge", None), + total_spin=getattr(state, "spin", None), ) # Get model output diff --git a/torch_sim/state.py b/torch_sim/state.py index 5eaba48ed..cb36e34f4 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -974,6 +974,11 @@ def get_attrs_for_scope( for attr_name in attr_names: yield attr_name, getattr(state, attr_name) + if scope == "per-system": + yield from state._system_extras.items() # noqa: SLF001 + elif scope == "per-atom": + yield from state._atom_extras.items() # noqa: SLF001 + def _filter_attrs_by_index( state: SimState, @@ -1029,11 +1034,15 @@ def _filter_attrs_by_index( c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] for name, val in get_attrs_for_scope(state, "per-atom"): + if name in state.atom_extras: + continue filtered_attrs[name] = ( system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices] ) for name, val in get_attrs_for_scope(state, "per-system"): + if name in state.system_extras: + continue filtered_attrs[name] = ( val[system_indices] if isinstance(val, torch.Tensor) else val ) @@ -1065,11 +1074,14 @@ def _split_state[T: SimState](state: T) -> list[T]: split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name != "system_idx": - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + if attr_name == "system_idx" or attr_name in state.atom_extras: + continue + split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): + if attr_name in state.system_extras: + continue if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split @@ -1277,13 +1289,15 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-atom properties for prop, val in get_attrs_for_scope(state, "per-atom"): - if prop == "system_idx": + if prop == "system_idx" or prop in state.atom_extras: # skip system_idx, it will be handled below continue per_atom_tensors[prop].append(val) # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): + if prop in state.system_extras: + continue per_system_tensors[prop].append(val) # Collect extras From 1445ed02fe708d9895ff78862e1def43798835ab Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 2 Apr 2026 11:53:26 -0400 Subject: [PATCH 05/12] lint: avoid SLF001 for extras in elastic code. --- torch_sim/elastic.py | 24 +++++++----------------- torch_sim/io.py | 2 +- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 3efffe7ef..7be97efbb 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -666,23 +666,13 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) -> else: # axis == 5 L[0, 1] += size # xy shear - # Convert positions to fractional coordinates - old_inv = torch.linalg.inv(row_vector_cell) - frac_coords = torch.matmul(positions, old_inv) - - # Apply transformation to cell and convert positions back to cartesian - row_vector_cell = torch.matmul(row_vector_cell, L) - new_positions = torch.matmul(frac_coords, row_vector_cell) - - return SimState( - positions=new_positions, - cell=row_vector_cell.mT.unsqueeze(0), - masses=state.masses, - pbc=state.pbc, - atomic_numbers=state.atomic_numbers, - _system_extras=state._system_extras, - _atom_extras=state._atom_extras, - ) + frac_coords = torch.matmul(positions, torch.linalg.inv(row_vector_cell)) + new_cell = torch.matmul(row_vector_cell, L) + + new_state = state.clone() + new_state.row_vector_cell = new_cell.unsqueeze(0) + new_state.positions = torch.matmul(frac_coords, new_cell) + return new_state def get_elementary_deformations( diff --git a/torch_sim/io.py b/torch_sim/io.py index 5ffe0fcd6..029fd5a09 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -255,7 +255,7 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: description="ASE: Atomic Simulation Environment", path="ase", ) -def atoms_to_state( +def atoms_to_state( # noqa: C901 atoms: "Atoms | list[Atoms]", device: torch.device | None = None, dtype: torch.dtype | None = None, From 2e9424794932f7d718467122166c05a8d8a012af Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 16:58:23 -0400 Subject: [PATCH 06/12] fea: add state modifier hook to interface test --- tests/models/conftest.py | 24 ++++++++++- torch_sim/models/interface.py | 80 ++++++++++++++--------------------- torch_sim/typing.py | 19 +++++++++ torch_sim/units.py | 51 ++++++++-------------- 4 files changed, 92 insertions(+), 82 deletions(-) diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 4c78118d4..0448b9165 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -1,5 +1,7 @@ """Pytest fixtures and test factories for model testing.""" +from __future__ import annotations + import typing import pytest @@ -10,7 +12,10 @@ if typing.TYPE_CHECKING: + from collections.abc import Callable, Sequence + from torch_sim.models.interface import ModelInterface + from torch_sim.state import SimState def make_model_calculator_consistency_test( @@ -81,22 +86,39 @@ def make_validate_model_outputs_test( dtype: torch.dtype = DTYPE, *, check_detached: bool = True, + state_modifiers: Sequence[Callable[[SimState], SimState]] = (), ): """Factory function to create model output validation tests. + Runs ``validate_model_outputs`` once with no modifier (baseline), then + once more for each entry in *state_modifiers* so that every modifier + gets a full, independent validation pass. + Args: model_fixture_name: Name of the model fixture to validate device: Device to run validation on dtype: Data type to use for validation check_detached: Whether to assert output tensors are detached from the autograd graph (skipped for models with ``retain_graph=True``). + state_modifiers: Each callable receives a ``SimState`` and returns a + (possibly new) ``SimState``. The full validation suite is run + once per modifier so that different input edge-cases are + exercised independently. """ from torch_sim.models.interface import validate_model_outputs def test_model_output_validation(request: pytest.FixtureRequest) -> None: """Test that a model implementation follows the ModelInterface contract.""" model: ModelInterface = request.getfixturevalue(model_fixture_name) - validate_model_outputs(model, device, dtype, check_detached=check_detached) + modifiers = state_modifiers or [None] + for modifier in modifiers: + validate_model_outputs( + model, + device, + dtype, + check_detached=check_detached, + state_modifier=modifier, + ) test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index d09878dec..59a99f5c6 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -26,13 +26,21 @@ def forward(self, positions, cell, batch, atomic_numbers=None, **kwargs): compute_stress property, as some integrators require stress calculations. """ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import torch import torch_sim as ts -from torch_sim.state import SimState -from torch_sim.typing import MemoryScaling + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch_sim.state import SimState + from torch_sim.typing import MemoryScaling VALIDATE_ATOL = 1e-4 @@ -210,6 +218,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 dtype: torch.dtype, *, check_detached: bool = False, + state_modifier: Callable[[SimState], SimState] | None = None, ) -> None: """Validate the outputs of a model implementation against the interface requirements. @@ -225,6 +234,9 @@ def validate_model_outputs( # noqa: C901, PLR0915 detached from the autograd graph, unless the model has a ``retain_graph`` attribute set to ``True``. Defaults to ``False`` so that external callers are not immediately broken. + state_modifier: If provided, applied to every ``SimState`` created + during validation before the model sees it. Must return the + (possibly new) state. Raises: AssertionError: If the model doesn't conform to the required interface, @@ -245,6 +257,9 @@ def validate_model_outputs( # noqa: C901, PLR0915 """ from ase.build import bulk, molecule + def _modify(state: SimState) -> SimState: + return state_modifier(state) if state_modifier is not None else state + for attr in ("dtype", "device", "compute_stress", "compute_forces"): if not hasattr(model, attr): raise ValueError(f"model.{attr} is not set") @@ -266,15 +281,14 @@ def validate_model_outputs( # noqa: C901, PLR0915 si_atoms = bulk("Si", "diamond", a=5.43, cubic=True) mg_atoms = bulk("Mg", "hcp", a=3.21, c=5.21).repeat([3, 2, 1]) fe_atoms = bulk("Fe", "bcc", a=2.87) - sim_state = ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype) - + sim_state = _modify( + ts.io.atoms_to_state([si_atoms, mg_atoms, fe_atoms], device, dtype) + ) og_positions = sim_state.positions.clone() og_cell = sim_state.cell.clone() system_idx = sim_state.system_idx og_system_idx = system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() - og_charge = sim_state.charge.clone() - og_spin = sim_state.spin.clone() if check_detached and hasattr(model, "retain_graph"): model.__dict__["retain_graph"] = True @@ -295,10 +309,6 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{og_system_idx=} != {sim_state.system_idx=}") if not torch.allclose(og_atomic_nums, sim_state.atomic_numbers): raise ValueError(f"{og_atomic_nums=} != {sim_state.atomic_numbers=}") - if not torch.allclose(og_charge, sim_state.charge): - raise ValueError(f"{og_charge=} != {sim_state.charge=}") - if not torch.allclose(og_spin, sim_state.spin): - raise ValueError(f"{og_spin=} != {sim_state.spin=}") # assert model output has the correct keys if "energy" not in model_output: @@ -317,8 +327,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{model_output['stress'].shape=} != (3, 3, 3)") # Test single Si system output shapes (8 atoms) - si_state = ts.io.atoms_to_state([si_atoms], device, dtype) - + si_state = _modify(ts.io.atoms_to_state([si_atoms], device, dtype)) si_model_output = model.forward(si_state) if not torch.allclose( si_model_output["energy"], model_output["energy"][0], atol=VALIDATE_ATOL @@ -339,7 +348,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 raise ValueError(f"{si_model_output['stress'].shape=} != (1, 3, 3)") # Test single Mg system output shapes (12 atoms) - mg_state = ts.io.atoms_to_state([mg_atoms], device, dtype) + mg_state = _modify(ts.io.atoms_to_state([mg_atoms], device, dtype)) mg_model_output = model.forward(mg_state) if not torch.allclose( mg_model_output["energy"], model_output["energy"][1], atol=VALIDATE_ATOL @@ -363,7 +372,7 @@ def validate_model_outputs( # noqa: C901, PLR0915 # Test single Fe system output shapes (1 atom) # This catches that models do not squeeze away singleton dimensions. - fe_state = ts.io.atoms_to_state([fe_atoms], device, dtype) + fe_state = _modify(ts.io.atoms_to_state([fe_atoms], device, dtype)) fe_model_output = model.forward(fe_state) if not torch.allclose( fe_model_output["energy"], model_output["energy"][2], atol=VALIDATE_ATOL @@ -414,42 +423,17 @@ def validate_model_outputs( # noqa: C901, PLR0915 f"{(shifted_output['stress'] - si_model_output['stress']).abs().max()}" ) - # Test that models can handle non-zero charge and spin + # Test a non-periodic molecule (benzene) benzene_atoms = molecule("C6H6") - benzene_atoms.info["charge"] = 1.0 - benzene_atoms.info["spin"] = 1.0 - charged_state = ts.io.atoms_to_state([benzene_atoms], device, dtype) - - # Ensure state has charge/spin before testing model - if charged_state.charge is None or charged_state.spin is None: - raise ValueError( - "atoms_to_state did not extract charge/spin. " - "Cannot test model charge/spin handling." - ) - - # Test that model can handle charge/spin without crashing - og_charged_charge = charged_state.charge.clone() - og_charged_spin = charged_state.spin.clone() - try: - charged_output = model.forward(charged_state) - except Exception as e: - raise ValueError( - "Model failed to handle non-zero charge/spin. " - "Models must be able to process states with charge and spin values. " - ) from e - - # Verify model didn't mutate charge/spin - if not torch.allclose(og_charged_charge, charged_state.charge): - raise ValueError( - f"Model mutated charge: {og_charged_charge=} != {charged_state.charge=}" - ) - if not torch.allclose(og_charged_spin, charged_state.spin): + benzene_state = _modify(ts.io.atoms_to_state([benzene_atoms], device, dtype)) + benzene_output = model.forward(benzene_state) + if benzene_output["energy"].shape != (1,): raise ValueError( - f"Model mutated spin: {og_charged_spin=} != {charged_state.spin=}" + f"energy shape incorrect for benzene: " + f"{benzene_output['energy'].shape=} != (1,)" ) - # Verify output shape is still correct - if charged_output["energy"].shape != (1,): + if force_computed and benzene_output["forces"].shape != (12, 3): raise ValueError( - f"energy shape incorrect with charge/spin: " - f"{charged_output['energy'].shape=} != (1,)" + f"forces shape incorrect for benzene: " + f"{benzene_output['forces'].shape=} != (12, 3)" ) diff --git a/torch_sim/typing.py b/torch_sim/typing.py index ab4a74145..4dd07ca0f 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -14,6 +14,25 @@ from torch_sim.state import SimState +class AtomExtras(StrEnum): + """Blessed names for per-atom :class:`~torch_sim.state.SimState` extras. + + Stored in ``SimState._atom_extras``; leading dimension is ``n_atoms``. + """ + + SITE_CHARGES = "site_charges" + + +class SystemExtras(StrEnum): + """Blessed names for per-system :class:`~torch_sim.state.SimState` extras. + + Stored in ``SimState._system_extras``; leading dimension is ``n_systems``. + """ + + CHARGE = "charge" + SPIN = "spin" + + class BravaisType(StrEnum): """Enumeration of the seven Bravais lattice types in 3D crystals. diff --git a/torch_sim/units.py b/torch_sim/units.py index 74cffba4e..6f0d1f23a 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -10,7 +10,7 @@ from typing import Self -class BaseConstant: +class BaseConstant(float, Enum): """CODATA Recommended Values of the Fundamental Physical Constants: 2014. References: @@ -18,6 +18,10 @@ class BaseConstant: https://wiki.fysik.dtu.dk/ase/_modules/ase/units.html#create_units """ + def __new__(cls, value: float) -> Self: + """Create new BaseConstant enum value.""" + return float.__new__(cls, value) + c = 299792458.0 # speed of light, m/s mu0 = 4.0e-7 * pi # permeability of vacuum grav = 6.67408e-11 # gravitational constant @@ -33,49 +37,30 @@ class BaseConstant: bc = BaseConstant -class UnitConversion: - """Unit conversion class for different unit systems. - - Distance: - Ang (Angstrom) - met (meter) - - Time: - ps (picosecond) - s (second) - fs (femtosecond) - - Pressure: - atm (atmosphere) - pa (pascal) - bar (bar) - GPa (GigaPascal) +class UnitConversion(float, Enum): + """Unit conversion factors between common unit systems.""" - Energy: - cal (calorie) - kcal (kilocalorie) - eV (electron volt) - """ + def __new__(cls, value: float) -> Self: + """Create new UnitConversion enum value.""" + return float.__new__(cls, value) - # Distance Ang_to_met = 1e-10 - Ang2_to_met2 = Ang_to_met * Ang_to_met - Ang3_to_met3 = Ang_to_met * Ang2_to_met2 - - # Time + Ang2_to_met2 = 1e-10**2 + Ang3_to_met3 = 1e-10**3 ps_to_s = 1e-12 fs_to_s = 1e-15 - - # Pressure bar_to_pa = 1e5 atm_to_pa = 101325 pa_to_GPa = 1e-9 - eV_per_Ang3_to_GPa = (bc.e / Ang3_to_met3) * pa_to_GPa - - # Energy + eV_per_Ang3_to_GPa = bc.e * 1e21 cal_to_J = 4.184 kcal_to_cal = 1e3 eV_to_J = bc.e + Bohr_to_Ang = 0.529177210903 + Ang_to_Bohr = 1.0 / 0.529177210903 + Hartree_to_eV = 27.211386245988 + eV_to_Hartree = 1.0 / 27.211386245988 + e2_per_Ang_to_eV = 14.399645478425668 uc = UnitConversion From 006a7ce7282474849da47fa546d1774428335a42 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 17:17:56 -0400 Subject: [PATCH 07/12] fix: units should be relative --- torch_sim/units.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_sim/units.py b/torch_sim/units.py index 6f0d1f23a..19cd6c4fc 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -57,9 +57,9 @@ def __new__(cls, value: float) -> Self: kcal_to_cal = 1e3 eV_to_J = bc.e Bohr_to_Ang = 0.529177210903 - Ang_to_Bohr = 1.0 / 0.529177210903 + Ang_to_Bohr = 1.0 / Bohr_to_Ang Hartree_to_eV = 27.211386245988 - eV_to_Hartree = 1.0 / 27.211386245988 + eV_to_Hartree = 1.0 / Hartree_to_eV e2_per_Ang_to_eV = 14.399645478425668 From ebd4d22a1b1ec53e976849292c26989a9b289bcb Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 18:16:35 -0400 Subject: [PATCH 08/12] fea: add some more blessed extras keys to enums --- tests/test_elastic.py | 16 ++++----- torch_sim/elastic.py | 76 +++++++++++++++++++++---------------------- torch_sim/typing.py | 30 +++++++++++------ 3 files changed, 66 insertions(+), 56 deletions(-) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index 6ad6af76f..2ac174410 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -221,7 +221,7 @@ def test_get_elementary_deformations_strain_consistency( n_deform=n_deform, max_strain_normal=max_strain_normal, max_strain_shear=max_strain_shear, - bravais_type=BravaisType.triclinic, # Test all axes + bravais_type=BravaisType.TRICLINIC, # Test all axes ) # Should generate deformations for all 6 axes (triclinic) @@ -271,12 +271,12 @@ def mace_model() -> MaceModel: @pytest.mark.parametrize( ("sim_state_name", "expected_bravais_type", "atol"), [ - ("cu_sim_state", BravaisType.cubic, 2e-1), - ("mg_sim_state", BravaisType.hexagonal, 5e-1), - ("sb_sim_state", BravaisType.trigonal, 5e-1), - ("tio2_sim_state", BravaisType.tetragonal, 5e-1), - ("ga_sim_state", BravaisType.orthorhombic, 5e-1), - ("niti_sim_state", BravaisType.monoclinic, 5e-1), + ("cu_sim_state", BravaisType.CUBIC, 2e-1), + ("mg_sim_state", BravaisType.HEXAGONAL, 5e-1), + ("sb_sim_state", BravaisType.TRIGONAL, 5e-1), + ("tio2_sim_state", BravaisType.TETRAGONAL, 5e-1), + ("ga_sim_state", BravaisType.ORTHORHOMBIC, 5e-1), + ("niti_sim_state", BravaisType.MONOCLINIC, 5e-1), ], ) def test_elastic_tensor_symmetries( @@ -340,7 +340,7 @@ def test_elastic_tensor_symmetries( ) C_triclinic = ( calculate_elastic_tensor( - state=state, model=model, bravais_type=BravaisType.triclinic + state=state, model=model, bravais_type=BravaisType.TRICLINIC ) * UnitConversion.eV_per_Ang3_to_GPa ) diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 7be97efbb..72341ebb1 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -91,7 +91,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 90) < angle_tol ): - return BravaisType.cubic + return BravaisType.CUBIC # Hexagonal: a = b ≠ c, alpha = beta = 90°, gamma = 120° if ( @@ -100,7 +100,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 120) < angle_tol ): - return BravaisType.hexagonal + return BravaisType.HEXAGONAL # Tetragonal: a = b ≠ c, alpha = beta = gamma = 90° if ( @@ -110,7 +110,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - 90) < angle_tol and abs(gamma - 90) < angle_tol ): - return BravaisType.tetragonal + return BravaisType.TETRAGONAL # Orthorhombic: a ≠ b ≠ c, alpha = beta = gamma = 90° if ( @@ -120,7 +120,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(a - b) > length_tol and (abs(b - c) > length_tol or abs(a - c) > length_tol) ): - return BravaisType.orthorhombic + return BravaisType.ORTHORHOMBIC # Monoclinic: a ≠ b ≠ c, alpha = gamma = 90°, beta ≠ 90° if ( @@ -128,7 +128,7 @@ def get_bravais_type( # noqa: PLR0911 and abs(gamma - 90) < angle_tol and abs(beta - 90) > angle_tol ): - return BravaisType.monoclinic + return BravaisType.MONOCLINIC # Trigonal/Rhombohedral: a = b = c, alpha = beta = gamma ≠ 90° if ( @@ -138,10 +138,10 @@ def get_bravais_type( # noqa: PLR0911 and abs(beta - gamma) < angle_tol and abs(alpha - 90) > angle_tol ): - return BravaisType.trigonal + return BravaisType.TRIGONAL # Triclinic: a ≠ b ≠ c, alpha ≠ beta ≠ gamma ≠ 90° - return BravaisType.triclinic + return BravaisType.TRICLINIC def regular_symmetry(strains: torch.Tensor) -> torch.Tensor: @@ -711,20 +711,20 @@ def get_elementary_deformations( # Deformation rules for different Bravais lattices # Each tuple contains (allowed_axes, symmetry_handler_function) deformation_rules: dict[BravaisType, DeformationRule] = { - BravaisType.cubic: DeformationRule([0, 3], regular_symmetry), - BravaisType.hexagonal: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), - BravaisType.trigonal: DeformationRule([0, 1, 2, 3, 4, 5], trigonal_symmetry), - BravaisType.tetragonal: DeformationRule([0, 2, 3, 5], tetragonal_symmetry), - BravaisType.orthorhombic: DeformationRule( + BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry), + BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), + BravaisType.TRIGONAL: DeformationRule([0, 1, 2, 3, 4, 5], trigonal_symmetry), + BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 5], tetragonal_symmetry), + BravaisType.ORTHORHOMBIC: DeformationRule( [0, 1, 2, 3, 4, 5], orthorhombic_symmetry ), - BravaisType.monoclinic: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), - BravaisType.triclinic: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), + BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), + BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), } # Get deformation rules for this Bravais lattice (default to triclinic if None) if bravais_type is None: - bravais_type = BravaisType.triclinic + bravais_type = BravaisType.TRICLINIC rule = deformation_rules[bravais_type] allowed_axes = rule.axes @@ -890,7 +890,7 @@ def get_elastic_coeffs( deformed_states: list[SimState], stresses: torch.Tensor, base_pressure: torch.Tensor, - bravais_type: BravaisType = BravaisType.triclinic, + bravais_type: BravaisType = BravaisType.TRICLINIC, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]]: """Calculate elastic tensor from stress-strain relationships. @@ -924,15 +924,15 @@ def get_elastic_coeffs( """ # Deformation rules for different Bravais lattices deformation_rules: dict[BravaisType, DeformationRule] = { - BravaisType.cubic: DeformationRule([0, 3], regular_symmetry), - BravaisType.hexagonal: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), - BravaisType.trigonal: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry), - BravaisType.tetragonal: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry), - BravaisType.orthorhombic: DeformationRule( + BravaisType.CUBIC: DeformationRule([0, 3], regular_symmetry), + BravaisType.HEXAGONAL: DeformationRule([0, 2, 3, 5], hexagonal_symmetry), + BravaisType.TRIGONAL: DeformationRule([0, 2, 3, 4, 5], trigonal_symmetry), + BravaisType.TETRAGONAL: DeformationRule([0, 2, 3, 4, 5], tetragonal_symmetry), + BravaisType.ORTHORHOMBIC: DeformationRule( [0, 1, 2, 3, 4, 5], orthorhombic_symmetry ), - BravaisType.monoclinic: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), - BravaisType.triclinic: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), + BravaisType.MONOCLINIC: DeformationRule([0, 1, 2, 3, 4, 5], monoclinic_symmetry), + BravaisType.TRICLINIC: DeformationRule([0, 1, 2, 3, 4, 5], triclinic_symmetry), } # Get symmetry handler for this Bravais lattice @@ -965,15 +965,15 @@ def get_elastic_coeffs( # Calculate elastic constants with pressure correction p = base_pressure pressure_corrections = { - BravaisType.cubic: torch.tensor([-p, p, -p]), - BravaisType.hexagonal: torch.tensor([-p, -p, p, p, -p]), - BravaisType.trigonal: torch.tensor([-p, -p, p, p, p, p, -p]), - BravaisType.tetragonal: torch.tensor([-p, -p, p, p, -p, -p, -p]), - BravaisType.orthorhombic: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]), - BravaisType.monoclinic: torch.tensor( + BravaisType.CUBIC: torch.tensor([-p, p, -p]), + BravaisType.HEXAGONAL: torch.tensor([-p, -p, p, p, -p]), + BravaisType.TRIGONAL: torch.tensor([-p, -p, p, p, p, p, -p]), + BravaisType.TETRAGONAL: torch.tensor([-p, -p, p, p, -p, -p, -p]), + BravaisType.ORTHORHOMBIC: torch.tensor([-p, -p, -p, p, p, p, -p, -p, -p]), + BravaisType.MONOCLINIC: torch.tensor( [-p, -p, -p, p, p, p, -p, -p, -p, p, p, p, p] ), - BravaisType.triclinic: torch.tensor( + BravaisType.TRICLINIC: torch.tensor( [ -p, p, @@ -1036,7 +1036,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 # Initialize full tensor C = torch.zeros((6, 6), dtype=Cij.dtype, device=Cij.device) - if bravais_type == BravaisType.triclinic: + if bravais_type == BravaisType.TRICLINIC: if len(Cij) != 21: raise ValueError( f"Triclinic symmetry requires 21 independent constants, " @@ -1049,19 +1049,19 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[i, j] = C[j, i] = Cij[idx] idx += 1 - elif bravais_type == BravaisType.cubic: + elif bravais_type == BravaisType.CUBIC: C11, C12, C44 = Cij diag = torch.tensor([C11, C11, C11, C44, C44, C44]) C.diagonal().copy_(diag) C[0, 1] = C[1, 0] = C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C12 - elif bravais_type == BravaisType.hexagonal: + elif bravais_type == BravaisType.HEXAGONAL: C11, C12, C13, C33, C44 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2])) C[0, 1] = C[1, 0] = C12 C[0, 2] = C[2, 0] = C[1, 2] = C[2, 1] = C13 - elif bravais_type == BravaisType.trigonal: + elif bravais_type == BravaisType.TRIGONAL: C11, C12, C13, C14, C15, C33, C44 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, (C11 - C12) / 2])) C[0, 1] = C[1, 0] = C12 @@ -1073,7 +1073,7 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[3, 5] = C[5, 3] = -C15 C[4, 5] = C[5, 4] = C14 - elif bravais_type == BravaisType.tetragonal: + elif bravais_type == BravaisType.TETRAGONAL: C11, C12, C13, C16, C33, C44, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C11, C33, C44, C44, C66])) C[0, 1] = C[1, 0] = C12 @@ -1081,14 +1081,14 @@ def get_elastic_tensor_from_coeffs( # noqa: C901, PLR0915 C[0, 5] = C[5, 0] = C16 C[1, 5] = C[5, 1] = -C16 - elif bravais_type == BravaisType.orthorhombic: + elif bravais_type == BravaisType.ORTHORHOMBIC: C11, C12, C13, C22, C23, C33, C44, C55, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66])) C[0, 1] = C[1, 0] = C12 C[0, 2] = C[2, 0] = C13 C[1, 2] = C[2, 1] = C23 - elif bravais_type == BravaisType.monoclinic: + elif bravais_type == BravaisType.MONOCLINIC: C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66 = Cij C.diagonal().copy_(torch.tensor([C11, C22, C33, C44, C55, C66])) C[0, 1] = C[1, 0] = C12 @@ -1106,7 +1106,7 @@ def calculate_elastic_tensor( state: OptimState, model: ModelInterface, *, - bravais_type: BravaisType = BravaisType.triclinic, + bravais_type: BravaisType = BravaisType.TRICLINIC, max_strain_normal: float = 0.01, max_strain_shear: float = 0.06, n_deform: int = 5, diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 4dd07ca0f..655105e55 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -20,7 +20,9 @@ class AtomExtras(StrEnum): Stored in ``SimState._atom_extras``; leading dimension is ``n_atoms``. """ - SITE_CHARGES = "site_charges" + PARTIAL_CHARGES = "partial_charges" + BORN_EFFECTIVE_CHARGES = "born_effective_charges" + MAGNETIC_MOMENTS = "magnetic_moments" class SystemExtras(StrEnum): @@ -29,8 +31,16 @@ class SystemExtras(StrEnum): Stored in ``SimState._system_extras``; leading dimension is ``n_systems``. """ - CHARGE = "charge" - SPIN = "spin" + CHARGE = "charge" # TOTAL_CHARGE preferred for less ambiguity with partial charges + SPIN = "spin" # TOTAL_SPIN preferred + TOTAL_CHARGE = "total_charge" + TOTAL_SPIN = "total_spin" + EXTERNAL_E_FIELD = "external_E_field" + POLARIZABILITY = "polarizability" + TOTAL_POLARIZATION = "total_polarization" + EXTERNAL_H_FIELD = "external_H_field" + MAGNETIC_SUSCEPTIBILITY = "magnetic_susceptibility" + TOTAL_MAGNETIZATION = "total_magnetization" class BravaisType(StrEnum): @@ -44,13 +54,13 @@ class BravaisType(StrEnum): which determine the number of independent elastic constants. """ - cubic = "cubic" - hexagonal = "hexagonal" - trigonal = "trigonal" - tetragonal = "tetragonal" - orthorhombic = "orthorhombic" - monoclinic = "monoclinic" - triclinic = "triclinic" + CUBIC = "cubic" + HEXAGONAL = "hexagonal" + TRIGONAL = "trigonal" + TETRAGONAL = "tetragonal" + ORTHORHOMBIC = "orthorhombic" + MONOCLINIC = "monoclinic" + TRICLINIC = "triclinic" StateLike = Union[ From e86ca52c28cb2f5968acbc499102b670e7e4d058 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 4 Apr 2026 18:58:46 -0400 Subject: [PATCH 09/12] fea: configure ase atoms to ts io better. --- tests/models/test_fairchem.py | 10 +- tests/test_extras.py | 43 ---- tests/test_io.py | 451 ++++++++++------------------------ tests/test_optimizers.py | 8 +- tests/test_state.py | 2 +- torch_sim/io.py | 136 +++++----- torch_sim/state.py | 19 +- torch_sim/typing.py | 2 + 8 files changed, 225 insertions(+), 446 deletions(-) diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index d259e2ec1..25fa1ba38 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -263,10 +263,12 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None: mol.info["charge"] = charge mol.info["spin"] = spin - # Convert to SimState (should extract charge/spin) - state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE) - - # Verify charge/spin were extracted correctly + state = ts.io.atoms_to_state( + [mol], + device=DEVICE, + dtype=DTYPE, + system_extras={"charge": "charge", "spin": "spin"}, + ) assert state.charge is not None assert state.spin is not None assert state.charge[0].item() == charge diff --git a/tests/test_extras.py b/tests/test_extras.py index 8c1eef80c..3bc36b3df 100644 --- a/tests/test_extras.py +++ b/tests/test_extras.py @@ -4,10 +4,6 @@ import torch_sim as ts -DEVICE = torch.device("cpu") -DTYPE = torch.float64 - - class TestExtras: def test_system_extras_construction(self): """Extras can be passed at construction time.""" @@ -51,7 +47,6 @@ def test_post_init_validation_rejects_bad_shape(self): ) def test_construction_extras_cannot_shadow(self): - # Post-init validation should also catch shadowing during construction with pytest.raises(ValueError, match="shadows an existing attribute"): ts.SimState( positions=torch.zeros(2, 3), @@ -62,7 +57,6 @@ def test_construction_extras_cannot_shadow(self): _system_extras={"cell": torch.zeros(1, 3)}, ) - # store_model_extras def test_store_model_extras_canonical_keys_not_stored( self, si_double_sim_state: ts.SimState ): @@ -114,40 +108,3 @@ def test_store_model_extras_skips_scalars(self, si_double_sim_state: ts.SimState ) assert not state.has_extras("scalar") assert not state.has_extras("string") - - -def test_system_extras_atoms_roundtrip(): - state = ts.SimState( - positions=torch.zeros(2, 3), - masses=torch.ones(2), - cell=torch.eye(3).unsqueeze(0), - pbc=True, - atomic_numbers=torch.tensor([1, 1], dtype=torch.int), - _system_extras={"external_E_field": torch.tensor([[1.0, 0.0, 0.0]])}, - ) - atoms_list = state.to_atoms() - assert "external_E_field" in atoms_list[0].info - restored = ts.io.atoms_to_state( - atoms_list, - system_extras_keys=["external_E_field"], - ) - assert torch.allclose(restored.external_E_field, state.external_E_field) - - -def test_atom_extras_atoms_roundtrip(): - tags = torch.tensor([1.0, 2.0]) - state = ts.SimState( - positions=torch.zeros(2, 3), - masses=torch.ones(2), - cell=torch.eye(3).unsqueeze(0), - pbc=True, - atomic_numbers=torch.tensor([1, 1], dtype=torch.int), - _atom_extras={"tags": tags}, - ) - atoms_list = state.to_atoms() - assert "tags" in atoms_list[0].arrays - restored = ts.io.atoms_to_state( - atoms_list, - atom_extras_keys=["tags"], - ) - assert torch.allclose(restored.tags, state.tags) diff --git a/tests/test_io.py b/tests/test_io.py index 8e1de1ac2..46e300d72 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,66 +2,20 @@ import sys from typing import Any +import numpy as np import pytest import torch from ase import Atoms from ase.build import molecule -from phonopy.structure.atoms import PhonopyAtoms -from pymatgen.core import Structure import torch_sim as ts from tests.conftest import DEVICE, DTYPE from torch_sim.state import SimState -def test_single_structure_to_state(si_structure: Structure) -> None: - """Test conversion from pymatgen Structure to state tensors.""" - state = ts.io.structures_to_state(si_structure, DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert all( - t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) - ) - assert all( - t.dtype == torch.float64 for t in (state.positions, state.masses, state.cell) - ) - assert state.atomic_numbers.dtype == torch.int - - # Check shapes and values - assert state.positions.shape == (8, 3) - assert torch.allclose(state.masses, torch.full_like(state.masses, 28.0855)) # Si - assert torch.all(state.atomic_numbers == 14) # Si atomic number - assert torch.allclose( - state.cell, - torch.diag(torch.full((3,), 5.43, device=DEVICE, dtype=torch.float64)), - ) - - -def test_multiple_structures_to_state(si_structure: Structure) -> None: - """Test conversion from list of pymatgen Structure to state tensors.""" - state = ts.io.structures_to_state([si_structure, si_structure], DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert state.positions.shape == (16, 3) - assert state.masses.shape == (16,) - assert state.cell.shape == (2, 3, 3) - assert torch.all(state.pbc) - assert state.atomic_numbers.shape == (16,) - assert state.system_idx is not None - assert state.system_idx.shape == (16,) - assert torch.all( - state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8) - ) - - def test_single_atoms_to_state(si_atoms: Atoms) -> None: - """Test conversion from ASE Atoms to state tensors.""" + """Test basic shape/dtype/device properties of atoms_to_state.""" state = ts.io.atoms_to_state(si_atoms, DEVICE, torch.float64) - - # Check basic properties assert isinstance(state, SimState) assert state.positions.shape == (8, 3) assert state.masses.shape == (8,) @@ -69,218 +23,101 @@ def test_single_atoms_to_state(si_atoms: Atoms) -> None: assert torch.all(state.pbc) assert state.atomic_numbers.shape == (8,) assert state.system_idx is not None - assert state.system_idx.shape == (8,) assert torch.all(state.system_idx == 0) - - -def test_multiple_atoms_to_state(si_atoms: Atoms) -> None: - """Test conversion from ASE Atoms to state tensors.""" - state = ts.io.atoms_to_state([si_atoms, si_atoms], DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert state.positions.shape == (16, 3) - assert state.masses.shape == (16,) - assert state.cell.shape == (2, 3, 3) - assert torch.all(state.pbc) - assert state.atomic_numbers.shape == (16,) - assert state.system_idx is not None - assert state.system_idx.shape == (16,) - assert torch.all( - state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8), + assert all( + t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) ) + assert all( + t.dtype == torch.float64 for t in (state.positions, state.masses, state.cell) + ) + assert state.atomic_numbers.dtype == torch.int @pytest.mark.parametrize( - ("charge", "spin", "expected_charge", "expected_spin"), + ("system_extras", "atom_extras", "expected_sys", "expected_atom"), [ - (1.0, 1.0, 1.0, 1.0), # Non-zero charge and spin - (0.0, 0.0, 0.0, 0.0), # Explicit zero charge and spin - (None, None, 0.0, 0.0), # No charge/spin set, defaults to zero + pytest.param(None, None, {}, {}, id="no-extras-by-default"), + pytest.param( + {"charge": "charge", "spin": "spin"}, + None, + {"charge": 3.0, "spin": 2.0}, + {}, + id="system-extras-identity-map", + ), + pytest.param( + {"total_charge": "charge"}, + None, + {"total_charge": 3.0}, + {}, + id="system-extras-rename", + ), + pytest.param( + None, + {"site_tags": "my_tags"}, + {}, + {"site_tags": [1.0, 2.0, 3.0]}, + id="atom-extras-rename", + ), ], ) -def test_atoms_to_state_with_charge_spin( - charge: float | None, - spin: float | None, - expected_charge: float, - expected_spin: float, +def test_extras_map_import( + system_extras: dict[str, str] | None, + atom_extras: dict[str, str] | None, + expected_sys: dict[str, float], + expected_atom: dict[str, list[float]], ) -> None: - """Test conversion from ASE Atoms with charge and spin to state tensors.""" + """ExtrasMap controls which keys are read and how they are renamed on import.""" mol = molecule("H2O") - if charge is not None: - mol.info["charge"] = charge - if spin is not None: - mol.info["spin"] = spin + mol.info["charge"] = 3.0 + mol.info["spin"] = 2.0 + mol.arrays["my_tags"] = np.array([1.0, 2.0, 3.0]) + state = ts.io.atoms_to_state( + [mol], DEVICE, DTYPE, system_extras=system_extras, atom_extras=atom_extras + ) + if not expected_sys and not expected_atom: + assert not state.system_extras + assert not state.atom_extras + for key, val in expected_sys.items(): + assert getattr(state, key)[0].item() == val + for key, vals in expected_atom.items(): + assert getattr(state, key).shape == (len(vals),) - state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) - # Check basic properties - assert isinstance(state, SimState) - assert state.charge is not None - assert state.spin is not None - assert state.charge.shape == (1,) - assert state.spin.shape == (1,) - assert state.charge[0].item() == expected_charge - assert state.spin[0].item() == expected_spin +def test_extras_map_missing_key_skipped() -> None: + """Missing ASE keys are silently skipped rather than defaulting to zero.""" + mol = molecule("H2O") + state = ts.io.atoms_to_state([mol], DEVICE, DTYPE, system_extras={"charge": "charge"}) + assert not state.system_extras -def test_multiple_atoms_to_state_with_charge_spin() -> None: - """Test conversion from multiple ASE Atoms with different charge/spin values.""" - mol1 = molecule("H2O") +def test_extras_map_multi_system() -> None: + """System extras work across multiple systems with correct per-system values.""" + mol1, mol2 = molecule("H2O"), molecule("CH4") mol1.info["charge"] = 1.0 - mol1.info["spin"] = 1.0 - - mol2 = molecule("CH4") mol2.info["charge"] = -1.0 - mol2.info["spin"] = 0.0 - - mol3 = molecule("NH3") - mol3.info["charge"] = 0.0 - mol3.info["spin"] = 2.0 - - state = ts.io.atoms_to_state([mol1, mol2, mol3], DEVICE, DTYPE) - - # Check basic properties - assert isinstance(state, SimState) - assert state.charge is not None - assert state.spin is not None - assert state.charge.shape == (3,) - assert state.spin.shape == (3,) + state = ts.io.atoms_to_state( + [mol1, mol2], DEVICE, DTYPE, system_extras={"charge": "charge"} + ) + assert state.charge.shape == (2,) assert state.charge[0].item() == 1.0 assert state.charge[1].item() == -1.0 - assert state.charge[2].item() == 0.0 - assert state.spin[0].item() == 1.0 - assert state.spin[1].item() == 0.0 - assert state.spin[2].item() == 2.0 - - -def test_state_to_structure(ar_supercell_sim_state: SimState) -> None: - """Test conversion from state tensors to list of pymatgen Structure.""" - structures = ts.io.state_to_structures(ar_supercell_sim_state) - assert len(structures) == 1 - assert isinstance(structures[0], Structure) - assert len(structures[0]) == 32 - - -def test_state_to_multiple_structures(ar_double_sim_state: SimState) -> None: - """Test conversion from state tensors to list of pymatgen Structure.""" - structures = ts.io.state_to_structures(ar_double_sim_state) - assert len(structures) == 2 - assert isinstance(structures[0], Structure) - assert isinstance(structures[1], Structure) - assert len(structures[0]) == 32 - assert len(structures[1]) == 32 - -def test_state_to_atoms(ar_supercell_sim_state: SimState) -> None: - """Test conversion from state tensors to list of ASE Atoms.""" - atoms = ts.io.state_to_atoms(ar_supercell_sim_state) - assert len(atoms) == 1 - assert isinstance(atoms[0], Atoms) - assert len(atoms[0]) == 32 - -def test_state_to_atoms_with_charge_spin() -> None: - """Test conversion from state with charge/spin to ASE Atoms preserves charge/spin.""" +def test_extras_map_export_roundtrip() -> None: + """System and atom extras round-trip through state_to_atoms with rename.""" mol = molecule("H2O") - mol.info["charge"] = 1.0 - mol.info["spin"] = 1.0 - - state = ts.io.atoms_to_state([mol], DEVICE, DTYPE) - atoms = ts.io.state_to_atoms(state) - - assert len(atoms) == 1 - assert isinstance(atoms[0], Atoms) - assert "charge" in atoms[0].info - assert "spin" in atoms[0].info - assert atoms[0].info["charge"] == 1 - assert atoms[0].info["spin"] == 1 - - -def test_state_to_multiple_atoms(ar_double_sim_state: SimState) -> None: - """Test conversion from state tensors to list of ASE Atoms.""" - atoms = ts.io.state_to_atoms(ar_double_sim_state) - assert len(atoms) == 2 - assert isinstance(atoms[0], Atoms) - assert isinstance(atoms[1], Atoms) - assert len(atoms[0]) == 32 - assert len(atoms[1]) == 32 - - -def test_to_atoms(ar_supercell_sim_state: SimState) -> None: - """Test conversion from SimState to list of ASE Atoms.""" - atoms = ts.io.state_to_atoms(ar_supercell_sim_state) - assert isinstance(atoms[0], Atoms) - - -def test_to_structures(ar_supercell_sim_state: SimState) -> None: - """Test conversion from SimState to list of Pymatgen Structure.""" - structures = ts.io.state_to_structures(ar_supercell_sim_state) - assert isinstance(structures[0], Structure) - - -def test_single_phonopy_to_state(si_phonopy_atoms: Any) -> None: - """Test conversion from PhonopyAtoms to state tensors.""" - state = ts.io.phonopy_to_state(si_phonopy_atoms, DEVICE, torch.float64) - - # Check basic properties - assert isinstance(state, SimState) - assert all( - t.device.type == DEVICE.type for t in (state.positions, state.masses, state.cell) + mol.info["charge"] = 5.0 + mol.arrays["my_tags"] = np.array([1.0, 2.0, 3.0]) + sys_map = {"total_charge": "charge"} + atom_map = {"site_tags": "my_tags"} + state = ts.io.atoms_to_state( + [mol], DEVICE, DTYPE, system_extras=sys_map, atom_extras=atom_map ) - assert all( - t.dtype == torch.float64 for t in (state.positions, state.masses, state.cell) - ) - assert state.atomic_numbers.dtype == torch.int - - # Check shapes and values - assert state.positions.shape == (8, 3) - assert torch.allclose(state.masses, torch.full_like(state.masses, 28.0855)) # Si - assert torch.all(state.atomic_numbers == 14) # Si atomic number - assert torch.allclose( - state.cell, - torch.diag(torch.full((3,), 5.43, device=DEVICE, dtype=torch.float64)), - ) - - -def test_multiple_phonopy_to_state(si_phonopy_atoms: Any) -> None: - """Test conversion from multiple PhonopyAtoms to state tensors.""" - state = ts.io.phonopy_to_state( - [si_phonopy_atoms, si_phonopy_atoms], DEVICE, torch.float64 - ) - - # Check basic properties - assert isinstance(state, SimState) - assert state.positions.shape == (16, 3) - assert state.masses.shape == (16,) - assert state.cell.shape == (2, 3, 3) - assert torch.all(state.pbc) - assert state.atomic_numbers.shape == (16,) - assert state.system_idx is not None - assert state.system_idx.shape == (16,) - assert torch.all( - state.system_idx - == torch.repeat_interleave(torch.tensor([0, 1], device=DEVICE), 8), - ) - - -def test_state_to_phonopy(ar_supercell_sim_state: SimState) -> None: - """Test conversion from state tensors to list of PhonopyAtoms.""" - phonopy_atoms = ts.io.state_to_phonopy(ar_supercell_sim_state) - assert len(phonopy_atoms) == 1 - assert isinstance(phonopy_atoms[0], PhonopyAtoms) - assert len(phonopy_atoms[0]) == 32 - - -def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: - """Test conversion from state tensors to list of PhonopyAtoms.""" - phonopy_atoms = ts.io.state_to_phonopy(ar_double_sim_state) - assert len(phonopy_atoms) == 2 - assert isinstance(phonopy_atoms[0], PhonopyAtoms) - assert isinstance(phonopy_atoms[1], PhonopyAtoms) - assert len(phonopy_atoms[0]) == 32 - assert len(phonopy_atoms[1]) == 32 + atoms = ts.io.state_to_atoms(state, system_extras=sys_map, atom_extras=atom_map) + assert atoms[0].info["charge"] == 5.0 + np.testing.assert_allclose(atoms[0].arrays["my_tags"], [1.0, 2.0, 3.0]) + atoms_no_map = ts.io.state_to_atoms(state) + assert "charge" not in atoms_no_map[0].info @pytest.mark.parametrize( @@ -295,7 +132,6 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: "cu_sim_state", "ar_double_sim_state", "mixed_double_sim_state", - # TODO: round trip benzene/non-pbc systems ], [ (ts.io.state_to_atoms, ts.io.atoms_to_state), @@ -307,103 +143,74 @@ def test_state_to_multiple_phonopy(ar_double_sim_state: SimState) -> None: def test_state_round_trip( sim_state_name: str, conversion_functions: tuple, request: pytest.FixtureRequest ) -> None: - """Test round-trip conversion from SimState through various formats and back. - - Args: - sim_state_name: Name of the sim_state fixture to test - conversion_functions: Tuple of (to_format, from_format) conversion functions - request: Pytest fixture request object to get dynamic fixtures - """ - # Get the sim_state fixture dynamically using the name + """Test round-trip conversion from SimState through various formats and back.""" sim_state: SimState = request.getfixturevalue(sim_state_name) to_format_fn, from_format_fn = conversion_functions assert sim_state.system_idx is not None uniq_systems = torch.unique(sim_state.system_idx) - - # Convert to intermediate format intermediate_format = to_format_fn(sim_state) assert len(intermediate_format) == len(uniq_systems) - - # Convert back to state round_trip_state: SimState = from_format_fn(intermediate_format, DEVICE, DTYPE) - - # Check that all properties match assert round_trip_state.system_idx is not None assert torch.allclose(sim_state.positions, round_trip_state.positions) assert torch.allclose(sim_state.cell, round_trip_state.cell) assert torch.all(sim_state.atomic_numbers == round_trip_state.atomic_numbers) assert torch.all(sim_state.system_idx == round_trip_state.system_idx) assert torch.equal(sim_state.pbc, round_trip_state.pbc) - if isinstance(intermediate_format[0], Atoms): - # TODO: masses round trip for pmg and phonopy masses is not exact - # since both use their own isotope masses based on species, - # not the ones in the state assert torch.allclose(sim_state.masses, round_trip_state.masses) - # Check charge/spin round trip - assert torch.allclose(sim_state.charge, round_trip_state.charge) - assert torch.allclose(sim_state.spin, round_trip_state.spin) - - -def test_state_to_atoms_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "ase", None) - monkeypatch.setitem(sys.modules, "ase.data", None) - - with pytest.raises( - ImportError, match="ASE is required for state_to_atoms conversion" - ): - ts.io.state_to_atoms(None) - - -def test_state_to_phonopy_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "phonopy", None) - monkeypatch.setitem(sys.modules, "phonopy.structure", None) - monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) - - with pytest.raises( - ImportError, match="Phonopy is required for state_to_phonopy conversion" - ): - ts.io.state_to_phonopy(None) - - -def test_state_to_structures_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "pymatgen", None) - monkeypatch.setitem(sys.modules, "pymatgen.core", None) - monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) - with pytest.raises( - ImportError, match="Pymatgen is required for state_to_structures conversion" - ): - ts.io.state_to_structures(None) - -def test_atoms_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "ase", None) - monkeypatch.setitem(sys.modules, "ase.data", None) - - with pytest.raises( - ImportError, match="ASE is required for atoms_to_state conversion" - ): - ts.io.atoms_to_state(None, None, None) - - -def test_phonopy_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "phonopy", None) - monkeypatch.setitem(sys.modules, "phonopy.structure", None) - monkeypatch.setitem(sys.modules, "phonopy.structure.atoms", None) - - with pytest.raises( - ImportError, match="Phonopy is required for phonopy_to_state conversion" - ): - ts.io.phonopy_to_state(None, None, None) - - -def test_structures_to_state_importerror(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setitem(sys.modules, "pymatgen", None) - monkeypatch.setitem(sys.modules, "pymatgen.core", None) - monkeypatch.setitem(sys.modules, "pymatgen.core.structure", None) - - with pytest.raises( - ImportError, match="Pymatgen is required for structures_to_state conversion" - ): - ts.io.structures_to_state(None, None, None) +@pytest.mark.parametrize( + ("monkeypatch_modules", "func", "args", "match"), + [ + ( + ["ase", "ase.data"], + ts.io.state_to_atoms, + (None,), + "ASE is required for state_to_atoms", + ), + ( + ["ase", "ase.data"], + ts.io.atoms_to_state, + (None, None, None), + "ASE is required for atoms_to_state", + ), + ( + ["phonopy", "phonopy.structure", "phonopy.structure.atoms"], + ts.io.state_to_phonopy, + (None,), + "Phonopy is required for state_to_phonopy", + ), + ( + ["phonopy", "phonopy.structure", "phonopy.structure.atoms"], + ts.io.phonopy_to_state, + (None, None, None), + "Phonopy is required for phonopy_to_state", + ), + ( + ["pymatgen", "pymatgen.core", "pymatgen.core.structure"], + ts.io.state_to_structures, + (None,), + "Pymatgen is required for state_to_structures", + ), + ( + ["pymatgen", "pymatgen.core", "pymatgen.core.structure"], + ts.io.structures_to_state, + (None, None, None), + "Pymatgen is required for structures_to_state", + ), + ], +) +def test_import_errors( + monkeypatch: pytest.MonkeyPatch, + monkeypatch_modules: list[str], + func: Any, + args: tuple, + match: str, +) -> None: + """All IO functions raise ImportError when their backend is unavailable.""" + for mod in monkeypatch_modules: + monkeypatch.setitem(sys.modules, mod, None) + with pytest.raises(ImportError, match=match): + func(*args) diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 3ab10533c..35f0b5665 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1464,8 +1464,12 @@ def test_optimizer_preserves_charge_spin( original_spin = torch.tensor( [6.0], device=ar_supercell_sim_state.device, dtype=ar_supercell_sim_state.dtype ) - ar_supercell_sim_state.charge = original_charge.clone() - ar_supercell_sim_state.spin = original_spin.clone() + + # NOTE the convenience method for setting extras is not used here because + # they are only available if the key is already in the extras dict. If not + # we need to set them explicitly. + ar_supercell_sim_state._system_extras["charge"] = original_charge.clone() # noqa: SLF001 + ar_supercell_sim_state._system_extras["spin"] = original_spin.clone() # noqa: SLF001 init_fn, step_fn = ts.OPTIM_REGISTRY[optimizer_fn] opt_state = init_fn( diff --git a/tests/test_state.py b/tests/test_state.py index ed1b3c29d..4cd98693d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -34,7 +34,7 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell", "charge", "spin"} + assert set(per_system_attrs) == {"cell"} global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) assert set(global_attrs) == {"pbc", "_rng"} diff --git a/torch_sim/io.py b/torch_sim/io.py index 029fd5a09..50d0dc702 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -12,6 +12,8 @@ * Batched conversions for multiple structures """ +from __future__ import annotations + from typing import TYPE_CHECKING import numpy as np @@ -26,6 +28,8 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure + from torch_sim.typing import ExtrasMap + @dcite( "10.1088/1361-648X/aa680e", @@ -33,27 +37,31 @@ path="ase", ) def state_to_atoms( - state: "ts.SimState", - system_extras_keys: list[str] | None = None, - atom_extras_keys: list[str] | None = None, -) -> list["Atoms"]: + state: ts.SimState, + *, + system_extras: ExtrasMap | None = None, + atom_extras: ExtrasMap | None = None, +) -> list[Atoms]: """Convert a SimState to a list of ASE Atoms objects. Args: - state (SimState): Batched state containing positions, cell, and atomic numbers - system_extras_keys: Keys for per-system extras to include in atoms.info - atom_extras_keys: Keys for per-atom extras to include in atoms.arrays + state: Batched state containing positions, cell, and atomic numbers. + system_extras: Map of ``{ts_key: ase_key}`` controlling which + ``_system_extras`` entries are written to ``atoms.info``. + ``None`` (default) means no extras are written. + atom_extras: Map of ``{ts_key: ase_key}`` controlling which + ``_atom_extras`` entries are written to ``atoms.arrays``. + ``None`` (default) means no extras are written. Returns: - list[Atoms]: ASE Atoms objects, one per system + list[Atoms]: ASE Atoms objects, one per system. Raises: - ImportError: If ASE is not installed + ImportError: If ASE is not installed. Notes: - Output positions and cell will be in Å - Output masses will be in amu - - Charge and spin are preserved in atoms.info if present in the state """ try: from ase import Atoms @@ -93,24 +101,17 @@ def state_to_atoms( symbols=symbols, positions=system_positions, cell=system_cell, pbc=pbc_for_sys ) - # Write system extras to atoms.info - # charge/spin stored as int scalars for FairChem compatibility - _sys_keys = ( - system_extras_keys - if system_extras_keys is not None - else list(state.system_extras) - ) - for key in _sys_keys: - val = state.system_extras[key][sys_idx].detach().cpu().numpy() - atoms.info[key] = val + if system_extras: + for ts_key, ase_key in system_extras.items(): + if ts_key in state.system_extras: + val = state.system_extras[ts_key][sys_idx].detach().cpu().numpy() + atoms.info[ase_key] = val - # Write atom extras to atoms.arrays - _atom_keys = ( - atom_extras_keys if atom_extras_keys is not None else list(state.atom_extras) - ) - for key in _atom_keys: - val = state.atom_extras[key][mask].detach().cpu().numpy() - atoms.arrays[key] = val + if atom_extras: + for ts_key, ase_key in atom_extras.items(): + if ts_key in state.atom_extras: + val = state.atom_extras[ts_key][mask].detach().cpu().numpy() + atoms.arrays[ase_key] = val atoms_list.append(atoms) @@ -122,7 +123,7 @@ def state_to_atoms( description="pymatgen: Python Materials Genomics", path="pymatgen", ) -def state_to_structures(state: "ts.SimState") -> list["Structure"]: +def state_to_structures(state: ts.SimState) -> list[Structure]: """Convert a SimState to a list of Pymatgen Structure objects. Args: @@ -198,7 +199,7 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: description="Phonopy: harmonic and quasi-harmonic phonon calculationss", path="phonopy", ) -def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: +def state_to_phonopy(state: ts.SimState) -> list[PhonopyAtoms]: """Convert a SimState to a list of PhonopyAtoms objects. Args: @@ -255,31 +256,34 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: description="ASE: Atomic Simulation Environment", path="ase", ) -def atoms_to_state( # noqa: C901 - atoms: "Atoms | list[Atoms]", +def atoms_to_state( + atoms: Atoms | list[Atoms], device: torch.device | None = None, dtype: torch.dtype | None = None, - system_extras_keys: list[str] | None = None, - atom_extras_keys: list[str] | None = None, -) -> "ts.SimState": + *, + system_extras: ExtrasMap | None = None, + atom_extras: ExtrasMap | None = None, +) -> ts.SimState: """Convert an ASE Atoms object or list of Atoms objects to a SimState. Args: - atoms (Atoms | list[Atoms]): Single ASE Atoms object or list of Atoms objects - device (torch.device): Device to create tensors on - dtype (torch.dtype): Data type for tensors (typically torch.float32 or - torch.float64) - system_extras_keys (list[str]): Optional list of keys to read from atoms.info - into _system_extras - atom_extras_keys (list[str]): Optional list of keys to read from atoms.arrays - into _atom_extras + atoms: Single ASE Atoms object or list of Atoms objects. + device: Device to create tensors on. + dtype: Data type for tensors (typically ``torch.float32`` or + ``torch.float64``). + system_extras: Map of ``{ts_key: ase_key}`` controlling which + ``atoms.info`` entries are read into ``_system_extras``. + ``None`` (default) means no extras are read. + atom_extras: Map of ``{ts_key: ase_key}`` controlling which + ``atoms.arrays`` entries are read into ``_atom_extras``. + ``None`` (default) means no extras are read. Returns: SimState: TorchSim SimState object. Raises: - ImportError: If ASE is not installed - ValueError: If systems have inconsistent periodic boundary conditions + ImportError: If ASE is not installed. + ValueError: If systems have inconsistent periodic boundary conditions. Notes: - Input positions and cell should be in Å @@ -320,31 +324,23 @@ def atoms_to_state( # noqa: C901 raise ValueError("All systems must have the same periodic boundary conditions") _system_extras: dict[str, torch.Tensor] = {} - - # charge and spin always default to 0 for backward compatibility - for key in ("charge", "spin"): - vals = np.array([float(at.info.get(key, 0.0)) for at in atoms_list]) - _system_extras[key] = torch.tensor(vals, dtype=dtype, device=device) - - if system_extras_keys: - for key in system_extras_keys: - if key in _system_extras: - continue - vals = [at.info.get(key) for at in atoms_list] - non_none_vals = [v for v in vals if v is not None] - if len(non_none_vals) == len(vals): - _system_extras[key] = torch.tensor( - np.stack(non_none_vals), dtype=dtype, device=device + if system_extras: + for ts_key, ase_key in system_extras.items(): + vals = [at.info.get(ase_key) for at in atoms_list] + non_none = [v for v in vals if v is not None] + if len(non_none) == len(vals): + _system_extras[ts_key] = torch.tensor( + np.array(non_none), dtype=dtype, device=device ) _atom_extras: dict[str, torch.Tensor] = {} - if atom_extras_keys: - for key in atom_extras_keys: - arrays = [at.arrays.get(key) for at in atoms_list] - non_none_arrays = [a for a in arrays if a is not None] - if len(non_none_arrays) == len(arrays): - _atom_extras[key] = torch.tensor( - np.concatenate(non_none_arrays), dtype=dtype, device=device + if atom_extras: + for ts_key, ase_key in atom_extras.items(): + arrays = [at.arrays.get(ase_key) for at in atoms_list] + non_none = [a for a in arrays if a is not None] + if len(non_none) == len(arrays): + _atom_extras[ts_key] = torch.tensor( + np.concatenate(non_none), dtype=dtype, device=device ) return ts.SimState( @@ -365,10 +361,10 @@ def atoms_to_state( # noqa: C901 path="pymatgen", ) def structures_to_state( - structure: "Structure | list[Structure]", + structure: Structure | list[Structure], device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> "ts.SimState": +) -> ts.SimState: """Create a SimState from pymatgen Structure(s). Args: @@ -446,10 +442,10 @@ def structures_to_state( path="phonopy", ) def phonopy_to_state( - phonopy_atoms: "PhonopyAtoms | list[PhonopyAtoms]", + phonopy_atoms: PhonopyAtoms | list[PhonopyAtoms], device: torch.device | None = None, dtype: torch.dtype | None = None, -) -> "ts.SimState": +) -> ts.SimState: """Create state tensors from a PhonopyAtoms object or list of PhonopyAtoms objects. Args: diff --git a/torch_sim/state.py b/torch_sim/state.py index 58890111b..fde62b5aa 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -17,7 +17,7 @@ from torch._prims_common import DeviceLikeType import torch_sim as ts -from torch_sim.typing import PRNGLike, StateLike +from torch_sim.typing import ExtrasMap, PRNGLike, StateLike if TYPE_CHECKING: @@ -618,13 +618,24 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: return cls(**attrs) - def to_atoms(self) -> list["Atoms"]: + def to_atoms( + self, + *, + system_extras: ExtrasMap | None = None, + atom_extras: ExtrasMap | None = None, + ) -> list["Atoms"]: """Convert the SimState to a list of ASE Atoms objects. + Args: + system_extras: Map of ``{ts_key: ase_key}`` for system extras. + atom_extras: Map of ``{ts_key: ase_key}`` for atom extras. + Returns: - list[Atoms]: A list of ASE Atoms objects, one per system + list[Atoms]: A list of ASE Atoms objects, one per system. """ - return ts.io.state_to_atoms(self) + return ts.io.state_to_atoms( + self, system_extras=system_extras, atom_extras=atom_extras + ) def to_structures(self) -> list["Structure"]: """Convert the SimState to a list of pymatgen Structure objects. diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 655105e55..a2221e9e9 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -63,6 +63,8 @@ class BravaisType(StrEnum): TRICLINIC = "triclinic" +ExtrasMap = dict[str, str] + StateLike = Union[ "Atoms", "Structure", From db6f0f421a465827613abd92778e65aeaeb05aad Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 5 Apr 2026 14:11:31 -0400 Subject: [PATCH 10/12] fix: all rather than any on retain graph --- torch_sim/models/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index c0ecc66a8..98d284335 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -265,7 +265,7 @@ def compute_forces(self, value: bool) -> None: # noqa: FBT001 @property def retain_graph(self) -> bool: """Whether any child model retains the computation graph.""" - return any(getattr(m, "retain_graph", False) for m in self._children()) + return all(getattr(m, "retain_graph", False) for m in self._children()) @retain_graph.setter def retain_graph(self, value: bool) -> None: From 1ba3ea4b6783a44fdc42b39108b14f3ab2491922 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 5 Apr 2026 18:56:09 -0400 Subject: [PATCH 11/12] fix units issue --- torch_sim/units.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torch_sim/units.py b/torch_sim/units.py index ac95cbeaa..19cd6c4fc 100644 --- a/torch_sim/units.py +++ b/torch_sim/units.py @@ -62,12 +62,6 @@ def __new__(cls, value: float) -> Self: eV_to_Hartree = 1.0 / Hartree_to_eV e2_per_Ang_to_eV = 14.399645478425668 - # Atomic-unit conversions (Bohr / Hartree <-> Angstrom / eV) - Bohr_to_Ang = 0.529177210903 - Ang_to_Bohr = 1.0 / Bohr_to_Ang - Hartree_to_eV = 27.211386245988 - eV_to_Hartree = 1.0 / Hartree_to_eV - uc = UnitConversion From c2892cf4164e87507bda3ac42680b22f23306bc7 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sun, 5 Apr 2026 21:02:14 -0400 Subject: [PATCH 12/12] patch orb forward to handle extras. Bump version to 0.6.0 for a new release after we merge this. --- pyproject.toml | 2 +- torch_sim/models/orb.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9ebc3752d..f4b9ab60b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "torch-sim-atomistic" -version = "0.5.2" +version = "0.6.0" description = "A pytorch toolkit for calculating material properties using MLIPs" authors = [ { name = "Abhijeet Gangan", email = "abhijeetgangan@g.ucla.edu" }, diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index 80587a372..5c910cf49 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -11,16 +11,38 @@ import warnings from typing import Any +import torch + try: from orb_models.forcefield.inference.orb_torchsim import OrbTorchSimModel + import torch_sim as ts + # Re-export with backward-compatible name class OrbModel(OrbTorchSimModel): """ORB model wrapper for torch-sim.""" + @staticmethod + def _normalize_charge_spin(state: "ts.SimState") -> "ts.SimState": + """Provide ORB's optional charge/spin inputs when they are missing.""" + charge = getattr(state, "charge", None) + spin = getattr(state, "spin", None) + if charge is not None and spin is not None: + return state + zeros = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + return ts.SimState.from_state( + state, + charge=charge if charge is not None else zeros, + spin=spin if spin is not None else zeros, + ) + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: """Run forward pass, detaching outputs unless retain_graph is True.""" + if args and isinstance(args[0], ts.SimState): + args = (self._normalize_charge_spin(args[0]), *args[1:]) + elif isinstance(kwargs.get("state"), ts.SimState): + kwargs["state"] = self._normalize_charge_spin(kwargs["state"]) output = super().forward(*args, **kwargs) return { # detach tensors as energy is not detached by default k: v.detach() if hasattr(v, "detach") else v for k, v in output.items()