Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/tests_and_linters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ name: Tests and Linters 🧪

on:
push:
branches:
- 'main'
- 'new-release*'
branches: [main]
pull_request:

jobs:
linters:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ repos:
- flake8-cognitive-complexity

- repo: https://github.com/compilerla/conventional-pre-commit
rev: v2.3.0
rev: v4.3.0
hooks:
- id: conventional-pre-commit
name: "Commit linter"
Expand Down
9 changes: 9 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelog

## Release 0.1.7

- Fixing issues with Periodic Boundary Conditions (PBCs) during inference.
- Supporting PBCs passed from `ase.Atoms` during simulation with the ASE engine.
Passing an orthorhombic box from configuration is still supported in both simulation
engines, but might become discouraged in future releases.
- Fixing a few bugs related to batched simulations occurring in cases of
reallocation of neighbor lists.

## Release 0.1.6

- Fixing incorrect instructions for GPU-compatible installation: most shells require
Expand Down
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 🪩 MLIP: Machine Learning Interatomic Potentials 🚀
# 🪩 MLIP: Machine Learning Interatomic Potentials

[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
[![Python 3.11](https://img.shields.io/badge/python-3.11%20%7C%203.12%20%7C%203.13-blue)](https://www.python.org/downloads/release/python-3110/)
Expand Down Expand Up @@ -28,11 +28,13 @@ experienced with MLIP and JAX, and (3) a focus on **high inference speeds** that
running long MD simulations on large systems which we believe is necessary in order to
bring MLIP to large-scale industrial application.
See our [inference speed benchmark](#-inference-time-benchmarks) below.
With our library, we observe a 10x speedup on 138 atoms and up to 4x speed up
on 1205 atoms over equivalent implementations relying on Torch and ASE.

See the [Installation](#-installation) section for details on how to install
MLIP-JAX and the example Google Colab notebooks linked below for a quick way
🎙️ For further information on the design principles and story behind the *mlip* library,
also check out our [Let's Talk Research podcast episode](https://youtu.be/xsCclme6RmY)
on the topic.

See the [Installation](#-installation) section for details on how to install *mlip* and the
example Jupyter notebooks linked below for a quick way
to get started. For detailed instructions, visit our extensive
[code documentation](https://instadeepai.github.io/mlip/).

Expand Down Expand Up @@ -167,3 +169,21 @@ S. Attias, M. Maarand, Y. Khanfir, E. Toledo, F. Falcioni, M. Bluntzer,
S. Acosta-Gutiérrez and J. Tilly, *Machine Learning Interatomic Potentials:
library for efficient training, model development and simulation of molecular systems*,
arXiv, 2025, arXiv:2505.22397.

The BibTeX formatted citation:

```
@misc{brunken2025mlip,
title={Machine Learning Interatomic Potentials: library for efficient training,
model development and simulation of molecular systems},
author={Christoph Brunken and Olivier Peltre and Heloise Chomet and
Lucien Walewski and Manus McAuliffe and Valentin Heyraud and Solal Attias
and Martin Maarand and Yessine Khanfir and Edan Toledo and Fabio Falcioni
and Marie Bluntzer and Silvia Acosta-Gutiérrez and Jules Tilly},
year={2025},
eprint={2505.22397},
archivePrefix={arXiv},
primaryClass={physics.chem-ph},
url={https://arxiv.org/abs/2505.22397},
}
```
16 changes: 16 additions & 0 deletions docs/source/user_guide/simulations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,19 @@ that allows to input a list of `ase.Atoms` objects and returns a list of
# Example: Get energy and forces for 7-th structure (indexing starts at 0)
energy = predictions[7].energy
forces = predictions[7].forces


Periodic Boundary Conditions
----------------------------

Generic periodic boundary conditions (PBCs) are currently only supported by
the :py:class:`ASESimulationEngine <mlip.simulation.ase.ase_simulation_engine.ASESimulationEngine>`,
which are read from the `cell` attribute of the `ase.Atoms` to be simulated.

Orthorhombic PBCs (90° angles) can otherwise be specified for both simulation engines
via the `box` attribute of
:py:class:`SimulationConfig <mlip.simulation.configs.simulation_config.SimulationConfig>`,
which can either be `None`, a float, or a list of three floats.
This is currently due to a limitation of
`jax_md.space.periodic <https://jax-md.readthedocs.io/en/main/jax_md.space.html#jax_md.space.periodic>`_,
but we may support non-orthorhombic lattices with Jax-MD too in future versions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mlip"
version = "0.1.6"
version = "0.1.7"
description = "Machine Learning Interatomic Potentials in JAX"
license-files = [
"LICENSE"
Expand Down
2 changes: 2 additions & 0 deletions src/mlip/inference/batched_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def _prepare_graphs(
atomic_numbers=atoms.numbers,
atomic_species=np.asarray([z_table.z_to_index(z) for z in atoms.numbers]),
positions=atoms.get_positions(),
cell=np.array(atoms.get_cell()),
pbc=atoms.pbc,
)
for atoms in structures
]
Expand Down
6 changes: 6 additions & 0 deletions src/mlip/simulation/ase/ase_simulation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def _initialize(
logger.debug("Initialization of simulation completed.")

def _init_box(self) -> None:
"""Update the PBC parameters of the underlying `ase.Atoms`"""
# Pass if atoms already have PBC and cell, best source of truth
if self.atoms.cell is not None and self.atoms.pbc is not None:
return
# Support cubic periodic box from config for Jax-MD consistency.
# To be discouraged once both engines support arbitrary lattices.
if isinstance(self._config.box, float):
self.atoms.cell = np.eye(3) * self._config.box
self.atoms.pbc = True
Expand Down
6 changes: 6 additions & 0 deletions src/mlip/simulation/configs/ase_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ class ASESimulationConfig(SimulationConfig):
w.r.t. the sum of the force norms. See the
ASE docs for more information. If not set,
the ASE default will be used.

Note:
This simulation engine supports generic PBCs and lattice parameters read from
the `cell` attribute of the `ase.Atoms` to be simulated. Setting the `box`
parameter from configuration is discouraged in this case, but kept for
consistency with `JaxMDSimulationEngine` engine for now.
"""

log_interval: PositiveInt | None = None
Expand Down
8 changes: 5 additions & 3 deletions src/mlip/simulation/configs/simulation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ class SimulationConfig(pydantic.BaseModel):
state. This means information about every N-th
snapshot is stored in the simulation state available to the
loggers (N being the snapshot interval). Defaults to 1.
box: The simulation box. If ``None``, no periodic boundary conditions are
applied (this is the default). It can be set to either a float or a list
of three floats, describing the dimensions of the box.
box: The optional simulation box. If `None`, no periodic boundary conditions
(PBCs) are applied. It can be set to either a float or a list three floats,
to enforce orthorhombic PBCs. Note that the `ASESimulationEngine` supports
generic PBCs by reading the `cell` attribute of the `ase.Atoms` to
simulate, in which case the `box` parameter should not be used.
edge_capacity_multiplier: Factor to multiply the number of edges by to
obtain the edge capacity including padding. Defaults
to 1.25.
Expand Down
44 changes: 41 additions & 3 deletions src/mlip/simulation/jax_md/jax_md_simulation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def _initialize(

logger.debug("Initialization of simulation begins...")
self._config = config
self._atoms = atoms
self._force_field = force_field

positions = tree_map(lambda a: a.get_positions(), atoms)
self._num_atoms = tree_map(lambda p: p.shape[0], positions)
if isinstance(self._num_atoms, list) and 0 in self._num_atoms:
Expand Down Expand Up @@ -233,14 +236,24 @@ def _reallocate_neighbors(self) -> None:
lambda s, n: s.set(neighbors=n),
self._internal_state.system_state,
new_neighbors,
is_leaf=lambda x: is_system_state(x) or is_neighbor_list(x),
)
)
self._update_base_graph_in_pure_sim_step_fun(new_neighbors)
logger.debug("Reallocation of neighbor lists completed.")

def _init_box(self) -> None:
if self._config.box is None:
# TODO: test jax_md.periodic_general() for arbitrary lattices. For now, we
# check that the ase.Atoms do not have PBCs or cell, since Jax-MD only
# supports orthorhombic boxes that are passed from config.
has_pbc = (
any(atoms.pbc.any() for atoms in self._atoms)
if isinstance(self._atoms, list)
else self._atoms.pbc.any()
)
if self._config.box is None and not has_pbc:
self._displacement_fun, self._shift_fun = jax_md.space.free()
else:
elif self._config.box is not None:
box = (
np.array(self._config.box)
if isinstance(self._config.box, list)
Expand All @@ -249,6 +262,11 @@ def _init_box(self) -> None:
self._displacement_fun, self._shift_fun = jax_md.space.periodic(
box, wrapped=False
)
else:
raise NotImplementedError(
"Jax-MD can only be used with cubic boxes passed from config for now. "
"To avoid this error, you can set atoms.pbc to False."
)

@staticmethod
def _get_model_calculate_fun(
Expand Down Expand Up @@ -281,7 +299,7 @@ def calc_func(

forces_split_idx = None
if is_batched_sim:
sizes = np.delete(graph.n_node, 1)
sizes = np.delete(graph.n_node, -1)
forces_split_idx = [int(sum(sizes[:i])) for i in range(1, len(sizes))]

return functools.partial(
Expand Down Expand Up @@ -576,3 +594,23 @@ def _init_base_graph(
return batched_graph._replace(edges=saved_edges)

return graph

def _update_base_graph_in_pure_sim_step_fun(
self, neighbors: jax_md.partition.NeighborList
) -> None:
"""After reallocation of neighbors, the simulation step function needs to
be updated because the `graph.n_edge` attribute has changed."""
senders = tree_map(lambda n: n.idx[1, :], neighbors, is_leaf=is_neighbor_list)
receivers = tree_map(lambda n: n.idx[0, :], neighbors, is_leaf=is_neighbor_list)
new_base_graph = self._init_base_graph(
self._atoms, senders, receivers, self._force_field.allowed_atomic_numbers
)
model_calculate_fun = self._get_model_calculate_fun(
new_base_graph,
self._force_field,
is_batched_sim=isinstance(self._atoms, list),
)
_, sim_apply_fun = init_simulation_algorithm(
model_calculate_fun, self._shift_fun, self._config
)
self._pure_simulation_step_fun.keywords["apply_fun"] = sim_apply_fun