Skip to content

Make evaluation of Binning corrections JAX-traceable #42

@eguiraud

Description

@eguiraud

The evaluation of a scipy.interpolate.CubicSpline is not traceable by JAX (because of an explicit cast to np.array somewhere in there).

This in turn makes it impossible to pass CorrectionWithGrad.evaluate to jax.jit or jax.vmap if a Binning correction is involved.

  • for simple Binning (1D histos with scalar bin contents) see below.
  • for compound Binning (1D histos with Formulas or FormulaRefs as bin contents) there is the additional problem that JAX cannot trace through the bin look-up, and I'm not sure how to fix this
  • for MultiBinning (ND histos) I don't know of a jax-friendly implementation of a bin look-up differentiable relaxation, we might have to come up with one. MultiBinning is not supported at the moment anyways, tracked in Add support for MultiBinning #15

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions