Skip to content

perf: Restructure JAX objective as end-to-end differentiable loss (JAX-ReaxFF pattern) #176

@ericchansen

Description

@ericchansen

Summary

Restructure the JAX objective evaluation to follow the JAX-ReaxFF pattern: a single jax.jit-compiled, jax.vmap-vectorized loss function that evaluates all molecules in one kernel launch, enabling end-to-end differentiation from parameters → loss.

Motivation

GPU benchmarks show CPU is 1.6x faster than GPU for q2mm's current workloads (see GPU Acceleration docs). Research into how JAX-ReaxFF achieves 10--100x GPU speedup identified three architectural differences that explain the gap:

Current q2mm architecture (GPU-unfriendly)

  1. Per-molecule Python loop -- ObjectiveFunction._compute_residuals() iterates over molecules in Python, calling engine.hessian() once per molecule. Each call is a separate GPU kernel launch.
  2. No cross-molecule batching -- Hessians for 9 structures are computed in 9 separate kernels rather than one vmapped call.
  3. Scipy drives the optimizer -- Each scipy.optimize.minimize call invokes the objective as a black-box function, preventing JAX from tracing through the full optimization step.

JAX-ReaxFF architecture (GPU-friendly)

  1. vmap over all geometries -- jax.vmap(calculate_energy_and_charges, in_axes=(0,0,0,None)) evaluates all molecules in a single kernel launch.
  2. End-to-end differentiable loss -- One JIT-compiled function from parameters to total loss, enabling jax.grad(loss)(params) for analytical gradients.
  3. JAX-native optimization -- Gradient-based optimizers (L-BFGS) run inside JAX's traced computation graph.

Sources:

Proposed Changes

Phase 0: Investigate float32 viability (quick win)

The current code forces float64 globally (_jax_common.py sets jax_enable_x64=True), which limits the RTX 5090 to 1.6 TFLOPS instead of its full 104.8 TFLOPS -- a 64x penalty.

Analysis shows float64 may not be necessary for the current harmonic-only JaxEngine:

  • The unit conversion (kcal/mol/A^2 to Hartree/Bohr^2) is linear -- it does not change relative precision. Float32's ~7 significant digits give the same accuracy at any scale.
  • For harmonic terms (E = k(r-r0)^2), the Hessian is exactly 2k -- no catastrophic cancellation in the autodiff tape.
  • To resolve 1 cm^-1 at a 1000 cm^-1 mode, you need ~8% relative precision in the Hessian element. Float32 provides 1.2e-7 relative precision -- 5 orders of magnitude more than needed.

Float64 would become necessary for:

  • VdW terms (1/r^12 - 1/r^6 near-cancellation in autodiff)
  • Morse potentials, cross-terms
  • Very soft modes (~10 cm^-1) near zero eigenvalues

Action: Run CH3F and rh-enamide benchmarks with jax_enable_x64=False and compare frequency accuracy against float64 baseline. If frequencies agree to <0.1 cm^-1, float32 is viable and unlocks 64x more GPU throughput on consumer hardware.

Phase 1: Batch Hessians across molecules

  • Pad molecules to uniform atom count
  • Use jax.vmap(hessian_fn) to compute all Hessians in one kernel launch
  • Estimated improvement: 2--3x for multi-molecule systems

Phase 2: End-to-end differentiable frequency loss

  • Create a single JIT-compiled function: params -> energy_fn -> hessian -> eigenvalues -> frequencies -> residuals -> loss
  • Use jax.grad(loss) for analytical parameter gradients (replacing finite-difference in GRAD steps)
  • This eliminates the Python loop and enables full GPU kernel fusion

Phase 3: JAX-native optimizer

  • Replace scipy.optimize.minimize with a JAX-traced optimizer (e.g., JAXopt L-BFGS)
  • The entire optimization loop becomes a single compiled computation
  • This is the step that enables true GPU acceleration

Additional Considerations

  • Molecule padding overhead -- vmap requires uniform array shapes. Padding small molecules (5 atoms) to match large ones (62 atoms) wastes compute. Grouping molecules by size or using JAX's dynamic shapes could mitigate this.
  • Backward compatibility -- The existing ObjectiveFunction API should continue to work for non-JAX engines (OpenMM, Tinker). The JAX-native path would be an alternative code path activated when using JaxEngine.
  • Datacenter GPUs -- Even with all optimizations, an A100 (FP64:FP32 = 1:2, 9.7 TFLOPS FP64) would be far more suitable than the RTX 5090 (1:64) for float64-heavy workloads.

Related Issues

Success Criteria

  • Float32 viability tested and documented (Phase 0)
  • Multi-molecule frequency evaluation uses vmap (no Python loop)
  • Analytical parameter gradients via jax.grad for frequency objective
  • GPU benchmark shows speedup over CPU for rh-enamide (9 molecules, 94 params)
  • All existing tests continue to pass
  • Benchmark results documented with reproducible commands

Metadata

Metadata

Assignees

No one assigned

    Labels

    backendBackend-specific (OpenMM, Tinker, JAX, etc.)performancePerformance and optimization improvements

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions