diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3c17bc5a8..b51ab6559 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,5 +13,10 @@ jobs: - name: Check out repo uses: actions/checkout@v5 + - name: Set up uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + - name: Run prek uses: j178/prek-action@v1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 74a334ad4..a14a75748 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,8 +40,8 @@ repos: hooks: - id: ty name: ty check - entry: ty check + entry: env -u VIRTUAL_ENV UV_PROJECT_ENVIRONMENT=.venv_prek uv run --extra test --extra docs ty check language: python + additional_dependencies: [uv] types: [python] pass_filenames: false - additional_dependencies: [ty, torch, ase, "mace-torch>=0.3.15"] diff --git a/docs/dev/dev_install.md b/docs/dev/dev_install.md index f4779c87d..a48531c08 100644 --- a/docs/dev/dev_install.md +++ b/docs/dev/dev_install.md @@ -32,7 +32,17 @@ prek run --all-files ``` The `prek` command will ensure that changes to the source code match the -TorchSim style guidelines by running the `ruff` code linters and the `ty` type checker automatically with each commit. +TorchSim style guidelines by running the `ruff` code linters and the `ty` type checker automatically with each commit. If you observe differences between running this locally +and in CI then the most likely cause is that you have a `uv.lock` file causing differences +for some dependencies that are unpinned. +To resolve this, you can run: + +```bash +uv sync -U +``` + +to ensure that all dependencies are latest for the unpinned dependencies which is the +behavior that we see in CI. ## Running unit tests diff --git a/pyproject.toml b/pyproject.toml index 95749eeea..9f473cd5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,10 +48,10 @@ test = [ "spglib>=2.6", ] vesin = ["vesin[torch]>=0.5.3"] -io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"] +io = ["ase>=3.26", "phonopy>=3.0.0", "pymatgen>=2026.3.23"] symmetry = ["moyopy>=0.7.8"] mace = ["mace-torch>=0.3.16"] -mattersim = ["mattersim>=1.2.4"] +mattersim = ["mattersim>=1.2.5"] metatomic = ["metatomic-torchsim>=0.1.1", "metatomic-ase>=0.1.0", "upet>=0.2.0"] orb = ["orb-models>=0.6.2"] sevenn = ["sevenn[torchsim]>=0.12.1"] @@ -60,6 +60,7 @@ nequix = ["nequix[torch-sim]>=0.4.5"] fairchem = ["fairchem-core>=2.20.0"] docs = [ "autodoc_pydantic==2.2.0", + "duecredit>=0.11", "furo==2024.8.6", "ipykernel==6.30.1", "ipython==8.34.0", @@ -185,11 +186,14 @@ unused-ignore-comment = "warn" [[tool.ty.overrides]] include = [ "tests/models/**/*.py", + "torch_sim/autobatching.py", "torch_sim/constraints.py", "torch_sim/io.py", "torch_sim/models/**/*.py", "torch_sim/neighbors/alchemiops.py", "torch_sim/neighbors/vesin.py", + "torch_sim/optimizers/**/*.py", + "torch_sim/quantities.py", "torch_sim/state.py", "torch_sim/symmetrize.py", "torch_sim/trajectory.py", @@ -198,6 +202,12 @@ include = [ ] [tool.ty.overrides.rules] +index-out-of-bounds = "ignore" +invalid-argument-type = "ignore" +invalid-assignment = "ignore" +invalid-return-type = "ignore" +unsupported-operator = "ignore" +unresolved-attribute = "ignore" unresolved-import = "ignore" [[tool.ty.overrides]] @@ -213,6 +223,7 @@ invalid-assignment = "ignore" include = ["tests/**/*.py"] [tool.ty.overrides.rules] +call-non-callable = "ignore" invalid-argument-type = "ignore" invalid-assignment = "ignore" no-matching-overload = "ignore" @@ -223,6 +234,7 @@ unresolved-import = "ignore" include = ["docs/**/*.py", "docs/**/*.ipynb", "examples/**/*.py"] [tool.ty.overrides.rules] invalid-argument-type = "ignore" +invalid-assignment = "ignore" not-iterable = "ignore" not-subscriptable = "ignore" unresolved-attribute = "ignore" diff --git a/tests/models/test_dispersion.py b/tests/models/test_dispersion.py index 4dc5ddc76..5e7fe5b37 100644 --- a/tests/models/test_dispersion.py +++ b/tests/models/test_dispersion.py @@ -1,6 +1,8 @@ """Tests for the D3DispersionModel wrapper.""" -import traceback # noqa: I001 +from __future__ import annotations + +import traceback import pytest import torch @@ -8,15 +10,19 @@ from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test + try: from nvalchemiops.torch.interactions.dispersion import D3Parameters from torch_sim.models.dispersion import D3DispersionModel + + _IMPORT_ERROR: str | None = None except (ImportError, OSError, RuntimeError): - pytest.skip( - f"nvalchemiops not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"nvalchemiops not installed: {_IMPORT_ERROR}" +) def _make_d3_params(device: torch.device = DEVICE) -> D3Parameters: diff --git a/tests/models/test_electrostatics.py b/tests/models/test_electrostatics.py index e4e3feaa6..3621bea69 100644 --- a/tests/models/test_electrostatics.py +++ b/tests/models/test_electrostatics.py @@ -1,6 +1,8 @@ """Tests for the electrostatics ModelInterface wrappers.""" -import traceback # noqa: I001 +from __future__ import annotations + +import traceback import pytest import torch @@ -10,13 +12,17 @@ from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test + try: from torch_sim.models.electrostatics import DSFCoulombModel, EwaldModel, PMEModel + + _IMPORT_ERROR: str | None = None except (ImportError, OSError, RuntimeError): - pytest.skip( - f"nvalchemiops not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"nvalchemiops not installed: {_IMPORT_ERROR}" +) def _make_charged_state( diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index 7937c6921..77603b935 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import traceback import pytest @@ -11,11 +13,13 @@ from torch_sim.models.fairchem import FairChemModel + _IMPORT_ERROR: str | None = None except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip( - f"FairChem not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"FairChem not installed: {_IMPORT_ERROR}" +) @pytest.fixture @@ -25,7 +29,7 @@ def eqv2_uma_model_pbc() -> FairChemModel: test_fairchem_uma_model_outputs = pytest.mark.skipif( - get_token() is None, + _IMPORT_ERROR is not None or get_token() is None, reason="Requires HuggingFace authentication for UMA model access", )( make_validate_model_outputs_test( diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index 885f5ce81..b2b75bd7a 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import random import time import traceback import urllib.error -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import pytest import torch -from ase.atoms import Atoms import torch_sim as ts + + +if TYPE_CHECKING: + from collections.abc import Callable + + from ase.atoms import Atoms from tests.conftest import DEVICE from tests.models.conftest import ( make_model_calculator_consistency_test, @@ -23,8 +29,14 @@ from mace.calculators.foundations_models import mace_mp, mace_off, mace_omol from torch_sim.models.mace import MaceModel + + _IMPORT_ERROR: str | None = None except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"MACE not installed: {_IMPORT_ERROR}" +) DTYPE = torch.float64 MAX_RETRIES = 3 diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index c5dcd69e3..f0a985b71 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import traceback import pytest @@ -7,7 +9,7 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.testing import SIMSTATE_GENERATORS +from torch_sim.testing import SIMSTATE_GENERATORS, ModelTolerance try: @@ -15,11 +17,13 @@ from torch_sim.models.mattersim import MatterSimModel + _IMPORT_ERROR: str | None = None except (ImportError, OSError, RuntimeError, AttributeError, ValueError): - pytest.skip( - f"mattersim not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"mattersim not installed: {_IMPORT_ERROR}" +) model_name = "mattersim-v1.0.0-1m.pth" @@ -59,6 +63,10 @@ def test_mattersim_initialization(pretrained_mattersim_model: Potential) -> None model_fixture_name="mattersim_model", calculator_fixture_name="mattersim_calculator", sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), + energy_rtol=ModelTolerance.LOOSE, + energy_atol=ModelTolerance.LOOSE, + force_rtol=ModelTolerance.LOOSE, + force_atol=ModelTolerance.STANDARD, ) test_mattersim_model_outputs = make_validate_model_outputs_test( diff --git a/tests/models/test_metatomic.py b/tests/models/test_metatomic.py index c853f55af..80ff527b3 100644 --- a/tests/models/test_metatomic.py +++ b/tests/models/test_metatomic.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import traceback +from typing import TYPE_CHECKING import pytest import torch @@ -8,20 +11,25 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.testing import SIMSTATE_GENERATORS +from torch_sim.testing import SIMSTATE_GENERATORS, ModelTolerance -try: +if TYPE_CHECKING: from metatomic.torch import AtomisticModel + +try: from metatomic_ase import MetatomicCalculator from upet import get_upet from torch_sim.models.metatomic import MetatomicModel + + _IMPORT_ERROR: str | None = None except ImportError: - pytest.skip( - f"metatomic not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"metatomic not installed: {_IMPORT_ERROR}" +) @pytest.fixture @@ -50,7 +58,7 @@ def test_metatomic_initialization() -> None: model_fixture_name="metatomic_model", calculator_fixture_name="metatomic_calculator", sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), - energy_atol=5e-5, + energy_atol=ModelTolerance.LOOSE, dtype=torch.float32, device=DEVICE, ) diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index e148dae3d..1a82a6051 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import traceback import urllib.request from pathlib import Path @@ -9,7 +11,7 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.testing import SIMSTATE_BULK_GENERATORS +from torch_sim.testing import SIMSTATE_BULK_GENERATORS, ModelTolerance try: @@ -17,11 +19,14 @@ from nequip.scripts.compile import main from torch_sim.models.nequip_framework import NequIPFrameworkModel + + _IMPORT_ERROR: str | None = None except (ImportError, ModuleNotFoundError): - pytest.skip( - f"nequip not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"nequip not installed: {_IMPORT_ERROR}" +) # Cache directory for compiled models (under tests/ for easy cleanup) @@ -117,14 +122,13 @@ def nequip_calculator(compiled_ase_nequip_model_path: Path) -> NequIPCalculator: ) -# NOTE: we take [:-1] to skip benzene. This is because the stress calculation in NequIP -# for non-periodic systems gave infinity. +# NOTE: skip molecule sim states as stress in NequIP gave inf. test_nequip_consistency = make_model_calculator_consistency_test( test_name="nequip", model_fixture_name="nequip_model", calculator_fixture_name="nequip_calculator", sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), - energy_atol=5e-5, + energy_atol=ModelTolerance.LOOSE, dtype=DTYPE, device=DEVICE, ) diff --git a/tests/models/test_nequix.py b/tests/models/test_nequix.py index f272e5a5c..7fef7092c 100644 --- a/tests/models/test_nequix.py +++ b/tests/models/test_nequix.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import traceback import pytest @@ -7,18 +9,21 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.testing import SIMSTATE_BULK_GENERATORS +from torch_sim.testing import SIMSTATE_BULK_GENERATORS, ModelTolerance try: from nequix.calculator import NequixCalculator from torch_sim.models.nequix import NequixModel + + _IMPORT_ERROR: str | None = None except (ImportError, ModuleNotFoundError): - pytest.skip( - f"nequix not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"nequix not installed: {_IMPORT_ERROR}" +) @pytest.fixture(scope="session") @@ -42,7 +47,7 @@ def nequix_calculator() -> NequixCalculator: model_fixture_name="nequix_model", calculator_fixture_name="nequix_calculator", sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()), - force_atol=5e-5, + force_atol=ModelTolerance.LOOSE, dtype=DTYPE, device=DEVICE, ) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 1e7483cc1..a44dab20a 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import traceback import pytest @@ -7,7 +9,7 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.testing import SIMSTATE_GENERATORS +from torch_sim.testing import SIMSTATE_GENERATORS, ModelTolerance try: @@ -16,8 +18,13 @@ from torch_sim.models.orb import OrbModel + _IMPORT_ERROR: str | None = None except ImportError: - pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"ORB not installed: {_IMPORT_ERROR}" +) @pytest.fixture @@ -59,8 +66,8 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: model_fixture_name="orbv3_conservative_inf_omat_model", calculator_fixture_name="orbv3_conservative_inf_omat_calculator", sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), - energy_rtol=5e-5, - energy_atol=5e-5, + energy_rtol=ModelTolerance.LOOSE, + energy_atol=ModelTolerance.LOOSE, ) test_orb_direct_consistency = make_model_calculator_consistency_test( @@ -68,8 +75,8 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: model_fixture_name="orbv3_direct_20_omat_model", calculator_fixture_name="orbv3_direct_20_omat_calculator", sim_state_names=tuple(SIMSTATE_GENERATORS.keys()), - energy_rtol=5e-5, - energy_atol=5e-5, + energy_rtol=ModelTolerance.LOOSE, + energy_atol=ModelTolerance.LOOSE, ) test_validate_conservative_model_outputs = make_validate_model_outputs_test( diff --git a/tests/models/test_sevennet.py b/tests/models/test_sevennet.py index b5e759e41..658678935 100644 --- a/tests/models/test_sevennet.py +++ b/tests/models/test_sevennet.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import traceback +from typing import TYPE_CHECKING import pytest import torch @@ -11,18 +14,22 @@ from torch_sim.testing import SIMSTATE_BULK_GENERATORS +if TYPE_CHECKING: + from sevenn.nn.sequential import AtomGraphSequential + try: import sevenn.util from sevenn.calculator import SevenNetCalculator - from sevenn.nn.sequential import AtomGraphSequential from torch_sim.models.sevennet import SevenNetModel + _IMPORT_ERROR: str | None = None except ImportError: - pytest.skip( - f"sevenn not installed: {traceback.format_exc()}", - allow_module_level=True, - ) + _IMPORT_ERROR = traceback.format_exc() + +pytestmark = pytest.mark.skipif( + _IMPORT_ERROR is not None, reason=f"sevenn not installed: {_IMPORT_ERROR}" +) model_name = "sevennet-mf-ompa" diff --git a/torch_sim/_citations.py b/torch_sim/_citations.py index 3cfc15c0a..27b0c349b 100644 --- a/torch_sim/_citations.py +++ b/torch_sim/_citations.py @@ -8,7 +8,7 @@ if due is not None: - due.cite( + due.cite( # ty: ignore[unresolved-attribute] BibTeX( """@article{cohen2025torchsim, title={TorchSim: An efficient atomistic simulation engine in PyTorch}, @@ -28,7 +28,7 @@ path="torch_sim", cite_module=True, ) - due.cite( + due.cite( # ty: ignore[unresolved-attribute] BibTeX( """@inproceedings{paszke2019pytorch, title={PyTorch: An Imperative Style, High-Performance Deep Learning Library}, diff --git a/torch_sim/_duecredit.py b/torch_sim/_duecredit.py index 2975f3cbf..e20421f7f 100644 --- a/torch_sim/_duecredit.py +++ b/torch_sim/_duecredit.py @@ -46,12 +46,12 @@ def _disable_duecredit(exc: Exception) -> None: try: - from duecredit import BibTeX, Doi, Text, Url, due # type: ignore[unresolved-import] + from duecredit import BibTeX, Doi, Text, Url, due except Exception as e: # noqa: BLE001 if not isinstance(e, ImportError): _disable_duecredit(e) - due = InactiveDueCreditCollector() - BibTeX = Doi = Url = Text = _donothing_func + due = InactiveDueCreditCollector() # ty: ignore[invalid-assignment] + BibTeX = Doi = Url = Text = _donothing_func # ty: ignore[invalid-assignment] def dcite( @@ -63,4 +63,4 @@ def dcite( ) if path is not None: kwargs["path"] = path - return due.dcite(Doi(doi), **kwargs) + return due.dcite(Doi(doi), **kwargs) # ty: ignore[unresolved-attribute] diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index eddc65c23..d43241664 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -94,11 +94,11 @@ def to_constant_volume_bins( # noqa: C901 # placed in the output bin: the dict key, the original tuple, or the weight. is_dict = isinstance(items, dict) if is_dict: - entries = [(weight, k) for k, weight in items.items()] # ty: ignore[unresolved-attribute] + entries = [(weight, k) for k, weight in items.items()] # list of objects: dispatch on how to extract the weight from each item elif weight_pos is not None: # weight lives at a fixed tuple/list position; payload is the original item - entries = [(item[weight_pos], item) for item in items] # ty: ignore[not-subscriptable] + entries = [(item[weight_pos], item) for item in items] elif key is not None: # custom extractor for arbitrary item types; payload is the original item entries = [(key(item), item) for item in items] @@ -912,7 +912,7 @@ def _get_next_states(self) -> list[T]: metric = calculate_memory_scalers( state, self.memory_scales_with, self.cutoff )[0] - if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator] + if metric > self.max_memory_scaler: raise ValueError( f"State {metric=} is greater than max_metric {self.max_memory_scaler}" ", please set a larger max_metric or run smaller systems metric." diff --git a/torch_sim/models/electrostatics.py b/torch_sim/models/electrostatics.py index 2ea9863b9..092f5bfee 100644 --- a/torch_sim/models/electrostatics.py +++ b/torch_sim/models/electrostatics.py @@ -254,7 +254,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] "energy": energy.to(self._dtype).detach(), } if self._compute_forces: - forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] + forces = out[1] * UnitConversion.e2_per_Ang_to_eV results["forces"] = forces.to(self._dtype).detach() if self._compute_stress: volumes = state.volume.unsqueeze(-1).unsqueeze(-1) @@ -383,7 +383,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] "energy": energy.to(self._dtype).detach(), } if self._compute_forces: - forces = out[1] * UnitConversion.e2_per_Ang_to_eV # type: ignore[index] + forces = out[1] * UnitConversion.e2_per_Ang_to_eV results["forces"] = forces.to(self._dtype).detach() if self._compute_stress: volumes = state.volume.unsqueeze(-1).unsqueeze(-1) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index b4e23b95e..d92da8d41 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -173,7 +173,7 @@ def fire_step( step_func_kwargs["max_step"] = max_step step_func = {vv_fire_key: _vv_fire_step, ase_fire_key: _ase_fire_step}[fire_flavor] - return step_func(state, **step_func_kwargs) # ty: ignore[invalid-argument-type] + return step_func(state, **step_func_kwargs) def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 3f902b1a6..bc29ed597 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -396,7 +396,7 @@ def integrate[T: SimState]( # noqa: C901 if isinstance(batch_iterator, BinningAutoBatcher): reordered = batch_iterator.restore_original_order(final_states) # ty: ignore[invalid-argument-type] - result = ts.concatenate_states(reordered) # ty: ignore[invalid-argument-type] + result = ts.concatenate_states(reordered) logger.info("integrate: complete, %d systems returned", result.n_systems) return result diff --git a/torch_sim/state.py b/torch_sim/state.py index 6584480a3..2c3e28b89 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1083,9 +1083,9 @@ def _filter_attrs_by_index( new_system_idx = system_remap[torch.where(system_mask)[0]] for c in filtered_attrs["_constraints"]: if hasattr(c, "atom_idx") and isinstance(c.atom_idx, torch.Tensor): - c.atom_idx = new_atom_idx[c.atom_idx] # ty: ignore[invalid-assignment] + c.atom_idx = new_atom_idx[c.atom_idx] if hasattr(c, "system_idx") and isinstance(c.system_idx, torch.Tensor): - c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment] + c.system_idx = new_system_idx[c.system_idx] for name, val in get_attrs_for_scope(state, "per-atom"): if name in state.atom_extras: @@ -1195,7 +1195,7 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 new_constraints.append(sub) system_attrs["_constraints"] = new_constraints - states.append(type(state)(**system_attrs)) # ty: ignore[invalid-argument-type] + states.append(type(state)(**system_attrs)) return states diff --git a/torch_sim/testing.py b/torch_sim/testing.py index 01e1ef092..f73534ad6 100644 --- a/torch_sim/testing.py +++ b/torch_sim/testing.py @@ -26,6 +26,7 @@ def test_my_model_consistency(sim_state_name, my_model, my_calculator): """ from collections.abc import Callable +from enum import Enum from typing import TYPE_CHECKING, Final import torch @@ -40,6 +41,37 @@ def test_my_model_consistency(sim_state_name, my_model, my_calculator): from torch_sim.models.interface import ModelInterface +class ModelTolerance(float, Enum): + """Named tolerance levels for model tests. + + Centralizes the absolute/relative tolerance values used across model tests, + spanning model-vs-calculator consistency, autograd-vs-direct force checks, + neighbor-list equivalence, symmetry, and exact-agreement assertions. Choose + the loosest level that still catches the regressions you care about. + """ + + COARSE = 1e-3 + VERY_LOOSE = 1e-4 + LOOSE = 5e-5 + STANDARD = 1e-5 + TIGHT = 1e-6 + VERY_TIGHT = 1e-7 + STRICT = 1e-10 + EXACT = 1e-13 + + def __repr__(self) -> str: + """Return a string representation of the tolerance level.""" + return f"{type(self).__name__}.{self.name} ({float(self.value):g})" + + __str__ = __repr__ + + def __format__(self, format_spec: str) -> str: + """Format the tolerance level as a string.""" + if format_spec: + return float(self.value).__format__(format_spec) + return self.__repr__() + + def make_cu_sim_state( device: torch.device | None = None, dtype: torch.dtype | None = None ) -> ts.SimState: @@ -481,12 +513,12 @@ def assert_model_calculator_consistency( model: "ModelInterface", calculator: "Calculator", sim_state: ts.SimState, - energy_rtol: float = 1e-5, - energy_atol: float = 1e-5, - force_rtol: float = 1e-5, - force_atol: float = 1e-5, - stress_rtol: float = 1e-5, - stress_atol: float = 1e-5, + energy_rtol: float = ModelTolerance.STANDARD, + energy_atol: float = ModelTolerance.STANDARD, + force_rtol: float = ModelTolerance.STANDARD, + force_atol: float = ModelTolerance.STANDARD, + stress_rtol: float = ModelTolerance.STANDARD, + stress_atol: float = ModelTolerance.STANDARD, ) -> None: """Assert consistency between model and calculator implementations.