Skip to content

chore(upstream): Propose modular dependencies PR to jax-md #95

@ericchansen

Description

@ericchansen

Summary

Propose a PR to jax-md that splits its monolithic dependency list into core + optional extras, enabling lightweight installation for users who only need classical MM force fields (not neural network potentials).

Problem

jax-md currently requires 13 direct dependencies, many of which are only needed for neural network potentials:

Dependency Needed for core MM? Needed for NN potentials?
jax, jaxlib Yes Yes
numpy Yes Yes
absl-py Yes Yes
frozendict Yes Yes
einops Yes Yes
ml_collections Partially (energy.py) Yes
dm-haiku No (only nn.py, energy.py NN aliases) Yes
flax No (only _nn/gnome.py) Yes
optax No (only NN training) Yes
jraph No (only nn.py graph networks) Yes
e3nn-jax No (only equivariant NN) Yes
pymatgen No (crystal structure I/O) Maybe

The NN-only dependencies pull in a massive transitive tree including orbax-checkpoint -> uvloop (Unix-only), making jax-md uninstallable on Windows.

Meanwhile, jax-md's mm_forcefields/ module (OPLSAA, ReaxFF) only needs jax, numpy, absl-py, frozendict, einops, and the core jax_md modules (space, smap, partition).

Proposed changes to jax-md

1. Split dependencies into core + extras

# pyproject.toml
[project]
dependencies = [
    "absl-py",
    "numpy",
    "jax>=0.5.0",
    "jaxlib>=0.5.0",
    "einops",
    "frozendict",
    "ml_collections",
]

[project.optional-dependencies]
nn = [
    "flax",
    "dm-haiku",
    "optax",
    "jraph",
    "e3nn-jax",
]
crystal = ["pymatgen"]
all = ["jax-md[nn,crystal]"]

2. Use lazy imports in __init__.py

# Current (breaks if flax missing):
from jax_md import nn       # imports haiku, flax, jraph at module load

# Proposed:
def __getattr__(name):
    if name == 'nn':
        from jax_md import nn
        return nn
    raise AttributeError(f"module 'jax_md' has no attribute {name}")

3. Guard NN imports in energy.py

# Current:
import haiku as hk
from jax_md import nn
bp = nn.behler_parrinello
gnome = nn.gnome

# Proposed:
try:
    import haiku as hk
    from jax_md import nn
    bp = nn.behler_parrinello
    gnome = nn.gnome
    nequip = nn.nequip
except ImportError:
    hk = None
    bp = gnome = nequip = None

Why this matters

  1. Windows support: Without flax, jax-md installs cleanly on Windows (no uvloop chain)
  2. Lighter installs: Users who only need classical MM don't need ~15 ML packages
  3. Broader adoption: Projects like Q2MM only need jax_md.mm_forcefields and jax_md.energy (classical potentials)
  4. Follows best practices: Many scientific Python packages use optional extras (e.g., scipy, pandas)

Our use case (Q2MM)

We're building a quantum-guided molecular mechanics force field optimizer. We need:

  • jax_md.mm_forcefields.oplsaa -- OPLSAA energy functions with topology
  • jax_md.space, smap, partition -- core geometry/mapping primitives
  • jax.grad / jax.hessian on the energy function -- analytical gradients

We do NOT need: neural network potentials, graph networks, equivariant NNs, or crystal structure I/O.

Currently we maintain our own ~100-line JAX energy implementation because jax-md won't install on Windows. This is duplicated effort that benefits nobody.

Action items

  • Fork jax-md and implement the changes above
  • Verify mm_forcefields works with core-only deps
  • Run jax-md's test suite to ensure no regressions
  • Submit PR to jax-md/jax-md with rationale
  • If accepted, update our [jax-md] extra to use the slim install

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    upstreamUpstream dependency work

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions