You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The evaluation of a scipy.interpolate.CubicSpline is not traceable by JAX (because of an explicit cast to np.array somewhere in there).
This in turn makes it impossible to pass CorrectionWithGrad.evaluate to jax.jit or jax.vmap if a Binning correction is involved.
for simple Binning (1D histos with scalar bin contents) see below.
for compound Binning (1D histos with Formulas or FormulaRefs as bin contents) there is the additional problem that JAX cannot trace through the bin look-up, and I'm not sure how to fix this
for MultiBinning (ND histos) I don't know of a jax-friendly implementation of a bin look-up differentiable relaxation, we might have to come up with one. MultiBinning is not supported at the moment anyways, tracked in Add support for MultiBinning #15
The evaluation of a
scipy.interpolate.CubicSplineis not traceable by JAX (because of an explicit cast tonp.arraysomewhere in there).This in turn makes it impossible to pass
CorrectionWithGrad.evaluatetojax.jitorjax.vmapif a Binning correction is involved.