Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,10 @@ jobs:
- name: Check out repo
uses: actions/checkout@v5

- name: Set up uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true

- name: Run prek
uses: j178/prek-action@v1
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ repos:
hooks:
- id: ty
name: ty check
entry: ty check
entry: env -u VIRTUAL_ENV UV_PROJECT_ENVIRONMENT=.venv_prek uv run --extra test --extra docs ty check
language: python
additional_dependencies: [uv]
types: [python]
pass_filenames: false
additional_dependencies: [ty, torch, ase, "mace-torch>=0.3.15"]
12 changes: 11 additions & 1 deletion docs/dev/dev_install.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,17 @@ prek run --all-files
```

The `prek` command will ensure that changes to the source code match the
TorchSim style guidelines by running the `ruff` code linters and the `ty` type checker automatically with each commit.
TorchSim style guidelines by running the `ruff` code linters and the `ty` type checker automatically with each commit. If you observe differences between running this locally
and in CI then the most likely cause is that you have a `uv.lock` file causing differences
for some dependencies that are unpinned.
To resolve this, you can run:

```bash
uv sync -U
```

to ensure that all dependencies are latest for the unpinned dependencies which is the
behavior that we see in CI.

## Running unit tests

Expand Down
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ test = [
"spglib>=2.6",
]
vesin = ["vesin[torch]>=0.5.3"]
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2026.3.23"]
io = ["ase>=3.26", "phonopy>=3.0.0", "pymatgen>=2026.3.23"]
symmetry = ["moyopy>=0.7.8"]
mace = ["mace-torch>=0.3.16"]
mattersim = ["mattersim>=1.2.4"]
mattersim = ["mattersim>=1.2.5"]
metatomic = ["metatomic-torchsim>=0.1.1", "metatomic-ase>=0.1.0", "upet>=0.2.0"]
orb = ["orb-models>=0.6.2"]
sevenn = ["sevenn[torchsim]>=0.12.1"]
Expand All @@ -60,6 +60,7 @@ nequix = ["nequix[torch-sim]>=0.4.5"]
fairchem = ["fairchem-core>=2.20.0"]
docs = [
"autodoc_pydantic==2.2.0",
"duecredit>=0.11",
"furo==2024.8.6",
"ipykernel==6.30.1",
"ipython==8.34.0",
Expand Down Expand Up @@ -185,11 +186,14 @@ unused-ignore-comment = "warn"
[[tool.ty.overrides]]
include = [
"tests/models/**/*.py",
"torch_sim/autobatching.py",
"torch_sim/constraints.py",
"torch_sim/io.py",
"torch_sim/models/**/*.py",
"torch_sim/neighbors/alchemiops.py",
"torch_sim/neighbors/vesin.py",
"torch_sim/optimizers/**/*.py",
"torch_sim/quantities.py",
"torch_sim/state.py",
"torch_sim/symmetrize.py",
"torch_sim/trajectory.py",
Expand All @@ -198,6 +202,12 @@ include = [
]

[tool.ty.overrides.rules]
index-out-of-bounds = "ignore"
invalid-argument-type = "ignore"
invalid-assignment = "ignore"
invalid-return-type = "ignore"
unsupported-operator = "ignore"
unresolved-attribute = "ignore"
unresolved-import = "ignore"

[[tool.ty.overrides]]
Expand All @@ -213,6 +223,7 @@ invalid-assignment = "ignore"
include = ["tests/**/*.py"]

[tool.ty.overrides.rules]
call-non-callable = "ignore"
invalid-argument-type = "ignore"
invalid-assignment = "ignore"
no-matching-overload = "ignore"
Expand All @@ -223,6 +234,7 @@ unresolved-import = "ignore"
include = ["docs/**/*.py", "docs/**/*.ipynb", "examples/**/*.py"]
[tool.ty.overrides.rules]
invalid-argument-type = "ignore"
invalid-assignment = "ignore"
not-iterable = "ignore"
not-subscriptable = "ignore"
unresolved-attribute = "ignore"
Expand Down
16 changes: 11 additions & 5 deletions tests/models/test_dispersion.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
"""Tests for the D3DispersionModel wrapper."""

import traceback # noqa: I001
from __future__ import annotations

import traceback

import pytest
import torch

from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test


try:
from nvalchemiops.torch.interactions.dispersion import D3Parameters

from torch_sim.models.dispersion import D3DispersionModel

_IMPORT_ERROR: str | None = None
except (ImportError, OSError, RuntimeError):
pytest.skip(
f"nvalchemiops not installed: {traceback.format_exc()}",
allow_module_level=True,
)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"nvalchemiops not installed: {_IMPORT_ERROR}"
)


def _make_d3_params(device: torch.device = DEVICE) -> D3Parameters:
Expand Down
16 changes: 11 additions & 5 deletions tests/models/test_electrostatics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Tests for the electrostatics ModelInterface wrappers."""

import traceback # noqa: I001
from __future__ import annotations

import traceback

import pytest
import torch
Expand All @@ -10,13 +12,17 @@
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test


try:
from torch_sim.models.electrostatics import DSFCoulombModel, EwaldModel, PMEModel

_IMPORT_ERROR: str | None = None
except (ImportError, OSError, RuntimeError):
pytest.skip(
f"nvalchemiops not installed: {traceback.format_exc()}",
allow_module_level=True,
)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"nvalchemiops not installed: {_IMPORT_ERROR}"
)


def _make_charged_state(
Expand Down
14 changes: 9 additions & 5 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import traceback

import pytest
Expand All @@ -11,11 +13,13 @@

from torch_sim.models.fairchem import FairChemModel

_IMPORT_ERROR: str | None = None
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(
f"FairChem not installed: {traceback.format_exc()}",
allow_module_level=True,
)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"FairChem not installed: {_IMPORT_ERROR}"
)


@pytest.fixture
Expand All @@ -25,7 +29,7 @@ def eqv2_uma_model_pbc() -> FairChemModel:


test_fairchem_uma_model_outputs = pytest.mark.skipif(
get_token() is None,
_IMPORT_ERROR is not None or get_token() is None,
reason="Requires HuggingFace authentication for UMA model access",
)(
make_validate_model_outputs_test(
Expand Down
20 changes: 16 additions & 4 deletions tests/models/test_mace.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import random
import time
import traceback
import urllib.error
from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any

import pytest
import torch
from ase.atoms import Atoms

import torch_sim as ts


if TYPE_CHECKING:
from collections.abc import Callable

from ase.atoms import Atoms
from tests.conftest import DEVICE
from tests.models.conftest import (
make_model_calculator_consistency_test,
Expand All @@ -23,8 +29,14 @@
from mace.calculators.foundations_models import mace_mp, mace_off, mace_omol

from torch_sim.models.mace import MaceModel

_IMPORT_ERROR: str | None = None
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"MACE not installed: {_IMPORT_ERROR}"
)

DTYPE = torch.float64
MAX_RETRIES = 3
Expand Down
18 changes: 13 additions & 5 deletions tests/models/test_mattersim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import traceback

import pytest
Expand All @@ -7,19 +9,21 @@
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)
from torch_sim.testing import SIMSTATE_GENERATORS
from torch_sim.testing import SIMSTATE_GENERATORS, ModelTolerance


try:
from mattersim.forcefield import MatterSimCalculator, Potential

from torch_sim.models.mattersim import MatterSimModel

_IMPORT_ERROR: str | None = None
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
pytest.skip(
f"mattersim not installed: {traceback.format_exc()}",
allow_module_level=True,
)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"mattersim not installed: {_IMPORT_ERROR}"
)


model_name = "mattersim-v1.0.0-1m.pth"
Expand Down Expand Up @@ -59,6 +63,10 @@ def test_mattersim_initialization(pretrained_mattersim_model: Potential) -> None
model_fixture_name="mattersim_model",
calculator_fixture_name="mattersim_calculator",
sim_state_names=tuple(SIMSTATE_GENERATORS.keys()),
energy_rtol=ModelTolerance.LOOSE,
energy_atol=ModelTolerance.LOOSE,
force_rtol=ModelTolerance.LOOSE,
force_atol=ModelTolerance.STANDARD,
)

test_mattersim_model_outputs = make_validate_model_outputs_test(
Expand Down
22 changes: 15 additions & 7 deletions tests/models/test_metatomic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import traceback
from typing import TYPE_CHECKING

import pytest
import torch
Expand All @@ -8,20 +11,25 @@
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)
from torch_sim.testing import SIMSTATE_GENERATORS
from torch_sim.testing import SIMSTATE_GENERATORS, ModelTolerance


try:
if TYPE_CHECKING:
from metatomic.torch import AtomisticModel

try:
from metatomic_ase import MetatomicCalculator
from upet import get_upet

from torch_sim.models.metatomic import MetatomicModel

_IMPORT_ERROR: str | None = None
except ImportError:
pytest.skip(
f"metatomic not installed: {traceback.format_exc()}",
allow_module_level=True,
)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"metatomic not installed: {_IMPORT_ERROR}"
)


@pytest.fixture
Expand Down Expand Up @@ -50,7 +58,7 @@ def test_metatomic_initialization() -> None:
model_fixture_name="metatomic_model",
calculator_fixture_name="metatomic_calculator",
sim_state_names=tuple(SIMSTATE_GENERATORS.keys()),
energy_atol=5e-5,
energy_atol=ModelTolerance.LOOSE,
dtype=torch.float32,
device=DEVICE,
)
Expand Down
20 changes: 12 additions & 8 deletions tests/models/test_nequip_framework.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import traceback
import urllib.request
from pathlib import Path
Expand All @@ -9,19 +11,22 @@
make_model_calculator_consistency_test,
make_validate_model_outputs_test,
)
from torch_sim.testing import SIMSTATE_BULK_GENERATORS
from torch_sim.testing import SIMSTATE_BULK_GENERATORS, ModelTolerance


try:
from nequip.ase import NequIPCalculator
from nequip.scripts.compile import main

from torch_sim.models.nequip_framework import NequIPFrameworkModel

_IMPORT_ERROR: str | None = None
except (ImportError, ModuleNotFoundError):
pytest.skip(
f"nequip not installed: {traceback.format_exc()}",
allow_module_level=True,
)
_IMPORT_ERROR = traceback.format_exc()

pytestmark = pytest.mark.skipif(
_IMPORT_ERROR is not None, reason=f"nequip not installed: {_IMPORT_ERROR}"
)


# Cache directory for compiled models (under tests/ for easy cleanup)
Expand Down Expand Up @@ -117,14 +122,13 @@ def nequip_calculator(compiled_ase_nequip_model_path: Path) -> NequIPCalculator:
)


# NOTE: we take [:-1] to skip benzene. This is because the stress calculation in NequIP
# for non-periodic systems gave infinity.
# NOTE: skip molecule sim states as stress in NequIP gave inf.
test_nequip_consistency = make_model_calculator_consistency_test(
test_name="nequip",
model_fixture_name="nequip_model",
calculator_fixture_name="nequip_calculator",
sim_state_names=tuple(SIMSTATE_BULK_GENERATORS.keys()),
energy_atol=5e-5,
energy_atol=ModelTolerance.LOOSE,
dtype=DTYPE,
device=DEVICE,
)
Expand Down
Loading
Loading