Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/benchmarking/neighborlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n-repeats",
type=int,
default=3,
help="Number of timed repetitions (median is reported).",
default=7,
help="Number of timed repetitions (trimmed mean excluding min/max is reported).",
)
parsed_args = parser.parse_args()
if parsed_args.n_structures <= 0:
Expand Down Expand Up @@ -264,13 +264,15 @@ def _benchmark_backend(
torch.cuda.synchronize()
timings.append(time.perf_counter() - t0)

median_s = float(np.median(timings))
sorted_timings = sorted(timings)
trimmed = sorted_timings[1:-1] if len(sorted_timings) > 2 else sorted_timings
trimmed_mean_s = float(np.mean(trimmed))
return {
"nl_backend": backend,
"n_pairs": int(mapping.shape[1]),
"median_nl_s": round(median_s, 6),
"trimmed_mean_nl_s": round(trimmed_mean_s, 6),
"timings_s": [round(t, 6) for t in timings],
"atoms_per_s": round(n_atoms / median_s, 1) if median_s > 0 else 0,
"atoms_per_s": round(n_atoms / trimmed_mean_s, 1) if trimmed_mean_s > 0 else 0,
}


Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ vesin = ["vesin[torch]>=0.5.3"]
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
symmetry = ["moyopy>=0.7.8"]
mace = ["mace-torch>=0.3.15"]
mattersim = ["mattersim>=0.1.2"]
mattersim = ["mattersim>=1.2.2"]
metatomic = ["metatomic-torchsim>=0.1.1", "metatomic-ase>=0.1.0", "upet>=0.2.0"]
orb = ["orb-models>=0.6.2"]
sevenn = ["sevenn[torchsim]>=0.12.1"]
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
nequip = ["nequip>=0.17.0"]
nequip = ["nequip>=0.17.1"]
nequix = ["nequix[torch-sim]>=0.4.5"]
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
docs = [
Expand Down
22 changes: 11 additions & 11 deletions tests/models/test_nequip_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,30 @@
# Cache directory for compiled models (under tests/ for easy cleanup)
NEQUIP_CACHE_DIR = Path(__file__).parent.parent / ".cache" / "nequip_compiled_models"

# Zenodo URL for NequIP-OAM-L model (more reliable than nequip.net for CI)
NEQUIP_OAM_L_ZENODO_URL = (
"https://zenodo.org/records/16980200/files/NequIP-OAM-L-0.1.nequip.zip?download=1"
# Zenodo URL for NequIP-OAM-S model (more reliable than nequip.net for CI)
NEQUIP_OAM_S_ZENODO_URL = (
"https://zenodo.org/records/18775904/files/NequIP-OAM-S-0.1.nequip.zip?download=1"
)
NEQUIP_OAM_L_ZIP_NAME = "NequIP-OAM-L-0.1.nequip.zip"
NEQUIP_OAM_S_ZIP_NAME = "NequIP-OAM-S-0.1.nequip.zip"


def _get_nequip_model_zip() -> Path:
"""Download NequIP-OAM-L model from Zenodo if not already cached."""
"""Download NequIP-OAM-S model from Zenodo if not already cached."""
NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True)
zip_path = NEQUIP_CACHE_DIR / NEQUIP_OAM_L_ZIP_NAME
zip_path = NEQUIP_CACHE_DIR / NEQUIP_OAM_S_ZIP_NAME

if not zip_path.exists():
urllib.request.urlretrieve(NEQUIP_OAM_L_ZENODO_URL, zip_path) # noqa: S310
urllib.request.urlretrieve(NEQUIP_OAM_S_ZENODO_URL, zip_path) # noqa: S310

return zip_path


@pytest.fixture(scope="session")
def compiled_ase_nequip_model_path() -> Path:
"""Compile NequIP OAM-L model from Zenodo for ASE (with persistent caching)."""
"""Compile NequIP OAM-S model from Zenodo for ASE (with persistent caching)."""
NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True)

output_model_name = f"mir-group__NequIP-OAM-L__0.1__{DEVICE.type}_ase.nequip.pt2"
output_model_name = f"mir-group__NequIP-OAM-S__0.1__{DEVICE.type}_ase.nequip.pt2"
output_path = NEQUIP_CACHE_DIR / output_model_name

# Only compile if not already cached
Expand All @@ -74,10 +74,10 @@ def compiled_ase_nequip_model_path() -> Path:

@pytest.fixture(scope="session")
def compiled_batch_nequip_model_path() -> Path:
"""Compile NequIP OAM-L model from Zenodo for batch (with persistent caching)."""
"""Compile NequIP OAM-S model from Zenodo for batch (with persistent caching)."""
NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True)

output_model_name = f"mir-group__NequIP-OAM-L__0.1__{DEVICE.type}_batch.nequip.pt2"
output_model_name = f"mir-group__NequIP-OAM-S__0.1__{DEVICE.type}_batch.nequip.pt2"
output_path = NEQUIP_CACHE_DIR / output_model_name

# Only compile if not already cached
Expand Down
22 changes: 17 additions & 5 deletions tests/models/test_orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,22 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator:
energy_atol=5e-5,
)

test_validate_conservative_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orbv3_conservative_inf_omat_model",
test_validate_conservative_model_outputs = pytest.mark.xfail(
reason=(
"Upstream ORB conservative model incorrectly squeezes length-1 batch "
"dimensions; see https://github.com/orbital-materials/orb-models/pull/158"
),
strict=False,
)(
make_validate_model_outputs_test(
model_fixture_name="orbv3_conservative_inf_omat_model",
)
)

test_validate_direct_model_outputs = make_validate_model_outputs_test(
model_fixture_name="orbv3_direct_20_omat_model",
)
test_validate_direct_model_outputs = pytest.mark.xfail(
reason=(
"Upstream ORB direct model shows batch-dependent leakage; "
"see https://github.com/orbital-materials/orb-models/issues/159"
),
strict=False,
)(make_validate_model_outputs_test(model_fixture_name="orbv3_direct_20_omat_model"))
34 changes: 0 additions & 34 deletions torch_sim/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,6 @@
)


def _normalize_inputs(
cell: torch.Tensor, pbc: torch.Tensor, n_systems: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Normalize cell and PBC tensors to standard batch format.

Handles multiple input formats:
- cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3]
- pbc: [3], [n_systems, 3], or [n_systems*3]

Returns:
(cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3])
Both tensors are guaranteed to be contiguous.
"""
# Normalize cell
if cell.ndim == 2:
if cell.shape[0] == 3:
cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous()
else:
cell = cell.reshape(n_systems, 3, 3).contiguous()
else:
cell = cell.contiguous()

# Normalize PBC
if pbc.ndim == 1:
if pbc.shape[0] == 3:
pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous()
else:
pbc = pbc.reshape(n_systems, 3).contiguous()
else:
pbc = pbc.contiguous()

return cell, pbc


# Set default neighbor list based on what's available (priority order)
if ALCHEMIOPS_AVAILABLE:
# Alchemiops is fastest on NVIDIA GPUs
Expand Down
11 changes: 4 additions & 7 deletions torch_sim/neighbors/alchemiops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import torch

from torch_sim.neighbors.utils import normalize_inputs


_batch_naive_neighbor_list: object | None = None
_batch_cell_list: object | None = None
Expand Down Expand Up @@ -64,13 +66,10 @@ def alchemiops_nl_n2(
Returns:
(mapping, system_mapping, shifts_idx)
"""
from torch_sim.neighbors import _normalize_inputs

r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff
n_systems = int(system_idx.max().item()) + 1
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
cell, pbc = normalize_inputs(cell, pbc, n_systems)

# Call alchemiops neighbor list
if _batch_naive_neighbor_list is None:
raise RuntimeError("nvalchemiops neighbor list is unavailable")
res = _batch_naive_neighbor_list(
Expand Down Expand Up @@ -138,11 +137,9 @@ def alchemiops_nl_cell_list(
Returns:
(mapping, system_mapping, shifts_idx)
"""
from torch_sim.neighbors import _normalize_inputs

r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff
n_systems = int(system_idx.max().item()) + 1
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
cell, pbc = normalize_inputs(cell, pbc, n_systems)

# For non-periodic systems with zero cells, use a nominal identity cell
# to avoid division by zero in alchemiops warp kernels
Expand Down
31 changes: 3 additions & 28 deletions torch_sim/neighbors/torch_nl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,7 @@
import torch

from torch_sim import transforms


@torch.jit.script
def _normalize_inputs_jit(
cell: torch.Tensor, pbc: torch.Tensor, n_systems: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""JIT-compatible input normalization for torch_nl functions."""
# Normalize cell
if cell.ndim == 2:
if cell.shape[0] == 3:
cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous()
else:
cell = cell.reshape(n_systems, 3, 3).contiguous()
else:
cell = cell.contiguous()

# Normalize PBC
if pbc.ndim == 1:
if pbc.shape[0] == 3:
pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous()
else:
pbc = pbc.reshape(n_systems, 3).contiguous()
else:
pbc = pbc.contiguous()

return cell, pbc
from torch_sim.neighbors.utils import normalize_inputs


def strict_nl(
Expand Down Expand Up @@ -162,7 +137,7 @@ def torch_nl_n2(
- https://github.com/venkatkapil24/batch_nl
"""
n_systems = system_idx.max().item() + 1
cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems)
cell, pbc = normalize_inputs(cell, pbc, n_systems)
wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts(
positions, cell, system_idx, pbc
)
Expand Down Expand Up @@ -224,7 +199,7 @@ def torch_nl_linked_cell(
- https://github.com/felixmusil/torch_nl
"""
n_systems = system_idx.max().item() + 1
cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems)
cell, pbc = normalize_inputs(cell, pbc, n_systems)
wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts(
positions, cell, system_idx, pbc
)
Expand Down
34 changes: 34 additions & 0 deletions torch_sim/neighbors/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Utilities for neighbor list calculations."""

import torch


@torch.jit.script
def normalize_inputs(
cell: torch.Tensor, pbc: torch.Tensor, n_systems: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Normalize cell and PBC tensors to standard batch format.

Handles multiple input formats:
- cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3]
- pbc: [3], [n_systems, 3], or [n_systems*3]

Returns:
(cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3])
Both tensors are guaranteed to be contiguous.
"""
if cell.ndim == 2:
if cell.shape[0] == 3:
cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous()
else:
cell = cell.reshape(n_systems, 3, 3).contiguous()
else:
cell = cell.contiguous()
if pbc.ndim == 1:
if pbc.shape[0] == 3:
pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous()
else:
pbc = pbc.reshape(n_systems, 3).contiguous()
else:
pbc = pbc.contiguous()
return cell, pbc
10 changes: 4 additions & 6 deletions torch_sim/neighbors/vesin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import torch

from torch_sim.neighbors.utils import normalize_inputs


try:
from vesin import NeighborList as VesinNeighborList
Expand Down Expand Up @@ -70,16 +72,14 @@ def vesin_nl(
References:
- https://github.com/Luthaf/vesin
"""
from torch_sim.neighbors import _normalize_inputs

if VesinNeighborList is None:
raise RuntimeError(
"vesin is not installed. Install it with: [uv] pip install vesin"
)
device = positions.device
dtype = positions.dtype
n_systems = int(system_idx.max().item()) + 1
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
cell, pbc = normalize_inputs(cell, pbc, n_systems)

# Process each system's neighbor list separately
edge_indices = []
Expand Down Expand Up @@ -214,14 +214,12 @@ def vesin_nl_ts(
References:
https://github.com/Luthaf/vesin
"""
from torch_sim.neighbors import _normalize_inputs

if VesinNeighborListTorch is None:
raise RuntimeError("vesin[torch] package is not installed")
device = positions.device
dtype = positions.dtype
n_systems = int(system_idx.max().item()) + 1
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
cell, pbc = normalize_inputs(cell, pbc, n_systems)

# Process each system's neighbor list separately
edge_indices = []
Expand Down
Loading