diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 23b583ab..7f3402b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -159,7 +159,7 @@ jobs: if [[ "${{ matrix.example }}" == *"fairchem"* ]]; then uv pip install huggingface_hub --system fi - uv run --with . ${{ matrix.example }} + uv run --with-editable . ${{ matrix.example }} all-required-pass: if: always() diff --git a/README.md b/README.md index ea4abcb7..70ffba22 100644 --- a/README.md +++ b/README.md @@ -157,10 +157,10 @@ If you use TorchSim in your research, please cite our [publication](https://iops ## Due Credit -We aim to recognize all [duecredit](https://github.com/duecredit/duecredit) for the decades of work that TorchSim builds on top of, an automated list of references can be obtained for the package by running `DUECREDIT_ENABLE=yes uv run --with . --extra docs --extra test python -m duecredit <(printf 'import pytest\nraise SystemExit(pytest.main(["-q"]))\n')`. This list is incomplete and we welcome PRs to help improve our citation coverage. +We aim to recognize all [duecredit](https://github.com/duecredit/duecredit) for the decades of work that TorchSim builds on top of, an automated list of references can be obtained for the package by running `DUECREDIT_ENABLE=yes uv run --with-editable . --extra docs --extra test python -m duecredit <(printf 'import pytest\nraise SystemExit(pytest.main(["-q"]))\n')`. This list is incomplete and we welcome PRs to help improve our citation coverage. To collect citations for a specific tutorial run, for example autobatching, use: ```sh -DUECREDIT_ENABLE=yes uv run --with . --extra docs --extra test python -m duecredit examples/tutorials/autobatching_tutorial.py +DUECREDIT_ENABLE=yes uv run --with-editable . --extra docs --extra test python -m duecredit examples/tutorials/autobatching_tutorial.py ``` diff --git a/examples/benchmarking/neighborlists.py b/examples/benchmarking/neighborlists.py index 40e02dc8..6341ee17 100644 --- a/examples/benchmarking/neighborlists.py +++ b/examples/benchmarking/neighborlists.py @@ -14,7 +14,7 @@ Directly times each torch-sim NL backend without any model evaluation. Example: - uv run --with . examples/benchmarking/neighborlists.py \ + uv run --with-editable . examples/benchmarking/neighborlists.py \ --source wbm --n-structures 100 --device cpu """ diff --git a/examples/readme.md b/examples/readme.md index 25ce0ab4..caffc439 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -35,25 +35,25 @@ If you'd like to execute the scripts or examples locally, you can run them with: curl -LsSf https://astral.sh/uv/install.sh | sh # pick any of the examples -uv run --with . examples/scripts/1_introduction.py -uv run --with . examples/scripts/2_structural_optimization.py -uv run --with . examples/scripts/3_dynamics.py -uv run --with . examples/scripts/4_high_level_api.py +uv run --with-editable . examples/scripts/1_introduction.py +uv run --with-editable . examples/scripts/2_structural_optimization.py +uv run --with-editable . examples/scripts/3_dynamics.py +uv run --with-editable . examples/scripts/4_high_level_api.py # or any of the tutorials -uv run --with . examples/tutorials/diff_sim.py +uv run --with-editable . examples/tutorials/diff_sim.py ``` ## Benchmarking Scripts The `examples/benchmarking/` folder contains standalone benchmark scripts. They declare their own dependencies via [PEP 723 inline script metadata](https://peps.python.org/pep-0723/) -and should be run with `uv run --with .` so that the local `torch-sim` package +and should be run with `uv run --with-editable .` so that the local `torch-sim` package is available alongside the script's isolated dependency environment: ```sh # Neighbor-list backend benchmark on WBM or MP structures -uv run --with . examples/benchmarking/neighborlists.py \ +uv run --with-editable . examples/benchmarking/neighborlists.py \ --source wbm --n-structures 100 --device cpu # Scaling benchmark: static, relax, NVE, NVT diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index bdc41f1e..9e9a2429 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -598,6 +598,23 @@ def _calculate_n2_lattice_shifts( ) # (n_shifts, 3) +def _pad_batched_positions( + positions: torch.Tensor, + n_atoms: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pad flat per-system positions and return the atom mask and offsets.""" + device = positions.device + n_systems = n_atoms.shape[0] + n_max = int(n_atoms.max().item()) + offsets = torch.zeros(n_systems, dtype=torch.long, device=device) + if n_systems > 1: + offsets[1:] = torch.cumsum(n_atoms[:-1], dim=0) + pos_list = [positions[offsets[i] : offsets[i] + n_atoms[i]] for i in range(n_systems)] + batch_positions = pad_sequence(pos_list, batch_first=True, padding_value=0.0) + atom_mask = torch.arange(n_max, device=device).unsqueeze(0) < n_atoms.unsqueeze(1) + return batch_positions, atom_mask, offsets + + def build_naive_neighborhood( positions: torch.Tensor, cell: torch.Tensor, @@ -645,21 +662,10 @@ def build_naive_neighborhood( """ device = positions.device dtype = positions.dtype - n_systems = n_atoms.shape[0] - n_max = int(n_atoms.max().item()) cell = cell.view(-1, 3, 3) pbc = pbc.view(-1, 3).to(torch.bool) - - # --- pad positions into (n_systems, n_max, 3) --- - offsets = torch.zeros(n_systems, dtype=torch.long, device=device) - offsets[1:] = torch.cumsum(n_atoms[:-1], dim=0) - - # split flat positions into per-system tensors, then pad - pos_list = [positions[offsets[i] : offsets[i] + n_atoms[i]] for i in range(n_systems)] - batch_positions = pad_sequence(pos_list, batch_first=True, padding_value=0.0) - # mask: True for real atoms, False for padding - batch_mask = torch.arange(n_max, device=device).unsqueeze(0) < n_atoms.unsqueeze(1) + batch_positions, batch_mask, offsets = _pad_batched_positions(positions, n_atoms) # --- compute lattice shifts --- lattice_shifts = _calculate_n2_lattice_shifts(cell, pbc, cutoff) # (n_shifts, 3) @@ -732,16 +738,17 @@ def ravel_3d(idx_3d: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: elements in a flattened representation. Args: - idx_3d (torch.Tensor): A tensor of shape [-1, 3] + idx_3d (torch.Tensor): A tensor of shape [..., 3] representing the 3D indices to be converted. shape (torch.Tensor): A tensor of shape [3] representing the dimensions of the array. Returns: - torch.Tensor: A tensor containing the linear indices + torch.Tensor: A tensor of shape [...] + containing the linear indices corresponding to the input 3D indices. """ - return idx_3d[:, 2] + shape[2] * (idx_3d[:, 1] + shape[1] * idx_3d[:, 0]) + return idx_3d[..., 2] + shape[2] * (idx_3d[..., 1] + shape[1] * idx_3d[..., 0]) def unravel_3d(idx_linear: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: @@ -752,22 +759,19 @@ def unravel_3d(idx_linear: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: The conversion is based on the provided shape of the array. Args: - idx_linear (torch.Tensor): A tensor of shape [-1] + idx_linear (torch.Tensor): A tensor of shape [...] representing the linear indices to be converted. shape (torch.Tensor): A tensor of shape [3] representing the dimensions of the array. Returns: - torch.Tensor: A tensor of shape [-1, 3] + torch.Tensor: A tensor of shape [..., 3] containing the 3D indices corresponding to the input linear indices. """ - idx_3d = idx_linear.new_empty((idx_linear.shape[0], 3)) - idx_3d[:, 2] = torch.remainder(idx_linear, shape[2]) - idx_3d[:, 1] = torch.remainder( - torch.div(idx_linear, shape[2], rounding_mode="floor"), shape[1] - ) - idx_3d[:, 0] = torch.div(idx_linear, shape[1] * shape[2], rounding_mode="floor") - return idx_3d + z = torch.remainder(idx_linear, shape[2]) + y = torch.remainder(torch.div(idx_linear, shape[2], rounding_mode="floor"), shape[1]) + x = torch.div(idx_linear, shape[1] * shape[2], rounding_mode="floor") + return torch.stack([x, y, z], dim=-1) def get_linear_bin_idx( @@ -999,7 +1003,7 @@ def linked_cell( # noqa: PLR0915 return neigh_atom, neigh_shift_idx -def build_linked_cell_neighborhood( +def build_linked_cell_neighborhood_serial( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, @@ -1077,6 +1081,295 @@ def build_linked_cell_neighborhood( ) +def _neighbor_bin_shifts_3d(device: torch.device) -> torch.Tensor: + """Return the 27 neighboring 3D bin offsets.""" + dd = torch.tensor([0, 1, -1], dtype=torch.long, device=device) + return torch.cartesian_prod(dd, dd, dd) + + +def _within_bin_position(sorted_bin: torch.Tensor) -> torch.Tensor: + """Compute within-bin position for a sorted bin index tensor. + + Given (S, N) of sorted bin indices, returns (S, N) where each element + is the 0-based position of that atom within its bin (0 for first atom + in the bin, 1 for second, etc.). + + Uses a vectorized segment-cumsum: cumsum of ones, minus the cumulative + count at each group boundary broadcast to the group. + """ + S, N = sorted_bin.shape + device = sorted_bin.device + same = torch.ones(S, N, dtype=torch.long, device=device) + same[:, 1:] = (sorted_bin[:, 1:] == sorted_bin[:, :-1]).long() + cum = torch.cumsum(same, dim=1) # (S, N) — running count within each row + boundary = same == 0 # True at group starts (except position 0) + boundary_val = torch.where(boundary, cum - 1, torch.zeros_like(cum)) + correction = torch.cummax(boundary_val, dim=1).values + return cum - correction - 1 + + +def _build_linked_cell_images_batched( + batch_pos: torch.Tensor, + atom_mask: torch.Tensor, + shifts_idx_unique: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build shifted images and validity mask for the batched linked-cell path.""" + n_structure, n_max = batch_pos.shape[:2] + n_shifts = shifts_idx_unique.shape[0] + cart_shifts = torch.matmul( + shifts_idx_unique.to(batch_pos.dtype), cell + ) # (n_systems, n_shifts, 3) + shift_ok = ((shifts_idx_unique == 0).unsqueeze(0) | pbc.unsqueeze(1)).all(dim=-1) + # `shift_ok` is (n_systems, n_shifts): valid shared lattice shifts per system. + # Flatten the (shift, atom) grid into one image axis for later sorting/binning. + images_flat = (batch_pos.unsqueeze(1) + cart_shifts.unsqueeze(2)).reshape( + n_structure, n_shifts * n_max, 3 + ) # (n_systems, n_shifts * max_atoms, 3) + image_valid = ( + (atom_mask.unsqueeze(1) & shift_ok.unsqueeze(-1)) + .expand(-1, n_shifts, -1) + .reshape(n_structure, n_shifts * n_max) + ) # (n_systems, n_shifts * max_atoms) + return cart_shifts, images_flat, image_valid, shift_ok + + +def _bin_linked_cell_images_batched( + batch_pos: torch.Tensor, + atom_mask: torch.Tensor, + images_flat: torch.Tensor, + image_valid: torch.Tensor, + cart_shifts: torch.Tensor, + shift_ok: torch.Tensor, + cutoff: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Assign shifted images to bins and build the per-bin image lookup table.""" + device = batch_pos.device + dtype = batch_pos.dtype + n_structure = batch_pos.shape[0] + n_img = images_flat.shape[1] + big_val = torch.finfo(dtype).max / 2 + pos_for_min = torch.where( + atom_mask.unsqueeze(-1), batch_pos, torch.full_like(batch_pos, big_val) + ) + pos_for_max = torch.where( + atom_mask.unsqueeze(-1), batch_pos, torch.full_like(batch_pos, -big_val) + ) + shift_for_min = torch.where( + shift_ok.unsqueeze(-1), cart_shifts, torch.full_like(cart_shifts, big_val) + ) + shift_for_max = torch.where( + shift_ok.unsqueeze(-1), cart_shifts, torch.full_like(cart_shifts, -big_val) + ) + b_min = pos_for_min.min(dim=1).values + shift_for_min.min(dim=1).values + b_max = pos_for_max.max(dim=1).values + shift_for_max.max(dim=1).values + # Rebuild the same cutoff-sized box construction used in the single-structure path. + images_shifted = images_flat - b_min.unsqueeze(1) + 1e-5 + box_length = b_max - b_min + 1e-3 + n_bins_s_per_sys = torch.maximum( + torch.ceil(box_length / cutoff), + torch.ones(n_structure, 3, device=device, dtype=dtype), + ).to(torch.long) # (n_systems, 3) + n_bins_s = n_bins_s_per_sys.max(dim=0).values # (3,) + n_bins = int(n_bins_s.prod().item()) + box_diag_per_sys = n_bins_s_per_sys.to(dtype) * cutoff + scaled_pos = images_shifted / box_diag_per_sys.unsqueeze(1) + bin_3d = torch.floor(scaled_pos * n_bins_s_per_sys.to(dtype).unsqueeze(1)).to( + torch.long + ) + bin_3d = bin_3d.clamp( + min=torch.zeros(3, device=device, dtype=torch.long), + max=(n_bins_s - 1), + ) + bin_linear = ravel_3d(bin_3d, n_bins_s) # (n_systems, n_shifts * max_atoms) + # Sort-by-bin lets us scatter images into a dense (bin, slot) lookup table. + safe_bin = torch.where(image_valid, bin_linear, torch.full_like(bin_linear, n_bins)) + sorted_bin, sorted_order = torch.sort(safe_bin, dim=1) + sorted_valid = image_valid.gather(1, sorted_order) + sorted_bin_clamped = torch.where( + sorted_valid, sorted_bin, torch.zeros_like(sorted_bin) + ) + within_pos = _within_bin_position(sorted_bin_clamped) + within_pos = torch.where(sorted_valid, within_pos, torch.zeros_like(within_pos)) + counts = torch.zeros(n_structure, n_bins, device=device, dtype=torch.long) + counts.scatter_add_(1, sorted_bin_clamped, sorted_valid.long()) + max_apb = max(int(counts.max().item()), 1) + sentinel = n_img + bin_id_j = torch.full( + (n_structure, n_bins * max_apb + 1), sentinel, dtype=torch.long, device=device + ) + flat_target = sorted_bin_clamped * max_apb + within_pos + scatter_mask = sorted_valid & (within_pos < max_apb) + trash_idx = n_bins * max_apb + safe_target = torch.where( + scatter_mask, flat_target, torch.full_like(flat_target, trash_idx) + ) + src_vals = torch.where( + scatter_mask, sorted_order, torch.full_like(sorted_order, sentinel) + ) + bin_id_j.scatter_(1, safe_target, src_vals) + bin_id_j = bin_id_j[:, : n_bins * max_apb].view( + n_structure, n_bins, max_apb + ) # (n_systems, n_bins, max_atoms_per_bin) + return bin_linear, bin_id_j, n_bins_s + + +def _gather_linked_cell_candidates_batched( + bin_linear: torch.Tensor, + bin_id_j: torch.Tensor, + atom_mask: torch.Tensor, + n_bins_s: torch.Tensor, + shifts_idx_unique: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Gather per-atom candidate neighbor images from surrounding bins.""" + device = bin_linear.device + n_structure, n_max = atom_mask.shape + max_apb = bin_id_j.shape[2] + sentinel = bin_linear.shape[1] + zero_shift_idx = (shifts_idx_unique == 0).all(dim=1).nonzero(as_tuple=True)[0][0] + orig_start = int(zero_shift_idx.item()) * n_max + bin_index_i = bin_linear[:, orig_start : orig_start + n_max] + bin_shifts_27 = _neighbor_bin_shifts_3d(device) + n_nb = bin_shifts_27.shape[0] + # Each central atom only needs images from its own bin and the 26 adjacent bins. + i_bins_3d = unravel_3d(bin_index_i, n_bins_s) + neigh_bins_3d = i_bins_3d.unsqueeze(2) + bin_shifts_27.view( + 1, 1, n_nb, 3 + ) # (n_systems, max_atoms, 27, 3) + neigh_ok = ((neigh_bins_3d >= 0) & (neigh_bins_3d < n_bins_s.view(1, 1, 1, 3))).all( + dim=-1 + ) & atom_mask.unsqueeze(2) + neigh_bins_lin = ravel_3d( + neigh_bins_3d.clamp( + min=torch.zeros(3, device=device, dtype=torch.long), + max=(n_bins_s - 1), + ), + n_bins_s, + ) + gather_idx = ( + neigh_bins_lin.reshape(n_structure, -1).unsqueeze(-1).expand(-1, -1, max_apb) + ) # (n_systems, max_atoms * 27, max_atoms_per_bin) + candidates = bin_id_j.gather(1, gather_idx).reshape( + n_structure, n_max, n_nb, max_apb + ) # (n_systems, max_atoms, 27, max_atoms_per_bin) + candidate_valid = neigh_ok.unsqueeze(-1) & ( + candidates != sentinel + ) # (n_systems, max_atoms, 27, max_atoms_per_bin) + return ( + candidates.reshape(n_structure, n_max, n_nb * max_apb), + candidate_valid.reshape(n_structure, n_max, n_nb * max_apb), + orig_start, + ) + + +def _finalize_linked_cell_pairs_batched( + candidates: torch.Tensor, + candidate_valid: torch.Tensor, + offsets: torch.Tensor, + shifts_idx_unique: torch.Tensor, + orig_start: int, + *, + self_interaction: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert candidate image indices into the final neighbor list outputs.""" + device = candidates.device + _, n_max, _ = candidates.shape + pair_valid = candidate_valid + if not self_interaction: + local_image_idx = torch.arange(n_max, device=device).view(1, -1, 1) + orig_start + pair_valid = pair_valid & (candidates != local_image_idx) + # Compact the valid (system, i, candidate-slot) triples once, then index directly. + s_flat, i_flat, k_flat = pair_valid.nonzero(as_tuple=True) + j_flat = candidates[s_flat, i_flat, k_flat] + j_atom = j_flat % n_max + shift_out = shifts_idx_unique[j_flat // n_max] + mapping = torch.stack([i_flat + offsets[s_flat], j_atom + offsets[s_flat]], dim=0) + sort_idx = torch.argsort(mapping[0]) + return mapping[:, sort_idx], s_flat[sort_idx], shift_out[sort_idx] + + +def build_linked_cell_neighborhood_batched( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: float, + n_atoms: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fully batched linked-cell neighbor list construction. + + Drop-in replacement for build_linked_cell_neighborhood that processes all + structures simultaneously using padded tensors, eliminating Python for-loops + over structures. The algorithm mirrors the single-structure ``linked_cell`` + but operates on all structures at once via padded 3D/4D tensors. + + Args: + positions: Flat atomic positions (N_total, 3). + cell: Unit cell matrices, broadcastable to (n_structure, 3, 3). + pbc: PBC flags, broadcastable to (n_structure, 3). + cutoff: Neighbor cutoff distance. + n_atoms: Number of atoms per structure (n_structure,). + self_interaction: Whether to include self-pairs. + + Returns: + (mapping, system_mapping, cell_shifts_idx) — same format as + build_linked_cell_neighborhood. + """ + shift_dtype = cell.dtype + cell = cell.view(-1, 3, 3) + pbc = pbc.view(-1, 3).to(torch.bool) + shifts_idx_unique = _calculate_n2_lattice_shifts(cell, pbc, cutoff) + batch_pos, atom_mask, offsets = _pad_batched_positions(positions, n_atoms) + # Mirror the main linked-cell stages: images -> bins -> candidates -> pairs. + cart_shifts, images_flat, image_valid, shift_ok = _build_linked_cell_images_batched( + batch_pos, atom_mask, shifts_idx_unique, cell, pbc + ) + bin_linear, bin_id_j, n_bins_s = _bin_linked_cell_images_batched( + batch_pos, atom_mask, images_flat, image_valid, cart_shifts, shift_ok, cutoff + ) + candidates, candidate_valid, orig_start = _gather_linked_cell_candidates_batched( + bin_linear, bin_id_j, atom_mask, n_bins_s, shifts_idx_unique + ) + mapping, system_mapping, shift_out = _finalize_linked_cell_pairs_batched( + candidates, + candidate_valid, + offsets, + shifts_idx_unique, + orig_start, + self_interaction=self_interaction, + ) + return mapping, system_mapping, shift_out.to(shift_dtype) + + +def build_linked_cell_neighborhood( + positions: torch.Tensor, + cell: torch.Tensor, + pbc: torch.Tensor, + cutoff: float, + n_atoms: torch.Tensor, + self_interaction: bool = False, # noqa: FBT001, FBT002 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Backward-compatible alias for the batched linked-cell implementation.""" + if cell.shape[0] == 1: + return build_linked_cell_neighborhood_serial( + positions, + cell, + pbc, + cutoff, + n_atoms, + self_interaction, + ) + return build_linked_cell_neighborhood_batched( + positions, + cell, + pbc, + cutoff, + n_atoms, + self_interaction, + ) + + def multiplicative_isotropic_cutoff( fn: Callable[..., torch.Tensor], r_onset: float | torch.Tensor,