diff --git a/examples/benchmarking/neighborlists.py b/examples/benchmarking/neighborlists.py index 40e02dc8..7d6c6ee1 100644 --- a/examples/benchmarking/neighborlists.py +++ b/examples/benchmarking/neighborlists.py @@ -98,8 +98,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--n-repeats", type=int, - default=3, - help="Number of timed repetitions (median is reported).", + default=7, + help="Number of timed repetitions (trimmed mean excluding min/max is reported).", ) parsed_args = parser.parse_args() if parsed_args.n_structures <= 0: @@ -264,13 +264,15 @@ def _benchmark_backend( torch.cuda.synchronize() timings.append(time.perf_counter() - t0) - median_s = float(np.median(timings)) + sorted_timings = sorted(timings) + trimmed = sorted_timings[1:-1] if len(sorted_timings) > 2 else sorted_timings + trimmed_mean_s = float(np.mean(trimmed)) return { "nl_backend": backend, "n_pairs": int(mapping.shape[1]), - "median_nl_s": round(median_s, 6), + "trimmed_mean_nl_s": round(trimmed_mean_s, 6), "timings_s": [round(t, 6) for t in timings], - "atoms_per_s": round(n_atoms / median_s, 1) if median_s > 0 else 0, + "atoms_per_s": round(n_atoms / trimmed_mean_s, 1) if trimmed_mean_s > 0 else 0, } diff --git a/pyproject.toml b/pyproject.toml index 3f60469b..9ebc3752 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,12 +51,12 @@ vesin = ["vesin[torch]>=0.5.3"] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] symmetry = ["moyopy>=0.7.8"] mace = ["mace-torch>=0.3.15"] -mattersim = ["mattersim>=0.1.2"] +mattersim = ["mattersim>=1.2.2"] metatomic = ["metatomic-torchsim>=0.1.1", "metatomic-ase>=0.1.0", "upet>=0.2.0"] orb = ["orb-models>=0.6.2"] sevenn = ["sevenn[torchsim]>=0.12.1"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] -nequip = ["nequip>=0.17.0"] +nequip = ["nequip>=0.17.1"] nequix = ["nequix[torch-sim]>=0.4.5"] fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] docs = [ diff --git a/tests/models/test_nequip_framework.py b/tests/models/test_nequip_framework.py index 51f73200..e148dae3 100644 --- a/tests/models/test_nequip_framework.py +++ b/tests/models/test_nequip_framework.py @@ -27,30 +27,30 @@ # Cache directory for compiled models (under tests/ for easy cleanup) NEQUIP_CACHE_DIR = Path(__file__).parent.parent / ".cache" / "nequip_compiled_models" -# Zenodo URL for NequIP-OAM-L model (more reliable than nequip.net for CI) -NEQUIP_OAM_L_ZENODO_URL = ( - "https://zenodo.org/records/16980200/files/NequIP-OAM-L-0.1.nequip.zip?download=1" +# Zenodo URL for NequIP-OAM-S model (more reliable than nequip.net for CI) +NEQUIP_OAM_S_ZENODO_URL = ( + "https://zenodo.org/records/18775904/files/NequIP-OAM-S-0.1.nequip.zip?download=1" ) -NEQUIP_OAM_L_ZIP_NAME = "NequIP-OAM-L-0.1.nequip.zip" +NEQUIP_OAM_S_ZIP_NAME = "NequIP-OAM-S-0.1.nequip.zip" def _get_nequip_model_zip() -> Path: - """Download NequIP-OAM-L model from Zenodo if not already cached.""" + """Download NequIP-OAM-S model from Zenodo if not already cached.""" NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True) - zip_path = NEQUIP_CACHE_DIR / NEQUIP_OAM_L_ZIP_NAME + zip_path = NEQUIP_CACHE_DIR / NEQUIP_OAM_S_ZIP_NAME if not zip_path.exists(): - urllib.request.urlretrieve(NEQUIP_OAM_L_ZENODO_URL, zip_path) # noqa: S310 + urllib.request.urlretrieve(NEQUIP_OAM_S_ZENODO_URL, zip_path) # noqa: S310 return zip_path @pytest.fixture(scope="session") def compiled_ase_nequip_model_path() -> Path: - """Compile NequIP OAM-L model from Zenodo for ASE (with persistent caching).""" + """Compile NequIP OAM-S model from Zenodo for ASE (with persistent caching).""" NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True) - output_model_name = f"mir-group__NequIP-OAM-L__0.1__{DEVICE.type}_ase.nequip.pt2" + output_model_name = f"mir-group__NequIP-OAM-S__0.1__{DEVICE.type}_ase.nequip.pt2" output_path = NEQUIP_CACHE_DIR / output_model_name # Only compile if not already cached @@ -74,10 +74,10 @@ def compiled_ase_nequip_model_path() -> Path: @pytest.fixture(scope="session") def compiled_batch_nequip_model_path() -> Path: - """Compile NequIP OAM-L model from Zenodo for batch (with persistent caching).""" + """Compile NequIP OAM-S model from Zenodo for batch (with persistent caching).""" NEQUIP_CACHE_DIR.mkdir(parents=True, exist_ok=True) - output_model_name = f"mir-group__NequIP-OAM-L__0.1__{DEVICE.type}_batch.nequip.pt2" + output_model_name = f"mir-group__NequIP-OAM-S__0.1__{DEVICE.type}_batch.nequip.pt2" output_path = NEQUIP_CACHE_DIR / output_model_name # Only compile if not already cached diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 6bdf1376..de23cd17 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -72,10 +72,22 @@ def orbv3_direct_20_omat_calculator() -> ORBCalculator: energy_atol=5e-5, ) -test_validate_conservative_model_outputs = make_validate_model_outputs_test( - model_fixture_name="orbv3_conservative_inf_omat_model", +test_validate_conservative_model_outputs = pytest.mark.xfail( + reason=( + "Upstream ORB conservative model incorrectly squeezes length-1 batch " + "dimensions; see https://github.com/orbital-materials/orb-models/pull/158" + ), + strict=False, +)( + make_validate_model_outputs_test( + model_fixture_name="orbv3_conservative_inf_omat_model", + ) ) -test_validate_direct_model_outputs = make_validate_model_outputs_test( - model_fixture_name="orbv3_direct_20_omat_model", -) +test_validate_direct_model_outputs = pytest.mark.xfail( + reason=( + "Upstream ORB direct model shows batch-dependent leakage; " + "see https://github.com/orbital-materials/orb-models/issues/159" + ), + strict=False, +)(make_validate_model_outputs_test(model_fixture_name="orbv3_direct_20_omat_model")) diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index 10b5a1db..5d08c589 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -30,40 +30,6 @@ ) -def _normalize_inputs( - cell: torch.Tensor, pbc: torch.Tensor, n_systems: int -) -> tuple[torch.Tensor, torch.Tensor]: - """Normalize cell and PBC tensors to standard batch format. - - Handles multiple input formats: - - cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3] - - pbc: [3], [n_systems, 3], or [n_systems*3] - - Returns: - (cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3]) - Both tensors are guaranteed to be contiguous. - """ - # Normalize cell - if cell.ndim == 2: - if cell.shape[0] == 3: - cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() - else: - cell = cell.reshape(n_systems, 3, 3).contiguous() - else: - cell = cell.contiguous() - - # Normalize PBC - if pbc.ndim == 1: - if pbc.shape[0] == 3: - pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() - else: - pbc = pbc.reshape(n_systems, 3).contiguous() - else: - pbc = pbc.contiguous() - - return cell, pbc - - # Set default neighbor list based on what's available (priority order) if ALCHEMIOPS_AVAILABLE: # Alchemiops is fastest on NVIDIA GPUs diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py index bb3dcc45..f9759c03 100644 --- a/torch_sim/neighbors/alchemiops.py +++ b/torch_sim/neighbors/alchemiops.py @@ -10,6 +10,8 @@ import torch +from torch_sim.neighbors.utils import normalize_inputs + _batch_naive_neighbor_list: object | None = None _batch_cell_list: object | None = None @@ -64,13 +66,10 @@ def alchemiops_nl_n2( Returns: (mapping, system_mapping, shifts_idx) """ - from torch_sim.neighbors import _normalize_inputs - r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff n_systems = int(system_idx.max().item()) + 1 - cell, pbc = _normalize_inputs(cell, pbc, n_systems) + cell, pbc = normalize_inputs(cell, pbc, n_systems) - # Call alchemiops neighbor list if _batch_naive_neighbor_list is None: raise RuntimeError("nvalchemiops neighbor list is unavailable") res = _batch_naive_neighbor_list( @@ -138,11 +137,9 @@ def alchemiops_nl_cell_list( Returns: (mapping, system_mapping, shifts_idx) """ - from torch_sim.neighbors import _normalize_inputs - r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff n_systems = int(system_idx.max().item()) + 1 - cell, pbc = _normalize_inputs(cell, pbc, n_systems) + cell, pbc = normalize_inputs(cell, pbc, n_systems) # For non-periodic systems with zero cells, use a nominal identity cell # to avoid division by zero in alchemiops warp kernels diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py index 687861e2..50bf94c7 100644 --- a/torch_sim/neighbors/torch_nl.py +++ b/torch_sim/neighbors/torch_nl.py @@ -19,32 +19,7 @@ import torch from torch_sim import transforms - - -@torch.jit.script -def _normalize_inputs_jit( - cell: torch.Tensor, pbc: torch.Tensor, n_systems: int -) -> tuple[torch.Tensor, torch.Tensor]: - """JIT-compatible input normalization for torch_nl functions.""" - # Normalize cell - if cell.ndim == 2: - if cell.shape[0] == 3: - cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() - else: - cell = cell.reshape(n_systems, 3, 3).contiguous() - else: - cell = cell.contiguous() - - # Normalize PBC - if pbc.ndim == 1: - if pbc.shape[0] == 3: - pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() - else: - pbc = pbc.reshape(n_systems, 3).contiguous() - else: - pbc = pbc.contiguous() - - return cell, pbc +from torch_sim.neighbors.utils import normalize_inputs def strict_nl( @@ -162,7 +137,7 @@ def torch_nl_n2( - https://github.com/venkatkapil24/batch_nl """ n_systems = system_idx.max().item() + 1 - cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) + cell, pbc = normalize_inputs(cell, pbc, n_systems) wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts( positions, cell, system_idx, pbc ) @@ -224,7 +199,7 @@ def torch_nl_linked_cell( - https://github.com/felixmusil/torch_nl """ n_systems = system_idx.max().item() + 1 - cell, pbc = _normalize_inputs_jit(cell, pbc, n_systems) + cell, pbc = normalize_inputs(cell, pbc, n_systems) wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts( positions, cell, system_idx, pbc ) diff --git a/torch_sim/neighbors/utils.py b/torch_sim/neighbors/utils.py new file mode 100644 index 00000000..8b91bda1 --- /dev/null +++ b/torch_sim/neighbors/utils.py @@ -0,0 +1,34 @@ +"""Utilities for neighbor list calculations.""" + +import torch + + +@torch.jit.script +def normalize_inputs( + cell: torch.Tensor, pbc: torch.Tensor, n_systems: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Normalize cell and PBC tensors to standard batch format. + + Handles multiple input formats: + - cell: [3, 3], [n_systems, 3, 3], or [n_systems*3, 3] + - pbc: [3], [n_systems, 3], or [n_systems*3] + + Returns: + (cell, pbc) normalized to ([n_systems, 3, 3], [n_systems, 3]) + Both tensors are guaranteed to be contiguous. + """ + if cell.ndim == 2: + if cell.shape[0] == 3: + cell = cell.unsqueeze(0).expand(n_systems, -1, -1).contiguous() + else: + cell = cell.reshape(n_systems, 3, 3).contiguous() + else: + cell = cell.contiguous() + if pbc.ndim == 1: + if pbc.shape[0] == 3: + pbc = pbc.unsqueeze(0).expand(n_systems, -1).contiguous() + else: + pbc = pbc.reshape(n_systems, 3).contiguous() + else: + pbc = pbc.contiguous() + return cell, pbc diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index 009fe9bb..8d58d78d 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -8,6 +8,8 @@ import torch +from torch_sim.neighbors.utils import normalize_inputs + try: from vesin import NeighborList as VesinNeighborList @@ -70,8 +72,6 @@ def vesin_nl( References: - https://github.com/Luthaf/vesin """ - from torch_sim.neighbors import _normalize_inputs - if VesinNeighborList is None: raise RuntimeError( "vesin is not installed. Install it with: [uv] pip install vesin" @@ -79,7 +79,7 @@ def vesin_nl( device = positions.device dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 - cell, pbc = _normalize_inputs(cell, pbc, n_systems) + cell, pbc = normalize_inputs(cell, pbc, n_systems) # Process each system's neighbor list separately edge_indices = [] @@ -214,14 +214,12 @@ def vesin_nl_ts( References: https://github.com/Luthaf/vesin """ - from torch_sim.neighbors import _normalize_inputs - if VesinNeighborListTorch is None: raise RuntimeError("vesin[torch] package is not installed") device = positions.device dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 - cell, pbc = _normalize_inputs(cell, pbc, n_systems) + cell, pbc = normalize_inputs(cell, pbc, n_systems) # Process each system's neighbor list separately edge_indices = []