Summary
Restructure the JAX objective evaluation to follow the JAX-ReaxFF pattern: a single jax.jit-compiled, jax.vmap-vectorized loss function that evaluates all molecules in one kernel launch, enabling end-to-end differentiation from parameters → loss.
Motivation
GPU benchmarks show CPU is 1.6x faster than GPU for q2mm's current workloads (see GPU Acceleration docs). Research into how JAX-ReaxFF achieves 10--100x GPU speedup identified three architectural differences that explain the gap:
Current q2mm architecture (GPU-unfriendly)
- Per-molecule Python loop --
ObjectiveFunction._compute_residuals() iterates over molecules in Python, calling engine.hessian() once per molecule. Each call is a separate GPU kernel launch.
- No cross-molecule batching -- Hessians for 9 structures are computed in 9 separate kernels rather than one vmapped call.
- Scipy drives the optimizer -- Each
scipy.optimize.minimize call invokes the objective as a black-box function, preventing JAX from tracing through the full optimization step.
JAX-ReaxFF architecture (GPU-friendly)
vmap over all geometries -- jax.vmap(calculate_energy_and_charges, in_axes=(0,0,0,None)) evaluates all molecules in a single kernel launch.
- End-to-end differentiable loss -- One JIT-compiled function from parameters to total loss, enabling
jax.grad(loss)(params) for analytical gradients.
- JAX-native optimization -- Gradient-based optimizers (L-BFGS) run inside JAX's traced computation graph.
Sources:
Proposed Changes
Phase 0: Investigate float32 viability (quick win)
The current code forces float64 globally (_jax_common.py sets jax_enable_x64=True), which limits the RTX 5090 to 1.6 TFLOPS instead of its full 104.8 TFLOPS -- a 64x penalty.
Analysis shows float64 may not be necessary for the current harmonic-only JaxEngine:
- The unit conversion (kcal/mol/A^2 to Hartree/Bohr^2) is linear -- it does not change relative precision. Float32's ~7 significant digits give the same accuracy at any scale.
- For harmonic terms (
E = k(r-r0)^2), the Hessian is exactly 2k -- no catastrophic cancellation in the autodiff tape.
- To resolve 1 cm^-1 at a 1000 cm^-1 mode, you need ~8% relative precision in the Hessian element. Float32 provides 1.2e-7 relative precision -- 5 orders of magnitude more than needed.
Float64 would become necessary for:
- VdW terms (1/r^12 - 1/r^6 near-cancellation in autodiff)
- Morse potentials, cross-terms
- Very soft modes (~10 cm^-1) near zero eigenvalues
Action: Run CH3F and rh-enamide benchmarks with jax_enable_x64=False and compare frequency accuracy against float64 baseline. If frequencies agree to <0.1 cm^-1, float32 is viable and unlocks 64x more GPU throughput on consumer hardware.
Phase 1: Batch Hessians across molecules
- Pad molecules to uniform atom count
- Use
jax.vmap(hessian_fn) to compute all Hessians in one kernel launch
- Estimated improvement: 2--3x for multi-molecule systems
Phase 2: End-to-end differentiable frequency loss
- Create a single JIT-compiled function:
params -> energy_fn -> hessian -> eigenvalues -> frequencies -> residuals -> loss
- Use
jax.grad(loss) for analytical parameter gradients (replacing finite-difference in GRAD steps)
- This eliminates the Python loop and enables full GPU kernel fusion
Phase 3: JAX-native optimizer
- Replace
scipy.optimize.minimize with a JAX-traced optimizer (e.g., JAXopt L-BFGS)
- The entire optimization loop becomes a single compiled computation
- This is the step that enables true GPU acceleration
Additional Considerations
- Molecule padding overhead --
vmap requires uniform array shapes. Padding small molecules (5 atoms) to match large ones (62 atoms) wastes compute. Grouping molecules by size or using JAX's dynamic shapes could mitigate this.
- Backward compatibility -- The existing
ObjectiveFunction API should continue to work for non-JAX engines (OpenMM, Tinker). The JAX-native path would be an alternative code path activated when using JaxEngine.
- Datacenter GPUs -- Even with all optimizations, an A100 (FP64:FP32 = 1:2, 9.7 TFLOPS FP64) would be far more suitable than the RTX 5090 (1:64) for float64-heavy workloads.
Related Issues
Success Criteria
Summary
Restructure the JAX objective evaluation to follow the JAX-ReaxFF pattern: a single
jax.jit-compiled,jax.vmap-vectorized loss function that evaluates all molecules in one kernel launch, enabling end-to-end differentiation from parameters → loss.Motivation
GPU benchmarks show CPU is 1.6x faster than GPU for q2mm's current workloads (see GPU Acceleration docs). Research into how JAX-ReaxFF achieves 10--100x GPU speedup identified three architectural differences that explain the gap:
Current q2mm architecture (GPU-unfriendly)
ObjectiveFunction._compute_residuals()iterates over molecules in Python, callingengine.hessian()once per molecule. Each call is a separate GPU kernel launch.scipy.optimize.minimizecall invokes the objective as a black-box function, preventing JAX from tracing through the full optimization step.JAX-ReaxFF architecture (GPU-friendly)
vmapover all geometries --jax.vmap(calculate_energy_and_charges, in_axes=(0,0,0,None))evaluates all molecules in a single kernel launch.jax.grad(loss)(params)for analytical gradients.Sources:
Proposed Changes
Phase 0: Investigate float32 viability (quick win)
The current code forces float64 globally (
_jax_common.pysetsjax_enable_x64=True), which limits the RTX 5090 to 1.6 TFLOPS instead of its full 104.8 TFLOPS -- a 64x penalty.Analysis shows float64 may not be necessary for the current harmonic-only JaxEngine:
E = k(r-r0)^2), the Hessian is exactly2k-- no catastrophic cancellation in the autodiff tape.Float64 would become necessary for:
Action: Run CH3F and rh-enamide benchmarks with
jax_enable_x64=Falseand compare frequency accuracy against float64 baseline. If frequencies agree to <0.1 cm^-1, float32 is viable and unlocks 64x more GPU throughput on consumer hardware.Phase 1: Batch Hessians across molecules
jax.vmap(hessian_fn)to compute all Hessians in one kernel launchPhase 2: End-to-end differentiable frequency loss
params -> energy_fn -> hessian -> eigenvalues -> frequencies -> residuals -> lossjax.grad(loss)for analytical parameter gradients (replacing finite-difference in GRAD steps)Phase 3: JAX-native optimizer
scipy.optimize.minimizewith a JAX-traced optimizer (e.g., JAXopt L-BFGS)Additional Considerations
vmaprequires uniform array shapes. Padding small molecules (5 atoms) to match large ones (62 atoms) wastes compute. Grouping molecules by size or using JAX's dynamic shapes could mitigate this.ObjectiveFunctionAPI should continue to work for non-JAX engines (OpenMM, Tinker). The JAX-native path would be an alternative code path activated when using JaxEngine.Related Issues
Success Criteria
vmap(no Python loop)jax.gradfor frequency objective