Skip to content

Commit 9db7177

Browse files
authored
Maint to restore CI to all passing (#533)
1 parent e57a99f commit 9db7177

9 files changed

Lines changed: 82 additions & 98 deletions

File tree

examples/benchmarking/neighborlists.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def parse_args() -> argparse.Namespace:
9898
parser.add_argument(
9999
"--n-repeats",
100100
type=int,
101-
default=3,
102-
help="Number of timed repetitions (median is reported).",
101+
default=7,
102+
help="Number of timed repetitions (trimmed mean excluding min/max is reported).",
103103
)
104104
parsed_args = parser.parse_args()
105105
if parsed_args.n_structures <= 0:
@@ -264,13 +264,15 @@ def _benchmark_backend(
264264
torch.cuda.synchronize()
265265
timings.append(time.perf_counter() - t0)
266266

267-
median_s = float(np.median(timings))
267+
sorted_timings = sorted(timings)
268+
trimmed = sorted_timings[1:-1] if len(sorted_timings) > 2 else sorted_timings
269+
trimmed_mean_s = float(np.mean(trimmed))
268270
return {
269271
"nl_backend": backend,
270272
"n_pairs": int(mapping.shape[1]),
271-
"median_nl_s": round(median_s, 6),
273+
"trimmed_mean_nl_s": round(trimmed_mean_s, 6),
272274
"timings_s": [round(t, 6) for t in timings],
273-
"atoms_per_s": round(n_atoms / median_s, 1) if median_s > 0 else 0,
275+
"atoms_per_s": round(n_atoms / trimmed_mean_s, 1) if trimmed_mean_s > 0 else 0,
274276
}
275277

276278

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ vesin = ["vesin[torch]>=0.5.3"]
5151
io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"]
5252
symmetry = ["moyopy>=0.7.8"]
5353
mace = ["mace-torch>=0.3.15"]
54-
mattersim = ["mattersim>=0.1.2"]
54+
mattersim = ["mattersim>=1.2.2"]
5555
metatomic = ["metatomic-torchsim>=0.1.1", "metatomic-ase>=0.1.0", "upet>=0.2.0"]
5656
orb = ["orb-models>=0.6.2"]
5757
sevenn = ["sevenn[torchsim]>=0.12.1"]
5858
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
59-
nequip = ["nequip>=0.17.0"]
59+
nequip = ["nequip>=0.17.1"]
6060
nequix = ["nequix[torch-sim]>=0.4.5"]
6161
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
6262
docs = [

tests/models/test_nequip_framework.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,30 @@
2727
# Cache directory for compiled models (under tests/ for easy cleanup)
2828
NEQUIP_CACHE_DIR = Path(__file__).parent.parent / ".cache" / "nequip_compiled_models"
2929

30-
# Zenodo URL for NequIP-OAM-L model (more reliable than nequip.net for CI)
31-
NEQUIP_OAM_L_ZENODO_URL = (
32-
"https://zenodo.org/records/16980200/files/NequIP-OAM-L-0.1.nequip.zip?download=1"
30+
# Zenodo URL for NequIP-OAM-S model (more reliable than nequip.net for CI)
31+
NEQUIP_OAM_S_ZENODO_URL = (
32+
"https://zenodo.org/records/18775904/files/NequIP-OAM-S-0.1.nequip.zip?download=1"
3333
)
34-
NEQUIP_OAM_L_ZIP_NAME = "NequIP-OAM-L-0.1.nequip.zip"
34+
NEQUIP_OAM_S_ZIP_NAME = "NequIP-OAM-S-0.1.nequip.zip"
3535

3636

3737
def _get_nequip_model_zip() -> Path:
38-
"""Download NequIP-OAM-L model from Zenodo if not already cached."""
38+
"""Download NequIP-OAM-S model from Zenodo if not already cached."""
3939
NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True)
40-
zip_path = NEQUIP_CACHE_DIR / NEQUIP_OAM_L_ZIP_NAME
40+
zip_path = NEQUIP_CACHE_DIR / NEQUIP_OAM_S_ZIP_NAME
4141

4242
if not zip_path.exists():
43-
urllib.request.urlretrieve(NEQUIP_OAM_L_ZENODO_URL, zip_path) # noqa: S310
43+
urllib.request.urlretrieve(NEQUIP_OAM_S_ZENODO_URL, zip_path) # noqa: S310
4444

4545
return zip_path
4646

4747

4848
@pytest.fixture(scope="session")
4949
def compiled_ase_nequip_model_path() -> Path:
50-
"""Compile NequIP OAM-L model from Zenodo for ASE (with persistent caching)."""
50+
"""Compile NequIP OAM-S model from Zenodo for ASE (with persistent caching)."""
5151
NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True)
5252

53-
output_model_name = f"mir-group__NequIP-OAM-L__0.1__{DEVICE.type}_ase.nequip.pt2"
53+
output_model_name = f"mir-group__NequIP-OAM-S__0.1__{DEVICE.type}_ase.nequip.pt2"
5454
output_path = NEQUIP_CACHE_DIR / output_model_name
5555

5656
# Only compile if not already cached
@@ -74,10 +74,10 @@ def compiled_ase_nequip_model_path() -> Path:
7474

7575
@pytest.fixture(scope="session")
7676
def compiled_batch_nequip_model_path() -> Path:
77-
"""Compile NequIP OAM-L model from Zenodo for batch (with persistent caching)."""
77+
"""Compile NequIP OAM-S model from Zenodo for batch (with persistent caching)."""
7878
NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True)
7979

80-
output_model_name = f"mir-group__NequIP-OAM-L__0.1__{DEVICE.type}_batch.nequip.pt2"
80+
output_model_name = f"mir-group__NequIP-OAM-S__0.1__{DEVICE.type}_batch.nequip.pt2"
8181
output_path = NEQUIP_CACHE_DIR / output_model_name
8282

8383
# Only compile if not already cached

tests/models/test_orb.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,22 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator:
7272
energy_atol=5e-5,
7373
)
7474

75-
test_validate_conservative_model_outputs = make_validate_model_outputs_test(
76-
model_fixture_name="orbv3_conservative_inf_omat_model",
75+
test_validate_conservative_model_outputs = pytest.mark.xfail(
76+
reason=(
77+
"Upstream ORB conservative model incorrectly squeezes length-1 batch "
78+
"dimensions; see https://github.com/orbital-materials/orb-models/pull/158"
79+
),
80+
strict=False,
81+
)(
82+
make_validate_model_outputs_test(
83+
model_fixture_name="orbv3_conservative_inf_omat_model",
84+
)
7785
)
7886

79-
test_validate_direct_model_outputs = make_validate_model_outputs_test(
80-
model_fixture_name="orbv3_direct_20_omat_model",
81-
)
87+
test_validate_direct_model_outputs = pytest.mark.xfail(
88+
reason=(
89+
"Upstream ORB direct model shows batch-dependent leakage; "
90+
"see https://github.com/orbital-materials/orb-models/issues/159"
91+
),
92+
strict=False,
93+
)(make_validate_model_outputs_test(model_fixture_name="orbv3_direct_20_omat_model"))

torch_sim/neighbors/__init__.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,40 +30,6 @@
3030
)
3131

3232

33-
def _normalize_inputs(
34-
cell: torch.Tensor, pbc: torch.Tensor, n_systems: int
35-
) -> tuple[torch.Tensor, torch.Tensor]:
36-
"""Normalize cell and PBC tensors to standard batch format.
37-
38-
Handles multiple input formats:
39-
- cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3]
40-
- pbc: [3], [n_systems, 3], or [n_systems*3]
41-
42-
Returns:
43-
(cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3])
44-
Both tensors are guaranteed to be contiguous.
45-
"""
46-
# Normalize cell
47-
if cell.ndim == 2:
48-
if cell.shape[0] == 3:
49-
cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous()
50-
else:
51-
cell = cell.reshape(n_systems, 3, 3).contiguous()
52-
else:
53-
cell = cell.contiguous()
54-
55-
# Normalize PBC
56-
if pbc.ndim == 1:
57-
if pbc.shape[0] == 3:
58-
pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous()
59-
else:
60-
pbc = pbc.reshape(n_systems, 3).contiguous()
61-
else:
62-
pbc = pbc.contiguous()
63-
64-
return cell, pbc
65-
66-
6733
# Set default neighbor list based on what's available (priority order)
6834
if ALCHEMIOPS_AVAILABLE:
6935
# Alchemiops is fastest on NVIDIA GPUs

torch_sim/neighbors/alchemiops.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import torch
1212

13+
from torch_sim.neighbors.utils import normalize_inputs
14+
1315

1416
_batch_naive_neighbor_list: object | None = None
1517
_batch_cell_list: object | None = None
@@ -64,13 +66,10 @@ def alchemiops_nl_n2(
6466
Returns:
6567
(mapping, system_mapping, shifts_idx)
6668
"""
67-
from torch_sim.neighbors import _normalize_inputs
68-
6969
r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff
7070
n_systems = int(system_idx.max().item()) + 1
71-
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
71+
cell, pbc = normalize_inputs(cell, pbc, n_systems)
7272

73-
# Call alchemiops neighbor list
7473
if _batch_naive_neighbor_list is None:
7574
raise RuntimeError("nvalchemiops neighbor list is unavailable")
7675
res = _batch_naive_neighbor_list(
@@ -138,11 +137,9 @@ def alchemiops_nl_cell_list(
138137
Returns:
139138
(mapping, system_mapping, shifts_idx)
140139
"""
141-
from torch_sim.neighbors import _normalize_inputs
142-
143140
r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff
144141
n_systems = int(system_idx.max().item()) + 1
145-
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
142+
cell, pbc = normalize_inputs(cell, pbc, n_systems)
146143

147144
# For non-periodic systems with zero cells, use a nominal identity cell
148145
# to avoid division by zero in alchemiops warp kernels

torch_sim/neighbors/torch_nl.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,7 @@
1919
import torch
2020

2121
from torch_sim import transforms
22-
23-
24-
@torch.jit.script
25-
def _normalize_inputs_jit(
26-
cell: torch.Tensor, pbc: torch.Tensor, n_systems: int
27-
) -> tuple[torch.Tensor, torch.Tensor]:
28-
"""JIT-compatible input normalization for torch_nl functions."""
29-
# Normalize cell
30-
if cell.ndim == 2:
31-
if cell.shape[0] == 3:
32-
cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous()
33-
else:
34-
cell = cell.reshape(n_systems, 3, 3).contiguous()
35-
else:
36-
cell = cell.contiguous()
37-
38-
# Normalize PBC
39-
if pbc.ndim == 1:
40-
if pbc.shape[0] == 3:
41-
pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous()
42-
else:
43-
pbc = pbc.reshape(n_systems, 3).contiguous()
44-
else:
45-
pbc = pbc.contiguous()
46-
47-
return cell, pbc
22+
from torch_sim.neighbors.utils import normalize_inputs
4823

4924

5025
def strict_nl(
@@ -162,7 +137,7 @@ def torch_nl_n2(
162137
- https://github.com/venkatkapil24/batch_nl
163138
"""
164139
n_systems = system_idx.max().item() + 1
165-
cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems)
140+
cell, pbc = normalize_inputs(cell, pbc, n_systems)
166141
wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts(
167142
positions, cell, system_idx, pbc
168143
)
@@ -224,7 +199,7 @@ def torch_nl_linked_cell(
224199
- https://github.com/felixmusil/torch_nl
225200
"""
226201
n_systems = system_idx.max().item() + 1
227-
cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems)
202+
cell, pbc = normalize_inputs(cell, pbc, n_systems)
228203
wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts(
229204
positions, cell, system_idx, pbc
230205
)

torch_sim/neighbors/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Utilities for neighbor list calculations."""
2+
3+
import torch
4+
5+
6+
@torch.jit.script
7+
def normalize_inputs(
8+
cell: torch.Tensor, pbc: torch.Tensor, n_systems: int
9+
) -> tuple[torch.Tensor, torch.Tensor]:
10+
"""Normalize cell and PBC tensors to standard batch format.
11+
12+
Handles multiple input formats:
13+
- cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3]
14+
- pbc: [3], [n_systems, 3], or [n_systems*3]
15+
16+
Returns:
17+
(cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3])
18+
Both tensors are guaranteed to be contiguous.
19+
"""
20+
if cell.ndim == 2:
21+
if cell.shape[0] == 3:
22+
cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous()
23+
else:
24+
cell = cell.reshape(n_systems, 3, 3).contiguous()
25+
else:
26+
cell = cell.contiguous()
27+
if pbc.ndim == 1:
28+
if pbc.shape[0] == 3:
29+
pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous()
30+
else:
31+
pbc = pbc.reshape(n_systems, 3).contiguous()
32+
else:
33+
pbc = pbc.contiguous()
34+
return cell, pbc

torch_sim/neighbors/vesin.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import torch
1010

11+
from torch_sim.neighbors.utils import normalize_inputs
12+
1113

1214
try:
1315
from vesin import NeighborList as VesinNeighborList
@@ -70,16 +72,14 @@ def vesin_nl(
7072
References:
7173
- https://github.com/Luthaf/vesin
7274
"""
73-
from torch_sim.neighbors import _normalize_inputs
74-
7575
if VesinNeighborList is None:
7676
raise RuntimeError(
7777
"vesin is not installed. Install it with: [uv] pip install vesin"
7878
)
7979
device = positions.device
8080
dtype = positions.dtype
8181
n_systems = int(system_idx.max().item()) + 1
82-
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
82+
cell, pbc = normalize_inputs(cell, pbc, n_systems)
8383

8484
# Process each system's neighbor list separately
8585
edge_indices = []
@@ -214,14 +214,12 @@ def vesin_nl_ts(
214214
References:
215215
https://github.com/Luthaf/vesin
216216
"""
217-
from torch_sim.neighbors import _normalize_inputs
218-
219217
if VesinNeighborListTorch is None:
220218
raise RuntimeError("vesin[torch] package is not installed")
221219
device = positions.device
222220
dtype = positions.dtype
223221
n_systems = int(system_idx.max().item()) + 1
224-
cell, pbc = _normalize_inputs(cell, pbc, n_systems)
222+
cell, pbc = normalize_inputs(cell, pbc, n_systems)
225223

226224
# Process each system's neighbor list separately
227225
edge_indices = []

0 commit comments

Comments
 (0)