Summary
Propose a PR to jax-md that splits its monolithic dependency list into core + optional extras, enabling lightweight installation for users who only need classical MM force fields (not neural network potentials).
Problem
jax-md currently requires 13 direct dependencies, many of which are only needed for neural network potentials:
| Dependency |
Needed for core MM? |
Needed for NN potentials? |
jax, jaxlib |
Yes |
Yes |
numpy |
Yes |
Yes |
absl-py |
Yes |
Yes |
frozendict |
Yes |
Yes |
einops |
Yes |
Yes |
ml_collections |
Partially (energy.py) |
Yes |
dm-haiku |
No (only nn.py, energy.py NN aliases) |
Yes |
flax |
No (only _nn/gnome.py) |
Yes |
optax |
No (only NN training) |
Yes |
jraph |
No (only nn.py graph networks) |
Yes |
e3nn-jax |
No (only equivariant NN) |
Yes |
pymatgen |
No (crystal structure I/O) |
Maybe |
The NN-only dependencies pull in a massive transitive tree including orbax-checkpoint -> uvloop (Unix-only), making jax-md uninstallable on Windows.
Meanwhile, jax-md's mm_forcefields/ module (OPLSAA, ReaxFF) only needs jax, numpy, absl-py, frozendict, einops, and the core jax_md modules (space, smap, partition).
Proposed changes to jax-md
1. Split dependencies into core + extras
# pyproject.toml
[project]
dependencies = [
"absl-py",
"numpy",
"jax>=0.5.0",
"jaxlib>=0.5.0",
"einops",
"frozendict",
"ml_collections",
]
[project.optional-dependencies]
nn = [
"flax",
"dm-haiku",
"optax",
"jraph",
"e3nn-jax",
]
crystal = ["pymatgen"]
all = ["jax-md[nn,crystal]"]
2. Use lazy imports in __init__.py
# Current (breaks if flax missing):
from jax_md import nn # imports haiku, flax, jraph at module load
# Proposed:
def __getattr__(name):
if name == 'nn':
from jax_md import nn
return nn
raise AttributeError(f"module 'jax_md' has no attribute {name}")
3. Guard NN imports in energy.py
# Current:
import haiku as hk
from jax_md import nn
bp = nn.behler_parrinello
gnome = nn.gnome
# Proposed:
try:
import haiku as hk
from jax_md import nn
bp = nn.behler_parrinello
gnome = nn.gnome
nequip = nn.nequip
except ImportError:
hk = None
bp = gnome = nequip = None
Why this matters
- Windows support: Without flax, jax-md installs cleanly on Windows (no uvloop chain)
- Lighter installs: Users who only need classical MM don't need ~15 ML packages
- Broader adoption: Projects like Q2MM only need
jax_md.mm_forcefields and jax_md.energy (classical potentials)
- Follows best practices: Many scientific Python packages use optional extras (e.g., scipy, pandas)
Our use case (Q2MM)
We're building a quantum-guided molecular mechanics force field optimizer. We need:
jax_md.mm_forcefields.oplsaa -- OPLSAA energy functions with topology
jax_md.space, smap, partition -- core geometry/mapping primitives
jax.grad / jax.hessian on the energy function -- analytical gradients
We do NOT need: neural network potentials, graph networks, equivariant NNs, or crystal structure I/O.
Currently we maintain our own ~100-line JAX energy implementation because jax-md won't install on Windows. This is duplicated effort that benefits nobody.
Action items
Related
Summary
Propose a PR to jax-md that splits its monolithic dependency list into core + optional extras, enabling lightweight installation for users who only need classical MM force fields (not neural network potentials).
Problem
jax-md currently requires 13 direct dependencies, many of which are only needed for neural network potentials:
jax,jaxlibnumpyabsl-pyfrozendicteinopsml_collectionsdm-haikunn.py,energy.pyNN aliases)flax_nn/gnome.py)optaxjraphnn.pygraph networks)e3nn-jaxpymatgenThe NN-only dependencies pull in a massive transitive tree including
orbax-checkpoint->uvloop(Unix-only), making jax-md uninstallable on Windows.Meanwhile, jax-md's
mm_forcefields/module (OPLSAA, ReaxFF) only needsjax,numpy,absl-py,frozendict,einops, and the corejax_mdmodules (space,smap,partition).Proposed changes to jax-md
1. Split dependencies into core + extras
2. Use lazy imports in
__init__.py3. Guard NN imports in
energy.pyWhy this matters
jax_md.mm_forcefieldsandjax_md.energy(classical potentials)Our use case (Q2MM)
We're building a quantum-guided molecular mechanics force field optimizer. We need:
jax_md.mm_forcefields.oplsaa-- OPLSAA energy functions with topologyjax_md.space,smap,partition-- core geometry/mapping primitivesjax.grad/jax.hessianon the energy function -- analytical gradientsWe do NOT need: neural network potentials, graph networks, equivariant NNs, or crystal structure I/O.
Currently we maintain our own ~100-line JAX energy implementation because jax-md won't install on Windows. This is duplicated effort that benefits nobody.
Action items
mm_forcefieldsworks with core-only deps[jax-md]extra to use the slim installRelated