From d6bf4e8b2fbca599a2114cd760719b7f9a246611 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 01:12:57 +0000 Subject: [PATCH 01/28] fix: jax 0.8 compatibility --- CMakeLists.txt | 21 +- pyproject.toml | 12 +- python/celerite2/jax/ops.py | 42 +- python/celerite2/jax/xla_ops.cpp | 1364 +++++++++++++++---------- python/celerite2/pymc/ops.py | 6 +- python/celerite2/pymc3/ops.py | 6 +- python/spec/templates/jax/xla_ops.cpp | 297 +++--- 7 files changed, 1032 insertions(+), 716 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a4c9f09..91fb842 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ project( set(PYBIND11_NEWPYTHON ON) find_package(pybind11 CONFIG REQUIRED) +find_package(Python3 COMPONENTS Interpreter REQUIRED) include_directories( "c++/include" @@ -20,6 +21,20 @@ pybind11_add_module(backprop "python/celerite2/backprop.cpp") target_compile_features(backprop PUBLIC cxx_std_14) install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) -pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp") -target_compile_features(xla_ops PUBLIC cxx_std_14) -install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax") +option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON) +if(BUILD_JAX) + execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import sys; import jaxlib; sys.stdout.write(jaxlib.get_include())" + OUTPUT_VARIABLE JAX_INCLUDE + RESULT_VARIABLE JAXLIB_RES + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(JAXLIB_RES EQUAL 0 AND NOT "${JAX_INCLUDE}" STREQUAL "") + message(STATUS "Building JAX extension with jaxlib include: ${JAX_INCLUDE}") + pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp") + target_compile_features(xla_ops PUBLIC cxx_std_17) + target_include_directories(xla_ops PRIVATE "${JAX_INCLUDE}") + install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax") + else() + message(STATUS "Skipping JAX extension (jaxlib not found)") + endif() +endif() diff --git a/pyproject.toml b/pyproject.toml index dc9ea68..648a279 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "celerite2" description = "Fast and scalable Gaussian Processes in 1D" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" license = { text = "MIT License" } classifiers = [ "Development Status :: 4 - Beta", @@ -18,10 +18,8 @@ dependencies = ["numpy"] [project.optional-dependencies] test = ["pytest", "scipy", "celerite"] -pymc3 = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"] -theano = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"] -pymc = ["pymc>=5.9.2"] -jax = ["jax"] +pymc = ["pymc>=5.26.1"] +jax = ["jax>=0.8.0,<0.9.0"] docs = [ "sphinx", "sphinx-material", @@ -31,11 +29,11 @@ docs = [ "matplotlib", "scipy", "emcee", - "pymc>=5", + "pymc>=5.26.1", "tqdm", "numpyro", ] -tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"] +tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] [project.urls] "Homepage" = "https://celerite2.readthedocs.io" diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index ffbc5f8..b869db9 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -23,17 +23,16 @@ from itertools import chain import numpy as np -import pkg_resources -from jax import core, lax +import importlib.resources as resources +from jax import core, lax, ffi from jax import numpy as jnp from jax.core import ShapedArray from jax.interpreters import ad, mlir, xla from jax.lib import xla_client +from jax._src.core import Primitive # Public Primitive was removed in JAX 0.8 from celerite2.jax import xla_ops -xops = xla_client.ops - def factor(t, c, a, U, V): d, W, S = factor_p.bind(t, c, a, U, V) @@ -103,20 +102,7 @@ def _abstract_eval(spec, *args): def _lowering_rule(name, spec, ctx: mlir.LoweringRuleContext, *args): if any(a.dtype != np.float64 for a in chain(ctx.avals_in, ctx.avals_out)): raise ValueError(f"{spec['name']} requires float64 precision") - shapes = [a.shape for a in ctx.avals_in] - dims = OrderedDict( - (s["name"], shapes[s["coords"][0]][s["coords"][1]]) - for s in spec["dimensions"] - ) - return mlir.custom_call( - name, - operands=tuple(mlir.ir_constant(np.int32(v)) for v in dims.values()) - + args, - operand_layouts=[()] * len(dims) - + _default_layouts(aval.shape for aval in ctx.avals_in), - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - result_layouts=_default_layouts(aval.shape for aval in ctx.avals_out), - ).results + return ffi.ffi_lowering(name)(ctx, *args) def _default_layouts(shapes): @@ -184,11 +170,14 @@ def _rev_lowering_rule(name, spec, ctx, *args): def _build_op(name, spec): - xla_client.register_custom_call_target( - name, getattr(xla_ops, spec["name"])(), platform="cpu" + ffi.register_ffi_target( + name, + getattr(xla_ops, spec["name"])(), + platform="cpu", + api_version=1, ) - prim = core.Primitive(f"celerite2_{spec['name']}") + prim = Primitive(f"celerite2_{spec['name']}") prim.multiple_results = True prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(_abstract_eval, spec)) @@ -199,15 +188,16 @@ def _build_op(name, spec): if not spec["has_rev"]: return prim, None - xla_client.register_custom_call_target( + ffi.register_ffi_target( name + "_rev", getattr(xla_ops, f"{spec['name']}_rev")(), platform="cpu", + api_version=1, ) - jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp") + jvp_prim = Primitive(f"celerite2_{spec['name']}_jvp") jvp_prim.multiple_results = True - rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev") + rev_prim = Primitive(f"celerite2_{spec['name']}_rev") rev_prim.multiple_results = True # Setup a dummy JVP rule @@ -227,9 +217,7 @@ def _build_op(name, spec): return prim, rev_prim -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with resources.files("celerite2").joinpath("definitions.json").open("r") as f: definitions = {spec["name"]: spec for spec in json.load(f)} diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index 73a2a55..0b57fd6 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -1,579 +1,881 @@ // NOTE: This file was autogenerated // NOTE: Changes should be made to the template +// Generated JAX FFI bindings for celerite2. +// Regenerate with: python python/spec/generate.py + #include -#include -#include +#include #include -#include + +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/api/api.h" + #include "../driver.hpp" namespace py = pybind11; +namespace ffi = xla::ffi; using namespace celerite2::driver; - -auto factor (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - - const double *t = reinterpret_cast(in[2]); - const double *c = reinterpret_cast(in[3]); - const double *a = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - double *d = reinterpret_cast(out[0]); - double *W = reinterpret_cast(out[1]); - double *S = reinterpret_cast(out[2]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map a_(a, N, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map d_(d, N, 1); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map::value>> S_(S, N, J * J); \ - Eigen::Index flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \ - if (flag) d_.setZero(); \ - } - UNWRAP_CASES_MOST -#undef FIXED_SIZE_MAP +// Helpers +template +inline Eigen::Index dim0(const Buffer& buf) { + return static_cast(buf.dimensions()[0]); } -auto factor_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - - const double *t = reinterpret_cast(in[2]); - const double *c = reinterpret_cast(in[3]); - const double *a = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *d = reinterpret_cast(in[7]); - const double *W = reinterpret_cast(in[8]); - const double *S = reinterpret_cast(in[9]); - const double *bd = reinterpret_cast(in[10]); - const double *bW = reinterpret_cast(in[11]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *ba = reinterpret_cast(out[2]); - double *bU = reinterpret_cast(out[3]); - double *bV = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map a_(a, N, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map d_(d, N, 1); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map::value>> S_(S, N, J * J); \ - Eigen::Map bd_(bd, N, 1); \ - Eigen::Map::value>> bW_(bW, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map ba_(ba, N, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \ - } - UNWRAP_CASES_FEW -#undef FIXED_SIZE_MAP +template +inline Eigen::Index dim1(const Buffer& buf) { + return static_cast(buf.dimensions()[1]); } +template +inline Eigen::Index flat_cols(const Buffer& buf) { + const auto& dims = buf.dimensions(); + Eigen::Index cols = 1; + for (size_t i = 1; i < dims.size(); ++i) + cols *= static_cast(dims[i]); + return cols; +} + +inline ffi::Error shape_error(const char* msg) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, std::string(msg)); +} + +// === AUTO-GENERATED KERNELS === + + +ffi::Error FactorImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer a, + ffi::Buffer U, + ffi::Buffer V, + ffi::ResultBuffer d, + ffi::ResultBuffer W, + ffi::ResultBuffer S +) { -auto solve_lower (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST + if (dim0(t) != N) return shape_error("factor shape mismatch"); + + if (dim0(c) != J) return shape_error("factor shape mismatch"); + + if (dim0(a) != N) return shape_error("factor shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("factor shape mismatch"); + + if (dim0(V) != N || dim1(V) != J) return shape_error("factor shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map a_(a.typed_data(), N, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map d_(d->typed_data(), N, 1); \ + Eigen::Map::value>> W_(W->typed_data(), N, J); \ + Eigen::Map::value>> S_(S->typed_data(), N, J * J); \ + d_.setZero(); \ + W_.setZero(); \ + S_.setZero(); \ + celerite2::core::factor( t_, c_, a_, U_, V_, d_,W_,S_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto solve_lower_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bW = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bW_(bW, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } \ - } - UNWRAP_CASES_FEW + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + factor, FactorImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // a + .Arg>() // U + .Arg>() // V + .Ret>() // d + .Ret>() // W + .Ret>() // S +); + + +ffi::Error factor_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer a, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer d, + ffi::Buffer W, + ffi::Buffer S, + ffi::Buffer bd, + ffi::Buffer bW, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer ba, + ffi::ResultBuffer bU, + ffi::ResultBuffer bV +) { + + + if (dim0(t) != N) return shape_error("factor_rev shape mismatch"); + + if (dim0(c) != J) return shape_error("factor_rev shape mismatch"); + + if (dim0(a) != N) return shape_error("factor_rev shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("factor_rev shape mismatch"); + + if (dim0(V) != N || dim1(V) != J) return shape_error("factor_rev shape mismatch"); + + if (dim0(d) != N) return shape_error("factor_rev shape mismatch"); + + if (dim0(W) != N || dim1(W) != J) return shape_error("factor_rev shape mismatch"); + + + if (dim0(bd) != N) return shape_error("factor_rev shape mismatch"); + + if (dim0(bW) != N || dim1(bW) != J) return shape_error("factor_rev shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map a_(a.typed_data(), N, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map d_(d.typed_data(), N, 1); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> S_(S.typed_data(), N, dim1(S)); \ + Eigen::Map bd_(bd.typed_data(), N, 1); \ + Eigen::Map::value>> bW_(bW.typed_data(), N, J); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map ba_(ba->typed_data(), N, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ + bt_.setZero(); \ + bc_.setZero(); \ + ba_.setZero(); \ + bU_.setZero(); \ + bV_.setZero(); \ + celerite2::core::factor_rev( t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_,bc_,ba_,bU_,bV_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP -} + return ffi::Error::Success(); +} -auto solve_upper (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +XLA_FFI_DEFINE_HANDLER_SYMBOL( + factor_rev, factor_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // a + .Arg>() // U + .Arg>() // V + .Arg>() // d + .Arg>() // W + .Arg>() // S + .Arg>() // bd + .Arg>() // bW + .Ret>() // bt + .Ret>() // bc + .Ret>() // ba + .Ret>() // bU + .Ret>() // bV +); + + + +ffi::Error Solve_lowerImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + if (dim0(t) != N) return shape_error("solve_lower shape mismatch"); + + if (dim0(c) != J) return shape_error("solve_lower shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("solve_lower shape mismatch"); + + if (dim0(W) != N || dim1(W) != J) return shape_error("solve_lower shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_lower shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Z_.setZero(); \ + F_.setZero(); \ + celerite2::core::solve_lower( t_, c_, U_, W_, Y_, Z_,F_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto solve_upper_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bW = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bW_(bW, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } \ - } - UNWRAP_CASES_FEW + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_lower, Solve_lowerImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error solve_lower_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bW, + ffi::ResultBuffer bY +) { + + + if (dim0(t) != N) return shape_error("solve_lower_rev shape mismatch"); + + if (dim0(c) != J) return shape_error("solve_lower_rev shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("solve_lower_rev shape mismatch"); + + if (dim0(W) != N || dim1(W) != J) return shape_error("solve_lower_rev shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_lower_rev shape mismatch"); + + if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("solve_lower_rev shape mismatch"); + + + if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("solve_lower_rev shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ + Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bW_(bW->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bW_.setZero(); \ + bY_.setZero(); \ + celerite2::core::solve_lower_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP -} + return ffi::Error::Success(); +} -auto matmul_lower (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_lower_rev, solve_lower_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bW + .Ret>() // bY +); + + + +ffi::Error Solve_upperImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + if (dim0(t) != N) return shape_error("solve_upper shape mismatch"); + + if (dim0(c) != J) return shape_error("solve_upper shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("solve_upper shape mismatch"); + + if (dim0(W) != N || dim1(W) != J) return shape_error("solve_upper shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_upper shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Z_.setZero(); \ + F_.setZero(); \ + celerite2::core::solve_upper( t_, c_, U_, W_, Y_, Z_,F_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto matmul_lower_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bV = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } \ - } - UNWRAP_CASES_FEW + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_upper, Solve_upperImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error solve_upper_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bW, + ffi::ResultBuffer bY +) { + + + if (dim0(t) != N) return shape_error("solve_upper_rev shape mismatch"); + + if (dim0(c) != J) return shape_error("solve_upper_rev shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("solve_upper_rev shape mismatch"); + + if (dim0(W) != N || dim1(W) != J) return shape_error("solve_upper_rev shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_upper_rev shape mismatch"); + + if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("solve_upper_rev shape mismatch"); + + + if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("solve_upper_rev shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ + Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bW_(bW->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bW_.setZero(); \ + bY_.setZero(); \ + celerite2::core::solve_upper_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP -} + return ffi::Error::Success(); +} -auto matmul_upper (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_upper_rev, solve_upper_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bW + .Ret>() // bY +); + + + +ffi::Error Matmul_lowerImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + if (dim0(t) != N) return shape_error("matmul_lower shape mismatch"); + + if (dim0(c) != J) return shape_error("matmul_lower shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_lower shape mismatch"); + + if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_lower shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_lower shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Z_.setZero(); \ + F_.setZero(); \ + celerite2::core::matmul_lower( t_, c_, U_, V_, Y_, Z_,F_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto matmul_upper_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bV = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } \ - } - UNWRAP_CASES_FEW + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_lower, Matmul_lowerImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error matmul_lower_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bV, + ffi::ResultBuffer bY +) { + + + if (dim0(t) != N) return shape_error("matmul_lower_rev shape mismatch"); + + if (dim0(c) != J) return shape_error("matmul_lower_rev shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_lower_rev shape mismatch"); + + if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_lower_rev shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); + + if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); + + + if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ + Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bV_.setZero(); \ + bY_.setZero(); \ + celerite2::core::matmul_lower_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP -} + return ffi::Error::Success(); +} -auto general_matmul_lower (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int M = *reinterpret_cast(in[1]); - const int J = *reinterpret_cast(in[2]); - const int nrhs = *reinterpret_cast(in[3]); - - const double *t1 = reinterpret_cast(in[4]); - const double *t2 = reinterpret_cast(in[5]); - const double *c = reinterpret_cast(in[6]); - const double *U = reinterpret_cast(in[7]); - const double *V = reinterpret_cast(in[8]); - const double *Y = reinterpret_cast(in[9]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_(t1, N, 1); \ - Eigen::Map t2_(t2, M, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, M, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, M, J); \ - Z_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, M, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, M, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_lower_rev, matmul_lower_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bV + .Ret>() // bY +); + + + +ffi::Error Matmul_upperImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + if (dim0(t) != N) return shape_error("matmul_upper shape mismatch"); + + if (dim0(c) != J) return shape_error("matmul_upper shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_upper shape mismatch"); + + if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_upper shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_upper shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Z_.setZero(); \ + F_.setZero(); \ + celerite2::core::matmul_upper( t_, c_, U_, V_, Y_, Z_,F_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto general_matmul_upper (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int M = *reinterpret_cast(in[1]); - const int J = *reinterpret_cast(in[2]); - const int nrhs = *reinterpret_cast(in[3]); - - const double *t1 = reinterpret_cast(in[4]); - const double *t2 = reinterpret_cast(in[5]); - const double *c = reinterpret_cast(in[6]); - const double *U = reinterpret_cast(in[7]); - const double *V = reinterpret_cast(in[8]); - const double *Y = reinterpret_cast(in[9]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); - -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_(t1, N, 1); \ - Eigen::Map t2_(t2, M, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, M, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, M, J); \ - Z_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, M, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, M, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } \ - } - UNWRAP_CASES_MOST +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_upper, Matmul_upperImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error matmul_upper_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bV, + ffi::ResultBuffer bY +) { + + + if (dim0(t) != N) return shape_error("matmul_upper_rev shape mismatch"); + + if (dim0(c) != J) return shape_error("matmul_upper_rev shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_upper_rev shape mismatch"); + + if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_upper_rev shape mismatch"); + + if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); + + if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); + + + if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ + Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ + Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bV_.setZero(); \ + bY_.setZero(); \ + celerite2::core::matmul_upper_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_upper_rev, matmul_upper_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bV + .Ret>() // bY +); + + + +ffi::Error General_matmul_lowerImpl( + ffi::Buffer t1, + ffi::Buffer t2, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + if (dim0(t1) != N) return shape_error("general_matmul_lower shape mismatch"); + + if (dim0(t2) != M) return shape_error("general_matmul_lower shape mismatch"); + + if (dim0(c) != J) return shape_error("general_matmul_lower shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("general_matmul_lower shape mismatch"); + + if (dim0(V) != M || dim1(V) != J) return shape_error("general_matmul_lower shape mismatch"); + + if (dim0(Y) != M || dim1(Y) != nrhs) return shape_error("general_matmul_lower shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_(t1.typed_data(), N, 1); \ + Eigen::Map t2_(t2.typed_data(), M, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), M, J); \ + Eigen::Map> Y_(Y.typed_data(), M, dim1(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F->typed_data(), M, flat_cols(F)); \ + Z_.setZero(); \ + F_.setZero(); \ + celerite2::core::general_matmul_lower( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + } + UNWRAP_CASES_MOST +#undef FIXED_SIZE_MAP + + return ffi::Error::Success(); +} -// https://en.cppreference.com/w/cpp/numeric/bit_cast -template -typename std::enable_if::value && std::is_trivially_copyable::value, To>::type -bit_cast(const From &src) noexcept { - static_assert(std::is_trivially_constructible::value, - "This implementation additionally requires destination type to be trivially constructible"); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + general_matmul_lower, General_matmul_lowerImpl, + ffi::Ffi::Bind() + .Arg>() // t1 + .Arg>() // t2 + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + + +ffi::Error General_matmul_upperImpl( + ffi::Buffer t1, + ffi::Buffer t2, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + if (dim0(t1) != N) return shape_error("general_matmul_upper shape mismatch"); + + if (dim0(t2) != M) return shape_error("general_matmul_upper shape mismatch"); + + if (dim0(c) != J) return shape_error("general_matmul_upper shape mismatch"); + + if (dim0(U) != N || dim1(U) != J) return shape_error("general_matmul_upper shape mismatch"); + + if (dim0(V) != M || dim1(V) != J) return shape_error("general_matmul_upper shape mismatch"); + + if (dim0(Y) != M || dim1(Y) != nrhs) return shape_error("general_matmul_upper shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_(t1.typed_data(), N, 1); \ + Eigen::Map t2_(t2.typed_data(), M, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), M, J); \ + Eigen::Map> Y_(Y.typed_data(), M, dim1(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ + Eigen::Map> F_(F->typed_data(), M, flat_cols(F)); \ + Z_.setZero(); \ + F_.setZero(); \ + celerite2::core::general_matmul_upper( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + } + UNWRAP_CASES_MOST +#undef FIXED_SIZE_MAP - To dst; - memcpy(&dst, &src, sizeof(To)); - return dst; + return ffi::Error::Success(); } -template -py::capsule encapsulate_function(T* fn) { - return py::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + general_matmul_upper, General_matmul_upperImpl, + ffi::Ffi::Bind() + .Arg>() // t1 + .Arg>() // t2 + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + + +// Pybind -------------------------------------------------------------------- +template +py::capsule Encapsulate() { + return py::capsule(reinterpret_cast(Fn), "xla._CUSTOM_CALL_TARGET"); } PYBIND11_MODULE(xla_ops, m) { - m.def("factor", []() { - return encapsulate_function(factor); - }); - m.def("factor_rev", []() { - return encapsulate_function(factor_rev); - }); - m.def("solve_lower", []() { - return encapsulate_function(solve_lower); - }); - m.def("solve_lower_rev", []() { - return encapsulate_function(solve_lower_rev); - }); - m.def("solve_upper", []() { - return encapsulate_function(solve_upper); - }); - m.def("solve_upper_rev", []() { - return encapsulate_function(solve_upper_rev); - }); - m.def("matmul_lower", []() { - return encapsulate_function(matmul_lower); - }); - m.def("matmul_lower_rev", []() { - return encapsulate_function(matmul_lower_rev); - }); - m.def("matmul_upper", []() { - return encapsulate_function(matmul_upper); - }); - m.def("matmul_upper_rev", []() { - return encapsulate_function(matmul_upper_rev); - }); - m.def("general_matmul_lower", []() { - return encapsulate_function(general_matmul_lower); - }); - m.def("general_matmul_upper", []() { - return encapsulate_function(general_matmul_upper); - }); + + m.def("factor", &Encapsulate); + m.def("factor_rev", &Encapsulate); + + m.def("solve_lower", &Encapsulate); + m.def("solve_lower_rev", &Encapsulate); + + m.def("solve_upper", &Encapsulate); + m.def("solve_upper_rev", &Encapsulate); + + m.def("matmul_lower", &Encapsulate); + m.def("matmul_lower_rev", &Encapsulate); + + m.def("matmul_upper", &Encapsulate); + m.def("matmul_upper_rev", &Encapsulate); + + m.def("general_matmul_lower", &Encapsulate); + + m.def("general_matmul_upper", &Encapsulate); + } diff --git a/python/celerite2/pymc/ops.py b/python/celerite2/pymc/ops.py index 2cceef5..de1300a 100644 --- a/python/celerite2/pymc/ops.py +++ b/python/celerite2/pymc/ops.py @@ -15,7 +15,7 @@ from itertools import chain import numpy as np -import pkg_resources +import importlib.resources as resources import pytensor import pytensor.tensor as pt from pytensor.graph import basic, op @@ -140,9 +140,7 @@ def grad(self, inputs, gradients): return self.rev_op(*chain(inputs, outputs, grads)) -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with resources.files("celerite2").joinpath("definitions.json").open("r") as f: definitions = {spec["name"]: spec for spec in json.load(f)} diff --git a/python/celerite2/pymc3/ops.py b/python/celerite2/pymc3/ops.py index c945858..d1108e4 100644 --- a/python/celerite2/pymc3/ops.py +++ b/python/celerite2/pymc3/ops.py @@ -15,7 +15,7 @@ from itertools import chain import numpy as np -import pkg_resources +import importlib.resources as resources import theano import theano.tensor as tt from theano.graph import basic, op @@ -140,9 +140,7 @@ def grad(self, inputs, gradients): return self.rev_op(*chain(inputs, outputs, grads)) -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with resources.files("celerite2").joinpath("definitions.json").open("r") as f: definitions = {spec["name"]: spec for spec in json.load(f)} diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index bd22248..3d153e1 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -1,169 +1,186 @@ +// Generated JAX FFI bindings for celerite2. +// Regenerate with: python python/spec/generate.py + #include -#include -#include +#include #include -#include + +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/api/api.h" + #include "../driver.hpp" namespace py = pybind11; +namespace ffi = xla::ffi; using namespace celerite2::driver; +// Helpers +template +inline Eigen::Index dim0(const Buffer& buf) { + return static_cast(buf.dimensions()[0]); +} +template +inline Eigen::Index dim1(const Buffer& buf) { + return static_cast(buf.dimensions()[1]); +} +template +inline Eigen::Index flat_cols(const Buffer& buf) { + const auto& dims = buf.dimensions(); + Eigen::Index cols = 1; + for (size_t i = 1; i < dims.size(); ++i) + cols *= static_cast(dims[i]); + return cols; +} + +inline ffi::Error shape_error(const char* msg) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, std::string(msg)); +} + +// === AUTO-GENERATED KERNELS === {% for mod in spec %} -auto {{mod.name}} (void *out_tuple, const void **in) { - {% set input_index = mod.dimensions|length - 1 -%} - void **out = reinterpret_cast(out_tuple); - {%- for dim in mod.dimensions %} - const int {{dim.name}} = *reinterpret_cast(in[{{loop.index - 1}}]); - {%- endfor %} - {% for arg in mod.inputs %} - const double *{{arg.name}} = reinterpret_cast(in[{{ input_index + loop.index }}]); - {%- endfor %} - {%- for arg in mod.outputs + mod.extra_outputs %} - double *{{arg.name}} = reinterpret_cast(out[{{ loop.index - 1 }}]); - {%- endfor %} -#define FIXED_SIZE_MAP(SIZE) \ - { \ - {%- for arg in mod.inputs + mod.outputs + mod.extra_outputs %} - {%- if arg.shape|length == 1 -%} - {%- if arg.shape[0] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, J, 1); \ +ffi::Error {{mod.name|capitalize}}Impl( +{%- for arg in mod.inputs %} + ffi::Buffer {{arg.name}}{% if not loop.last or mod.outputs or mod.extra_outputs %},{% endif %} +{%- endfor -%} +{%- for arg in mod.outputs %} + ffi::ResultBuffer {{arg.name}}{% if not loop.last or mod.extra_outputs %},{% endif %} +{%- endfor -%} +{%- for arg in mod.extra_outputs %} + ffi::ResultBuffer {{arg.name}}{% if not loop.last %},{% endif %} +{%- endfor %} +) { + {# Minimal shape checks - rely on driver.hpp order helper #} + {% for arg in mod.inputs %} + {%- if arg.shape|length == 1 %} + if (dim0({{arg.name}}) != {{arg.shape[0]}}) return shape_error("{{mod.name}} shape mismatch"); + {%- elif arg.shape|length == 2 %} + if (dim0({{arg.name}}) != {{arg.shape[0]}} || dim1({{arg.name}}) != {{arg.shape[1]}}) return shape_error("{{mod.name}} shape mismatch"); + {%- endif %} + {% endfor %} + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + {%- for arg in mod.inputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {%- endif -%} - {%- elif arg.shape|length == 2 -%} - {%- if arg.shape[1] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {%- endif -%} - {%- else -%} - {%- if arg.shape[2] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J * J); \ - {%- endif -%} - {%- endif -%} - {% endfor %} - {%- if mod.name == "factor" %} - Eigen::Index flag = celerite2::core::{{mod.name}}({% for val in mod.inputs + mod.outputs + mod.extra_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - if (flag) d_.setZero(); \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ + {%- endif %} + {%- endfor %} + {%- for arg in mod.outputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - if (nrhs == 1) { \ - {% for arg in mod.inputs + mod.outputs + mod.extra_outputs %} - {%- if arg.shape|length == 2 and arg.shape[1] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {% elif arg.shape|length == 3 and arg.shape[2] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {% endif -%} - {% endfor -%} - {% for arg in mod.outputs -%} - {{arg.name}}_.setZero(); \ - {% endfor -%} - celerite2::core::{{mod.name}}({% for val in mod.inputs + mod.outputs + mod.extra_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } else { \ - {% for arg in mod.inputs + mod.outputs + mod.extra_outputs %} - {%- if (arg.shape|length == 2 and arg.shape[1] == "nrhs") or (arg.shape|length == 3 and arg.shape[2] == "nrhs") -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, {{arg.shape[1:]|join(" * ")}}); \ - {% endif -%} - {% endfor -%} - {% for arg in mod.outputs -%} - {{arg.name}}_.setZero(); \ - {% endfor -%} - celerite2::core::{{mod.name}}({% for val in mod.inputs + mod.outputs + mod.extra_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ {%- endif %} - } - UNWRAP_CASES_MOST -#undef FIXED_SIZE_MAP -} -{%- if mod.has_rev %} -auto {{mod.name}}_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - {%- if mod.name == "factor" -%} - {%- set input_index = 1 -%} + {%- endfor %} + {%- for arg in mod.extra_outputs %} + {%- if arg.shape|length == 3 and arg.shape[1] == "J" and arg.shape[2] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J * J); \ {%- else %} - const int nrhs = *reinterpret_cast(in[2]); - {%- set input_index = 2 -%} + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, flat_cols({{arg.name}})); \ {%- endif %} - {% for arg in mod.rev_inputs %} - const double *{{arg.name}} = reinterpret_cast(in[{{ input_index + loop.index }}]); {%- endfor %} - {%- for arg in mod.rev_outputs %} - double *{{arg.name}} = reinterpret_cast(out[{{ loop.index - 1 }}]); + {%- for arg in mod.outputs + mod.extra_outputs %} + {{arg.name}}_.setZero(); \ {%- endfor %} + celerite2::core::{{mod.name}}({%- for arg in mod.inputs %} {{arg.name}}_{% if not loop.last or mod.outputs or mod.extra_outputs %},{% endif %}{%- endfor %}{%- if mod.outputs or mod.extra_outputs %} {% endif %}{%- for arg in mod.outputs + mod.extra_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + } + UNWRAP_CASES_{{ "FEW" if mod.has_rev else "MOST" }} +#undef FIXED_SIZE_MAP + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + {{mod.name}}, {{mod.name|capitalize}}Impl, + ffi::Ffi::Bind() +{%- for arg in mod.inputs %} + .Arg>() // {{arg.name}} +{%- endfor %} +{%- for arg in mod.outputs + mod.extra_outputs %} + .Ret>() // {{arg.name}} +{%- endfor %} +); +{% if mod.has_rev %} -#define FIXED_SIZE_MAP(SIZE) \ - { \ - {%- for arg in mod.rev_inputs + mod.rev_outputs %} - {%- if arg.shape|length == 1 -%} - {%- if arg.shape[0] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, J, 1); \ +ffi::Error {{mod.name}}_revImpl( +{%- for arg in mod.rev_inputs %} + ffi::Buffer {{arg.name}}{% if not loop.last or mod.rev_outputs %},{% endif %} +{%- endfor -%} +{%- for arg in mod.rev_outputs %} + ffi::ResultBuffer {{arg.name}}{% if not loop.last %},{% endif %} +{%- endfor %} +) { + {# Minimal shape checks #} + {% for arg in mod.rev_inputs %} + {%- if arg.shape|length == 1 %} + if (dim0({{arg.name}}) != {{arg.shape[0]}}) return shape_error("{{mod.name}}_rev shape mismatch"); + {%- elif arg.shape|length == 2 %} + if (dim0({{arg.name}}) != {{arg.shape[0]}} || dim1({{arg.name}}) != {{arg.shape[1]}}) return shape_error("{{mod.name}}_rev shape mismatch"); + {%- endif %} + {% endfor %} + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + {%- for arg in mod.rev_inputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {%- endif -%} - {%- elif arg.shape|length == 2 -%} - {%- if arg.shape[1] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {%- endif -%} - {%- else -%} - {%- if arg.shape[2] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J * J); \ - {%- endif -%} - {%- endif -%} - {% endfor %} - {%- if mod.name == "factor" %} - celerite2::core::{{mod.name}}_rev({% for val in mod.rev_inputs + mod.rev_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ + {%- endif %} + {%- endfor %} + {%- for arg in mod.rev_outputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - if (nrhs == 1) { \ - {% for arg in mod.rev_inputs + mod.rev_outputs %} - {%- if arg.shape|length == 2 and arg.shape[1] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {% elif arg.shape|length == 3 and arg.shape[2] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {% endif -%} - {% endfor -%} - celerite2::core::{{mod.name}}_rev({% for val in mod.rev_inputs + mod.rev_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } else { \ - {% for arg in mod.rev_inputs + mod.rev_outputs %} - {%- if (arg.shape|length == 2 and arg.shape[1] == "nrhs") or (arg.shape|length == 3 and arg.shape[2] == "nrhs") -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, {{arg.shape[1:]|join(" * ")}}); \ - {% endif -%} - {% endfor -%} - celerite2::core::{{mod.name}}_rev({% for val in mod.rev_inputs + mod.rev_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ {%- endif %} - } - UNWRAP_CASES_FEW + {%- endfor %} + {%- for arg in mod.rev_outputs %} + {{arg.name}}_.setZero(); \ + {%- endfor %} + celerite2::core::{{mod.name}}_rev({%- for arg in mod.rev_inputs %} {{arg.name}}_{% if not loop.last or mod.rev_outputs %},{% endif %}{%- endfor %}{%- if mod.rev_outputs %} {% endif %}{%- for arg in mod.rev_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + {{mod.name}}_rev, {{mod.name}}_revImpl, + ffi::Ffi::Bind() +{%- for arg in mod.rev_inputs %} + .Arg>() // {{arg.name}} +{%- endfor %} +{%- for arg in mod.rev_outputs %} + .Ret>() // {{arg.name}} +{%- endfor %} +); {% endif %} {% endfor %} -// https://en.cppreference.com/w/cpp/numeric/bit_cast -template -typename std::enable_if::value && std::is_trivially_copyable::value, To>::type -bit_cast(const From &src) noexcept { - static_assert(std::is_trivially_constructible::value, - "This implementation additionally requires destination type to be trivially constructible"); - - To dst; - memcpy(&dst, &src, sizeof(To)); - return dst; -} - -template -py::capsule encapsulate_function(T* fn) { - return py::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +// Pybind -------------------------------------------------------------------- +template +py::capsule Encapsulate() { + return py::capsule(reinterpret_cast(Fn), "xla._CUSTOM_CALL_TARGET"); } PYBIND11_MODULE(xla_ops, m) { - {%- for mod in spec %} - m.def("{{mod.name}}", []() { - return encapsulate_function({{mod.name}}); - }); - {%- if mod.has_rev %} - m.def("{{mod.name}}_rev", []() { - return encapsulate_function({{mod.name}}_rev); - }); - {%- endif %} - {%- endfor %} + {% for mod in spec %} + m.def("{{mod.name}}", &Encapsulate<{{mod.name}}>); + {%- if mod.has_rev %} + m.def("{{mod.name}}_rev", &Encapsulate<{{mod.name}}_rev>); + {%- endif %} + {% endfor %} } From 314ba3e91cc503c51ca20ca2cf3cea7b9111221a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 01:40:19 +0000 Subject: [PATCH 02/28] fix: get pymc working on v5 --- CMakeLists.txt | 2 +- c++/include/celerite2/forward.hpp | 4 +- pyproject.toml | 4 +- python/celerite2/jax/celerite2.py | 4 + python/celerite2/jax/ops.py | 3 +- python/celerite2/jax/xla_ops.cpp | 157 ++++++++++++++++++++------ python/celerite2/pymc/ops.py | 66 +++++++++++ python/spec/generate.py | 5 +- python/spec/templates/jax/xla_ops.cpp | 19 +++- 9 files changed, 217 insertions(+), 47 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 91fb842..90bfb9f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,7 @@ install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON) if(BUILD_JAX) execute_process( - COMMAND ${Python3_EXECUTABLE} -c "import sys; import jaxlib; sys.stdout.write(jaxlib.get_include())" + COMMAND ${Python3_EXECUTABLE} -c "import sys, pathlib; import jaxlib; path = pathlib.Path(jaxlib.__file__).parent / 'include'; sys.stdout.write(str(path) if path.exists() else '')" OUTPUT_VARIABLE JAX_INCLUDE RESULT_VARIABLE JAXLIB_RES OUTPUT_STRIP_TRAILING_WHITESPACE) diff --git a/c++/include/celerite2/forward.hpp b/c++/include/celerite2/forward.hpp index b5b9a39..74e5f4e 100644 --- a/c++/include/celerite2/forward.hpp +++ b/c++/include/celerite2/forward.hpp @@ -297,7 +297,7 @@ void general_matmul_lower(const Eigen::MatrixBase &t1, // ( typedef typename LowRank::Scalar Scalar; typedef typename Eigen::internal::plain_col_type::type CoeffVector; - typedef typename Eigen::Matrix Inner; + typedef typename Eigen::Matrix Inner; Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols(); @@ -358,7 +358,7 @@ void general_matmul_upper(const Eigen::MatrixBase &t1, // ( typedef typename LowRank::Scalar Scalar; typedef typename Eigen::internal::plain_col_type::type CoeffVector; - typedef typename Eigen::Matrix Inner; + typedef typename Eigen::Matrix Inner; Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols(); diff --git a/pyproject.toml b/pyproject.toml index 648a279..666adfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = ["numpy"] [project.optional-dependencies] test = ["pytest", "scipy", "celerite"] pymc = ["pymc>=5.26.1"] -jax = ["jax>=0.8.0,<0.9.0"] +jax = ["jax>=0.8.0,<0.9.0", "jaxlib>=0.8.0,<0.9.0"] docs = [ "sphinx", "sphinx-material", @@ -41,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11"] +requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0,<0.9.0"] build-backend = "scikit_build_core.build" [tool.scikit-build] diff --git a/python/celerite2/jax/celerite2.py b/python/celerite2/jax/celerite2.py index b3cb9fc..b652825 100644 --- a/python/celerite2/jax/celerite2.py +++ b/python/celerite2/jax/celerite2.py @@ -35,6 +35,10 @@ def _do_compute(self, quiet): self._t, self._c, self._a, self._U, self._V ) self._log_det = np.sum(np.log(self._d)) + if np.any(self._d <= 0) or not np.isfinite(self._log_det): + self._log_det = -np.inf + self._norm = np.inf + return self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi)) def _check_sorted(self, t): diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index b869db9..f1b04f0 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -21,6 +21,7 @@ from collections import OrderedDict from functools import partial from itertools import chain +import importlib import numpy as np import importlib.resources as resources @@ -31,7 +32,7 @@ from jax.lib import xla_client from jax._src.core import Primitive # Public Primitive was removed in JAX 0.8 -from celerite2.jax import xla_ops +xla_ops = importlib.import_module("celerite2.jax.xla_ops") def factor(t, c, a, U, V): diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index 0b57fd6..18d1e9e 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -54,6 +54,12 @@ ffi::Error FactorImpl( ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + + if (dim0(t) != N) return shape_error("factor shape mismatch"); if (dim0(c) != J) return shape_error("factor shape mismatch"); @@ -118,6 +124,11 @@ ffi::Error factor_revImpl( ffi::ResultBuffer bV ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + if (dim0(t) != N) return shape_error("factor_rev shape mismatch"); @@ -148,7 +159,7 @@ ffi::Error factor_revImpl( Eigen::Map::value>> V_(V.typed_data(), N, J); \ Eigen::Map d_(d.typed_data(), N, 1); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - Eigen::Map> S_(S.typed_data(), N, dim1(S)); \ + Eigen::Map::value>> S_(S.typed_data(), N, J * J); \ Eigen::Map bd_(bd.typed_data(), N, 1); \ Eigen::Map::value>> bW_(bW.typed_data(), N, J); \ Eigen::Map bt_(bt->typed_data(), N, 1); \ @@ -202,6 +213,14 @@ ffi::Error Solve_lowerImpl( ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + + if (dim0(t) != N) return shape_error("solve_lower shape mismatch"); if (dim0(c) != J) return shape_error("solve_lower shape mismatch"); @@ -220,8 +239,8 @@ ffi::Error Solve_lowerImpl( Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ celerite2::core::solve_lower( t_, c_, U_, W_, Y_, Z_,F_); \ @@ -261,6 +280,13 @@ ffi::Error solve_lower_revImpl( ffi::ResultBuffer bY ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + if (dim0(t) != N) return shape_error("solve_lower_rev shape mismatch"); @@ -284,15 +310,15 @@ ffi::Error solve_lower_revImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ - Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ Eigen::Map bt_(bt->typed_data(), N, 1); \ Eigen::Map bc_(bc->typed_data(), J, 1); \ Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ Eigen::Map::value>> bW_(bW->typed_data(), N, J); \ - Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ bt_.setZero(); \ bc_.setZero(); \ bU_.setZero(); \ @@ -337,6 +363,14 @@ ffi::Error Solve_upperImpl( ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + + if (dim0(t) != N) return shape_error("solve_upper shape mismatch"); if (dim0(c) != J) return shape_error("solve_upper shape mismatch"); @@ -355,8 +389,8 @@ ffi::Error Solve_upperImpl( Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ celerite2::core::solve_upper( t_, c_, U_, W_, Y_, Z_,F_); \ @@ -396,6 +430,13 @@ ffi::Error solve_upper_revImpl( ffi::ResultBuffer bY ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + if (dim0(t) != N) return shape_error("solve_upper_rev shape mismatch"); @@ -419,15 +460,15 @@ ffi::Error solve_upper_revImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ - Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ Eigen::Map bt_(bt->typed_data(), N, 1); \ Eigen::Map bc_(bc->typed_data(), J, 1); \ Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ Eigen::Map::value>> bW_(bW->typed_data(), N, J); \ - Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ bt_.setZero(); \ bc_.setZero(); \ bU_.setZero(); \ @@ -472,6 +513,14 @@ ffi::Error Matmul_lowerImpl( ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + + if (dim0(t) != N) return shape_error("matmul_lower shape mismatch"); if (dim0(c) != J) return shape_error("matmul_lower shape mismatch"); @@ -490,8 +539,8 @@ ffi::Error Matmul_lowerImpl( Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ celerite2::core::matmul_lower( t_, c_, U_, V_, Y_, Z_,F_); \ @@ -531,6 +580,13 @@ ffi::Error matmul_lower_revImpl( ffi::ResultBuffer bY ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + if (dim0(t) != N) return shape_error("matmul_lower_rev shape mismatch"); @@ -554,15 +610,15 @@ ffi::Error matmul_lower_revImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ - Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ Eigen::Map bt_(bt->typed_data(), N, 1); \ Eigen::Map bc_(bc->typed_data(), J, 1); \ Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ - Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ bt_.setZero(); \ bc_.setZero(); \ bU_.setZero(); \ @@ -607,6 +663,14 @@ ffi::Error Matmul_upperImpl( ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + + if (dim0(t) != N) return shape_error("matmul_upper shape mismatch"); if (dim0(c) != J) return shape_error("matmul_upper shape mismatch"); @@ -625,8 +689,8 @@ ffi::Error Matmul_upperImpl( Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F->typed_data(), N, flat_cols(F)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ celerite2::core::matmul_upper( t_, c_, U_, V_, Y_, Z_,F_); \ @@ -666,6 +730,13 @@ ffi::Error matmul_upper_revImpl( ffi::ResultBuffer bY ) { + const auto N = dim0(t); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + if (dim0(t) != N) return shape_error("matmul_upper_rev shape mismatch"); @@ -689,15 +760,15 @@ ffi::Error matmul_upper_revImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z.typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F.typed_data(), N, dim1(F)); \ - Eigen::Map> bZ_(bZ.typed_data(), N, dim1(bZ)); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ Eigen::Map bt_(bt->typed_data(), N, 1); \ Eigen::Map bc_(bc->typed_data(), J, 1); \ Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ - Eigen::Map> bY_(bY->typed_data(), N, dim1(bY)); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ bt_.setZero(); \ bc_.setZero(); \ bU_.setZero(); \ @@ -743,6 +814,16 @@ ffi::Error General_matmul_lowerImpl( ) { + const auto N = dim0(t1); + + const auto M = dim0(t2); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + + if (dim0(t1) != N) return shape_error("general_matmul_lower shape mismatch"); if (dim0(t2) != M) return shape_error("general_matmul_lower shape mismatch"); @@ -764,8 +845,8 @@ ffi::Error General_matmul_lowerImpl( Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), M, J); \ Eigen::Map> Y_(Y.typed_data(), M, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F->typed_data(), M, flat_cols(F)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ celerite2::core::general_matmul_lower( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ @@ -803,6 +884,16 @@ ffi::Error General_matmul_upperImpl( ) { + const auto N = dim0(t1); + + const auto M = dim0(t2); + + const auto J = dim0(c); + + const auto nrhs = dim1(Y); + + + if (dim0(t1) != N) return shape_error("general_matmul_upper shape mismatch"); if (dim0(t2) != M) return shape_error("general_matmul_upper shape mismatch"); @@ -824,8 +915,8 @@ ffi::Error General_matmul_upperImpl( Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), M, J); \ Eigen::Map> Y_(Y.typed_data(), M, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, dim1(Z)); \ - Eigen::Map> F_(F->typed_data(), M, flat_cols(F)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ celerite2::core::general_matmul_upper( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ diff --git a/python/celerite2/pymc/ops.py b/python/celerite2/pymc/ops.py index de1300a..094a191 100644 --- a/python/celerite2/pymc/ops.py +++ b/python/celerite2/pymc/ops.py @@ -19,6 +19,7 @@ import pytensor import pytensor.tensor as pt from pytensor.graph import basic, op +from pytensor.link.jax.dispatch import jax_funcify import celerite2.backprop as backprop import celerite2.driver as driver @@ -156,3 +157,68 @@ def grad(self, inputs, gradients): general_matmul_upper = _CeleriteOp( "general_matmul_upper_fwd", definitions["general_matmul_upper"] ) + + +# JAX conversion for PyTensor JAX linker ------------------------------------- +@jax_funcify.register(_CeleriteOp) +def _jax_funcify_celerite(op, node, **kwargs): + """Map celerite2 PyTensor ops to their JAX counterparts.""" + + # Lazy import to avoid circular import during module import + import celerite2.jax.ops as jax_ops + + def factor_fwd(t, c, a, U, V): + return jax_ops.factor_p.bind(t, c, a, U, V) + + def factor_rev(t, c, a, U, V, d, W, S, bd, bW): + return jax_ops.factor_rev_p.bind(t, c, a, U, V, d, W, S, bd, bW) + + def solve_lower_fwd(t, c, U, W, Y): + return jax_ops.solve_lower_p.bind(t, c, U, W, Y) + + def solve_lower_rev(t, c, U, W, Y, Z, F, bZ): + return jax_ops.solve_lower_rev_p.bind(t, c, U, W, Y, Z, F, bZ) + + def solve_upper_fwd(t, c, U, W, Y): + return jax_ops.solve_upper_p.bind(t, c, U, W, Y) + + def solve_upper_rev(t, c, U, W, Y, Z, F, bZ): + return jax_ops.solve_upper_rev_p.bind(t, c, U, W, Y, Z, F, bZ) + + def matmul_lower_fwd(t, c, U, V, Y): + return jax_ops.matmul_lower_p.bind(t, c, U, V, Y) + + def matmul_lower_rev(t, c, U, V, Y, Z, F, bZ): + return jax_ops.matmul_lower_rev_p.bind(t, c, U, V, Y, Z, F, bZ) + + def matmul_upper_fwd(t, c, U, V, Y): + return jax_ops.matmul_upper_p.bind(t, c, U, V, Y) + + def matmul_upper_rev(t, c, U, V, Y, Z, F, bZ): + return jax_ops.matmul_upper_rev_p.bind(t, c, U, V, Y, Z, F, bZ) + + def general_matmul_lower_fwd(t1, t2, c, U, V, Y): + return jax_ops.general_matmul_lower_p.bind(t1, t2, c, U, V, Y) + + def general_matmul_upper_fwd(t1, t2, c, U, V, Y): + return jax_ops.general_matmul_upper_p.bind(t1, t2, c, U, V, Y) + + mapping = { + "factor_fwd": factor_fwd, + "factor_rev": factor_rev, + "solve_lower_fwd": solve_lower_fwd, + "solve_lower_rev": solve_lower_rev, + "solve_upper_fwd": solve_upper_fwd, + "solve_upper_rev": solve_upper_rev, + "matmul_lower_fwd": matmul_lower_fwd, + "matmul_lower_rev": matmul_lower_rev, + "matmul_upper_fwd": matmul_upper_fwd, + "matmul_upper_rev": matmul_upper_rev, + "general_matmul_lower_fwd": general_matmul_lower_fwd, + "general_matmul_upper_fwd": general_matmul_upper_fwd, + } + + try: + return mapping[op.name] + except KeyError: + raise NotImplementedError(f"No JAX conversion registered for {op.name}") diff --git a/python/spec/generate.py b/python/spec/generate.py index 14323ad..1e5ad2a 100644 --- a/python/spec/generate.py +++ b/python/spec/generate.py @@ -5,7 +5,6 @@ import os from pathlib import Path -import pkg_resources from jinja2 import Environment, FileSystemLoader, select_autoescape base = Path(os.path.dirname(os.path.abspath(__file__))) @@ -15,9 +14,7 @@ autoescape=select_autoescape(["cpp"]), ) -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with open(base.parent / "celerite2" / "definitions.json", "r") as f: data = json.load(f) for n in range(len(data)): diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index 3d153e1..2173144 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -50,6 +50,10 @@ ffi::Error {{mod.name|capitalize}}Impl( ffi::ResultBuffer {{arg.name}}{% if not loop.last %},{% endif %} {%- endfor %} ) { + {# Dimension aliases for readability #} + {% for dim in mod.dimensions %} + const auto {{dim.name}} = dim{{dim.coords[1]}}({{mod.inputs[dim.coords[0]].name}}); + {% endfor %} {# Minimal shape checks - rely on driver.hpp order helper #} {% for arg in mod.inputs %} {%- if arg.shape|length == 1 %} @@ -76,14 +80,14 @@ ffi::Error {{mod.name|capitalize}}Impl( {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} {%- endfor %} {%- for arg in mod.extra_outputs %} {%- if arg.shape|length == 3 and arg.shape[1] == "J" and arg.shape[2] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J * J); \ {%- else %} - Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, flat_cols({{arg.name}})); \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{ '*'.join(arg.shape[1:]) }}); \ {%- endif %} {%- endfor %} {%- for arg in mod.outputs + mod.extra_outputs %} @@ -117,6 +121,9 @@ ffi::Error {{mod.name}}_revImpl( ffi::ResultBuffer {{arg.name}}{% if not loop.last %},{% endif %} {%- endfor %} ) { + {% for dim in mod.dimensions %} + const auto {{dim.name}} = dim{{dim.coords[1]}}({{mod.inputs[dim.coords[0]].name}}); + {% endfor %} {# Minimal shape checks #} {% for arg in mod.rev_inputs %} {%- if arg.shape|length == 1 %} @@ -133,8 +140,12 @@ ffi::Error {{mod.name}}_revImpl( Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ + {%- elif arg.shape|length == 3 and arg.shape[1] == "J" and arg.shape[2] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J * J); \ + {%- elif arg.shape|length == 3 %} + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, {{ '*'.join(arg.shape[1:]) }}); \ {%- else %} - Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} {%- endfor %} {%- for arg in mod.rev_outputs %} @@ -143,7 +154,7 @@ ffi::Error {{mod.name}}_revImpl( {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} {%- endfor %} {%- for arg in mod.rev_outputs %} From f13d3a40722c12b780e58da8a254635b21e04c95 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 01:44:26 +0000 Subject: [PATCH 03/28] deps: lower to python 3.10 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 666adfa..a6f19d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "celerite2" description = "Fast and scalable Gaussian Processes in 1D" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.10" license = { text = "MIT License" } classifiers = [ "Development Status :: 4 - Beta", From 5790e0989691c68a3851f774f6900ac17aee2132 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 01:51:00 +0000 Subject: [PATCH 04/28] fix: JIT compat of `_do_compute` --- python/celerite2/jax/celerite2.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/celerite2/jax/celerite2.py b/python/celerite2/jax/celerite2.py index b652825..9814232 100644 --- a/python/celerite2/jax/celerite2.py +++ b/python/celerite2/jax/celerite2.py @@ -3,6 +3,7 @@ __all__ = ["GaussianProcess", "ConditionalDistribution"] from jax import numpy as np +from jax import lax from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess from celerite2.jax import ops @@ -35,11 +36,17 @@ def _do_compute(self, quiet): self._t, self._c, self._a, self._U, self._V ) self._log_det = np.sum(np.log(self._d)) - if np.any(self._d <= 0) or not np.isfinite(self._log_det): - self._log_det = -np.inf - self._norm = np.inf - return - self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi)) + + def _bad(_): + return -np.inf, np.inf + + def _good(_): + return self._log_det, -0.5 * ( + self._log_det + self._size * np.log(2 * np.pi) + ) + + bad = np.any(self._d <= 0) | (~np.isfinite(self._log_det)) + self._log_det, self._norm = lax.cond(bad, _bad, _good, operand=None) def _check_sorted(self, t): return t From b14435b62f14bf08e224f4d496478d4dd58ba91d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Nov 2025 01:56:41 +0000 Subject: [PATCH 05/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/celerite2/jax/celerite2.py | 2 +- python/celerite2/jax/ops.py | 9 ++++----- python/celerite2/pymc/ops.py | 6 ++++-- python/celerite2/pymc3/ops.py | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/celerite2/jax/celerite2.py b/python/celerite2/jax/celerite2.py index 9814232..a3e8122 100644 --- a/python/celerite2/jax/celerite2.py +++ b/python/celerite2/jax/celerite2.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- __all__ = ["GaussianProcess", "ConditionalDistribution"] +from jax import lax from jax import numpy as np -from jax import lax from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess from celerite2.jax import ops diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index f1b04f0..faba5ae 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -17,20 +17,19 @@ "general_matmul_lower", "general_matmul_upper", ] +import importlib +import importlib.resources as resources import json -from collections import OrderedDict from functools import partial from itertools import chain -import importlib import numpy as np -import importlib.resources as resources -from jax import core, lax, ffi +from jax import core, ffi, lax from jax import numpy as jnp +from jax._src.core import Primitive # Public Primitive was removed in JAX 0.8 from jax.core import ShapedArray from jax.interpreters import ad, mlir, xla from jax.lib import xla_client -from jax._src.core import Primitive # Public Primitive was removed in JAX 0.8 xla_ops = importlib.import_module("celerite2.jax.xla_ops") diff --git a/python/celerite2/pymc/ops.py b/python/celerite2/pymc/ops.py index 094a191..cce6928 100644 --- a/python/celerite2/pymc/ops.py +++ b/python/celerite2/pymc/ops.py @@ -11,11 +11,11 @@ "general_matmul_upper", ] +import importlib.resources as resources import json from itertools import chain import numpy as np -import importlib.resources as resources import pytensor import pytensor.tensor as pt from pytensor.graph import basic, op @@ -221,4 +221,6 @@ def general_matmul_upper_fwd(t1, t2, c, U, V, Y): try: return mapping[op.name] except KeyError: - raise NotImplementedError(f"No JAX conversion registered for {op.name}") + raise NotImplementedError( + f"No JAX conversion registered for {op.name}" + ) diff --git a/python/celerite2/pymc3/ops.py b/python/celerite2/pymc3/ops.py index d1108e4..6d42ece 100644 --- a/python/celerite2/pymc3/ops.py +++ b/python/celerite2/pymc3/ops.py @@ -11,11 +11,11 @@ "general_matmul_upper", ] +import importlib.resources as resources import json from itertools import chain import numpy as np -import importlib.resources as resources import theano import theano.tensor as tt from theano.graph import basic, op From 8b36a42f291c0582a12207712694885ee41ddd96 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 02:12:49 +0000 Subject: [PATCH 06/28] deps: require python 3.11 for jaxlib compat --- .github/workflows/python.yml | 6 +++--- .github/workflows/tutorials.yml.off | 2 +- .github/workflows/wheels.yml | 2 +- pyproject.toml | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3bfda5c..4dfb18e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] os: ["ubuntu-latest"] session: - "core" @@ -68,7 +68,7 @@ jobs: environment-name: test-env create-args: >- mamba - python=3.10 + python=3.11 - name: Install nox run: python -m pip install -U nox @@ -76,4 +76,4 @@ jobs: - name: Run tests run: | python -m nox --non-interactive --error-on-missing-interpreter \ - --session pymc_mamba-3.10 + --session pymc_mamba-3.11 diff --git a/.github/workflows/tutorials.yml.off b/.github/workflows/tutorials.yml.off index 74a3603..cd0278c 100644 --- a/.github/workflows/tutorials.yml.off +++ b/.github/workflows/tutorials.yml.off @@ -26,7 +26,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.11 - name: Install dependencies run: | diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index fa3b02e..49ed81a 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -44,7 +44,7 @@ jobs: - uses: actions/setup-python@v5 name: Install Python with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | python -m pip install -U pip diff --git a/pyproject.toml b/pyproject.toml index a6f19d3..6e33774 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "celerite2" description = "Fast and scalable Gaussian Processes in 1D" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "MIT License" } classifiers = [ "Development Status :: 4 - Beta", @@ -41,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0,<0.9.0"] +requires = ["scikit-build-core", "numpy", "pybind11"] build-backend = "scikit_build_core.build" [tool.scikit-build] From 461fb35bb9dedec5df776b5b541b4451b45769ba Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 02:15:03 +0000 Subject: [PATCH 07/28] ci: skip bad nox arg --- .github/workflows/python.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 4dfb18e..135d19e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -47,7 +47,7 @@ jobs: - name: Run tests run: | - python -m nox --non-interactive --error-on-missing-interpreter \ + python -m nox --non-interactive \ --session ${{ matrix.session }}-${{ matrix.python-version }} tests-pymc: @@ -75,5 +75,5 @@ jobs: - name: Run tests run: | - python -m nox --non-interactive --error-on-missing-interpreter \ + python -m nox --non-interactive \ --session pymc_mamba-3.11 From 9957276e366e96a5bce29c086db14e2c8b04c6c7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 02:17:10 +0000 Subject: [PATCH 08/28] ci: fix nox target version --- .github/workflows/python.yml | 3 +-- noxfile.py | 21 +-------------------- 2 files changed, 2 insertions(+), 22 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 135d19e..0c0ebdb 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -24,8 +24,7 @@ jobs: session: - "core" - "jax" - - "pymc3" - # - "pymc" + - "pymc" - "pymc_jax" steps: diff --git a/noxfile.py b/noxfile.py index 7d9bfe8..725b4a5 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,6 +1,6 @@ import nox -ALL_PYTHON_VS = ["3.8", "3.9", "3.10"] +ALL_PYTHON_VS = ["3.11"] TEST_CMD = ["python", "-m", "pytest", "-v"] @@ -23,12 +23,6 @@ def jax(session): _session_run(session, "python/test/jax") -@nox.session(python=ALL_PYTHON_VS) -def pymc3(session): - session.install(".[test,pymc3]") - _session_run(session, "python/test/pymc3") - - @nox.session(python=ALL_PYTHON_VS) def pymc(session): session.install(".[test,pymc]") @@ -48,19 +42,6 @@ def pymc_jax(session): _session_run(session, "python/test/pymc/test_pymc_ops.py") -@nox.session(python=ALL_PYTHON_VS) -def full(session): - session.install(".[test,jax,pymc3,pymc]") - _session_run(session, "python/test") - - -@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba") -def full_mamba(session): - session.conda_install("jax", "pymc3", "pymc", channel="conda-forge") - session.install(".[test]") - _session_run(session, "python/test") - - @nox.session def lint(session): session.install("pre-commit") From 7d5d97d5490d70ea52d5f8cce0bb9476aee137b8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 02:22:31 +0000 Subject: [PATCH 09/28] deps: fix missing jaxlib install --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6e33774..666adfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11"] +requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0,<0.9.0"] build-backend = "scikit_build_core.build" [tool.scikit-build] From 272344d48499fed9eeda597cd7f5aa718386e806 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 23 Nov 2025 02:48:53 +0000 Subject: [PATCH 10/28] ci: fix missing jax install for pymc tests --- noxfile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 725b4a5..f075c02 100644 --- a/noxfile.py +++ b/noxfile.py @@ -25,14 +25,14 @@ def jax(session): @nox.session(python=ALL_PYTHON_VS) def pymc(session): - session.install(".[test,pymc]") + session.install(".[test,pymc,jax]") _session_run(session, "python/test/pymc") @nox.session(python=ALL_PYTHON_VS, venv_backend="mamba") def pymc_mamba(session): session.conda_install("pymc", channel="conda-forge") - session.install(".[test,pymc]") + session.install(".[test,pymc,jax]") _session_run(session, "python/test/pymc") From 341c2f5e7542d18b016826b65893f0ece0009a5c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 17:19:39 +0000 Subject: [PATCH 11/28] fix: generator will no longer add lines of spaces Co-authored-by: Dan Foreman-Mackey --- python/spec/generate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/spec/generate.py b/python/spec/generate.py index 1e5ad2a..164f994 100644 --- a/python/spec/generate.py +++ b/python/spec/generate.py @@ -12,6 +12,7 @@ env = Environment( loader=FileSystemLoader(base / "templates"), autoescape=select_autoescape(["cpp"]), + keep_trailing_newline=True, ) with open(base.parent / "celerite2" / "definitions.json", "r") as f: @@ -42,7 +43,8 @@ for name in ["driver.cpp", "backprop.cpp", "jax/xla_ops.cpp"]: template = env.get_template(name) result = template.render(spec=data) + clean_result = "\n".join(line.rstrip() for line in result.splitlines()) + "\n" with open(base.parent / "celerite2" / name, "w") as f: f.write("// NOTE: This file was autogenerated\n") f.write("// NOTE: Changes should be made to the template\n\n") - f.write(result) + f.write(clean_result) From ffaff432548fab2ff63a2c7537f7a5e521eb0beb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 17:21:47 +0000 Subject: [PATCH 12/28] fix: re-generated xla_ops without api.h Co-authored-by: Dan Foreman-Mackey --- python/celerite2/jax/xla_ops.cpp | 1 - python/spec/templates/jax/xla_ops.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index 18d1e9e..a6d0929 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -9,7 +9,6 @@ #include #include "xla/ffi/api/ffi.h" -#include "xla/ffi/api/api.h" #include "../driver.hpp" diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index 2173144..5b28e17 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -6,7 +6,6 @@ #include #include "xla/ffi/api/ffi.h" -#include "xla/ffi/api/api.h" #include "../driver.hpp" From f6fc718ca5bc9072fd75c0fdb70758e63738d0b8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 17:24:26 +0000 Subject: [PATCH 13/28] refactor: clean up error handling Co-authored-by: Dan Foreman-Mackey --- python/celerite2/jax/xla_ops.cpp | 150 +++++++++++++++++--------- python/spec/templates/jax/xla_ops.cpp | 20 ++-- 2 files changed, 109 insertions(+), 61 deletions(-) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index a6d0929..cf3bbb5 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -34,10 +34,6 @@ inline Eigen::Index flat_cols(const Buffer& buf) { return cols; } -inline ffi::Error shape_error(const char* msg) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, std::string(msg)); -} - // === AUTO-GENERATED KERNELS === @@ -59,15 +55,15 @@ ffi::Error FactorImpl( - if (dim0(t) != N) return shape_error("factor shape mismatch"); + if (dim0(t) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(c) != J) return shape_error("factor shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(a) != N) return shape_error("factor shape mismatch"); + if (dim0(a) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("factor shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return shape_error("factor shape mismatch"); + if (dim0(V) != N || dim1(V) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -83,7 +79,11 @@ ffi::Error FactorImpl( d_.setZero(); \ W_.setZero(); \ S_.setZero(); \ - celerite2::core::factor( t_, c_, a_, U_, V_, d_,W_,S_); \ + try { \ + celerite2::core::factor( t_, c_, a_, U_, V_, d_,W_,S_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -171,7 +171,11 @@ ffi::Error factor_revImpl( ba_.setZero(); \ bU_.setZero(); \ bV_.setZero(); \ - celerite2::core::factor_rev( t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_,bc_,ba_,bU_,bV_); \ + try { \ + celerite2::core::factor_rev( t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_,bc_,ba_,bU_,bV_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -220,15 +224,15 @@ ffi::Error Solve_lowerImpl( - if (dim0(t) != N) return shape_error("solve_lower shape mismatch"); + if (dim0(t) != N) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(c) != J) return shape_error("solve_lower shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("solve_lower shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return shape_error("solve_lower shape mismatch"); + if (dim0(W) != N || dim1(W) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_lower shape mismatch"); + if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -242,7 +246,11 @@ ffi::Error Solve_lowerImpl( Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ - celerite2::core::solve_lower( t_, c_, U_, W_, Y_, Z_,F_); \ + try { \ + celerite2::core::solve_lower( t_, c_, U_, W_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -323,7 +331,11 @@ ffi::Error solve_lower_revImpl( bU_.setZero(); \ bW_.setZero(); \ bY_.setZero(); \ - celerite2::core::solve_lower_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + try { \ + celerite2::core::solve_lower_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -370,15 +382,15 @@ ffi::Error Solve_upperImpl( - if (dim0(t) != N) return shape_error("solve_upper shape mismatch"); + if (dim0(t) != N) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(c) != J) return shape_error("solve_upper shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("solve_upper shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return shape_error("solve_upper shape mismatch"); + if (dim0(W) != N || dim1(W) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_upper shape mismatch"); + if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -392,7 +404,11 @@ ffi::Error Solve_upperImpl( Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ - celerite2::core::solve_upper( t_, c_, U_, W_, Y_, Z_,F_); \ + try { \ + celerite2::core::solve_upper( t_, c_, U_, W_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -473,7 +489,11 @@ ffi::Error solve_upper_revImpl( bU_.setZero(); \ bW_.setZero(); \ bY_.setZero(); \ - celerite2::core::solve_upper_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + try { \ + celerite2::core::solve_upper_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -520,15 +540,15 @@ ffi::Error Matmul_lowerImpl( - if (dim0(t) != N) return shape_error("matmul_lower shape mismatch"); + if (dim0(t) != N) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(c) != J) return shape_error("matmul_lower shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_lower shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_lower shape mismatch"); + if (dim0(V) != N || dim1(V) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_lower shape mismatch"); + if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -542,7 +562,11 @@ ffi::Error Matmul_lowerImpl( Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ - celerite2::core::matmul_lower( t_, c_, U_, V_, Y_, Z_,F_); \ + try { \ + celerite2::core::matmul_lower( t_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -623,7 +647,11 @@ ffi::Error matmul_lower_revImpl( bU_.setZero(); \ bV_.setZero(); \ bY_.setZero(); \ - celerite2::core::matmul_lower_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + try { \ + celerite2::core::matmul_lower_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -670,15 +698,15 @@ ffi::Error Matmul_upperImpl( - if (dim0(t) != N) return shape_error("matmul_upper shape mismatch"); + if (dim0(t) != N) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(c) != J) return shape_error("matmul_upper shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_upper shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_upper shape mismatch"); + if (dim0(V) != N || dim1(V) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_upper shape mismatch"); + if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -692,7 +720,11 @@ ffi::Error Matmul_upperImpl( Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ - celerite2::core::matmul_upper( t_, c_, U_, V_, Y_, Z_,F_); \ + try { \ + celerite2::core::matmul_upper( t_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -773,7 +805,11 @@ ffi::Error matmul_upper_revImpl( bU_.setZero(); \ bV_.setZero(); \ bY_.setZero(); \ - celerite2::core::matmul_upper_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + try { \ + celerite2::core::matmul_upper_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP @@ -823,17 +859,17 @@ ffi::Error General_matmul_lowerImpl( - if (dim0(t1) != N) return shape_error("general_matmul_lower shape mismatch"); + if (dim0(t1) != N) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(t2) != M) return shape_error("general_matmul_lower shape mismatch"); + if (dim0(t2) != M) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(c) != J) return shape_error("general_matmul_lower shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("general_matmul_lower shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(V) != M || dim1(V) != J) return shape_error("general_matmul_lower shape mismatch"); + if (dim0(V) != M || dim1(V) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(Y) != M || dim1(Y) != nrhs) return shape_error("general_matmul_lower shape mismatch"); + if (dim0(Y) != M || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -848,7 +884,11 @@ ffi::Error General_matmul_lowerImpl( Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ - celerite2::core::general_matmul_lower( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + try { \ + celerite2::core::general_matmul_lower( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP @@ -893,17 +933,17 @@ ffi::Error General_matmul_upperImpl( - if (dim0(t1) != N) return shape_error("general_matmul_upper shape mismatch"); + if (dim0(t1) != N) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(t2) != M) return shape_error("general_matmul_upper shape mismatch"); + if (dim0(t2) != M) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(c) != J) return shape_error("general_matmul_upper shape mismatch"); + if (dim0(c) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("general_matmul_upper shape mismatch"); + if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(V) != M || dim1(V) != J) return shape_error("general_matmul_upper shape mismatch"); + if (dim0(V) != M || dim1(V) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(Y) != M || dim1(Y) != nrhs) return shape_error("general_matmul_upper shape mismatch"); + if (dim0(Y) != M || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -918,7 +958,11 @@ ffi::Error General_matmul_upperImpl( Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ - celerite2::core::general_matmul_upper( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + try { \ + celerite2::core::general_matmul_upper( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index 5b28e17..b0a369b 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -31,10 +31,6 @@ inline Eigen::Index flat_cols(const Buffer& buf) { return cols; } -inline ffi::Error shape_error(const char* msg) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, std::string(msg)); -} - // === AUTO-GENERATED KERNELS === {% for mod in spec %} @@ -56,9 +52,9 @@ ffi::Error {{mod.name|capitalize}}Impl( {# Minimal shape checks - rely on driver.hpp order helper #} {% for arg in mod.inputs %} {%- if arg.shape|length == 1 %} - if (dim0({{arg.name}}) != {{arg.shape[0]}}) return shape_error("{{mod.name}} shape mismatch"); + if (dim0({{arg.name}}) != {{arg.shape[0]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); {%- elif arg.shape|length == 2 %} - if (dim0({{arg.name}}) != {{arg.shape[0]}} || dim1({{arg.name}}) != {{arg.shape[1]}}) return shape_error("{{mod.name}} shape mismatch"); + if (dim0({{arg.name}}) != {{arg.shape[0]}} || dim1({{arg.name}}) != {{arg.shape[1]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); {%- endif %} {% endfor %} @@ -92,7 +88,11 @@ ffi::Error {{mod.name|capitalize}}Impl( {%- for arg in mod.outputs + mod.extra_outputs %} {{arg.name}}_.setZero(); \ {%- endfor %} - celerite2::core::{{mod.name}}({%- for arg in mod.inputs %} {{arg.name}}_{% if not loop.last or mod.outputs or mod.extra_outputs %},{% endif %}{%- endfor %}{%- if mod.outputs or mod.extra_outputs %} {% endif %}{%- for arg in mod.outputs + mod.extra_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + try { \ + celerite2::core::{{mod.name}}({%- for arg in mod.inputs %} {{arg.name}}_{% if not loop.last or mod.outputs or mod.extra_outputs %},{% endif %}{%- endfor %}{%- if mod.outputs or mod.extra_outputs %} {% endif %}{%- for arg in mod.outputs + mod.extra_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_{{ "FEW" if mod.has_rev else "MOST" }} #undef FIXED_SIZE_MAP @@ -159,7 +159,11 @@ ffi::Error {{mod.name}}_revImpl( {%- for arg in mod.rev_outputs %} {{arg.name}}_.setZero(); \ {%- endfor %} - celerite2::core::{{mod.name}}_rev({%- for arg in mod.rev_inputs %} {{arg.name}}_{% if not loop.last or mod.rev_outputs %},{% endif %}{%- endfor %}{%- if mod.rev_outputs %} {% endif %}{%- for arg in mod.rev_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + try { \ + celerite2::core::{{mod.name}}_rev({%- for arg in mod.rev_inputs %} {{arg.name}}_{% if not loop.last or mod.rev_outputs %},{% endif %}{%- endfor %}{%- if mod.rev_outputs %} {% endif %}{%- for arg in mod.rev_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ } UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP From 2644c1f17962b39e52d2f862eab1a4c1e7635793 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 17:24:46 +0000 Subject: [PATCH 14/28] fix: `Primitive` import Co-authored-by: Dan Foreman-Mackey --- python/celerite2/jax/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index faba5ae..9c82a2b 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -26,7 +26,7 @@ import numpy as np from jax import core, ffi, lax from jax import numpy as jnp -from jax._src.core import Primitive # Public Primitive was removed in JAX 0.8 +from jax.extend.core import Primitive from jax.core import ShapedArray from jax.interpreters import ad, mlir, xla from jax.lib import xla_client From 4fa899df81e6d714ff5646a9ab724b268ca238ec Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 17:27:54 +0000 Subject: [PATCH 15/28] build: update cmake per review Co-authored-by: Dan Foreman-Mackey --- CMakeLists.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 90bfb9f..dedab48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ project( set(PYBIND11_NEWPYTHON ON) find_package(pybind11 CONFIG REQUIRED) -find_package(Python3 COMPONENTS Interpreter REQUIRED) +find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) include_directories( "c++/include" @@ -24,17 +24,17 @@ install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON) if(BUILD_JAX) execute_process( - COMMAND ${Python3_EXECUTABLE} -c "import sys, pathlib; import jaxlib; path = pathlib.Path(jaxlib.__file__).parent / 'include'; sys.stdout.write(str(path) if path.exists() else '')" - OUTPUT_VARIABLE JAX_INCLUDE - RESULT_VARIABLE JAXLIB_RES - OUTPUT_STRIP_TRAILING_WHITESPACE) - if(JAXLIB_RES EQUAL 0 AND NOT "${JAX_INCLUDE}" STREQUAL "") - message(STATUS "Building JAX extension with jaxlib include: ${JAX_INCLUDE}") + COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE XLA_DIR + RESULT_VARIABLE JAXLIB_RES) + if(JAXLIB_RES EQUAL 0 AND NOT "${XLA_DIR}" STREQUAL "") + message(STATUS "Building JAX extension with XLA include: ${XLA_DIR}") pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp") target_compile_features(xla_ops PUBLIC cxx_std_17) - target_include_directories(xla_ops PRIVATE "${JAX_INCLUDE}") + target_include_directories(xla_ops PUBLIC "${XLA_DIR}") install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax") else() - message(STATUS "Skipping JAX extension (jaxlib not found)") + message(STATUS "Skipping JAX extension (jax.ffi include_dir not found)") endif() endif() From f527c78a9a4494406f4132f373be6d47fa901912 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Nov 2025 17:31:26 +0000 Subject: [PATCH 16/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/celerite2/jax/ops.py | 2 +- python/spec/generate.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index 9c82a2b..cc8e7b5 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -26,8 +26,8 @@ import numpy as np from jax import core, ffi, lax from jax import numpy as jnp -from jax.extend.core import Primitive from jax.core import ShapedArray +from jax.extend.core import Primitive from jax.interpreters import ad, mlir, xla from jax.lib import xla_client diff --git a/python/spec/generate.py b/python/spec/generate.py index 164f994..341e588 100644 --- a/python/spec/generate.py +++ b/python/spec/generate.py @@ -43,7 +43,9 @@ for name in ["driver.cpp", "backprop.cpp", "jax/xla_ops.cpp"]: template = env.get_template(name) result = template.render(spec=data) - clean_result = "\n".join(line.rstrip() for line in result.splitlines()) + "\n" + clean_result = ( + "\n".join(line.rstrip() for line in result.splitlines()) + "\n" + ) with open(base.parent / "celerite2" / name, "w") as f: f.write("// NOTE: This file was autogenerated\n") f.write("// NOTE: Changes should be made to the template\n\n") From 7679b851d02b1617d01bfc42250b3ec9bbcf5ac8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 17:36:18 +0000 Subject: [PATCH 17/28] refactor: put axis into template Co-authored-by: Dan Foreman-Mackey --- python/celerite2/jax/xla_ops.cpp | 84 +++++++++++++-------------- python/spec/templates/jax/xla_ops.cpp | 14 ++--- 2 files changed, 45 insertions(+), 53 deletions(-) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index cf3bbb5..d481f0c 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -17,13 +17,9 @@ namespace ffi = xla::ffi; using namespace celerite2::driver; // Helpers -template -inline Eigen::Index dim0(const Buffer& buf) { - return static_cast(buf.dimensions()[0]); -} -template -inline Eigen::Index dim1(const Buffer& buf) { - return static_cast(buf.dimensions()[1]); +template +inline Eigen::Index dim(const Buffer& buf) { + return static_cast(buf.dimensions()[Axis]); } template inline Eigen::Index flat_cols(const Buffer& buf) { @@ -55,15 +51,15 @@ ffi::Error FactorImpl( - if (dim0(t) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(a) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); + if (dim<0>(a) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -224,15 +220,15 @@ ffi::Error Solve_lowerImpl( - if (dim0(t) != N) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -382,15 +378,15 @@ ffi::Error Solve_upperImpl( - if (dim0(t) != N) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -540,15 +536,15 @@ ffi::Error Matmul_lowerImpl( - if (dim0(t) != N) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -698,15 +694,15 @@ ffi::Error Matmul_upperImpl( - if (dim0(t) != N) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -859,17 +855,17 @@ ffi::Error General_matmul_lowerImpl( - if (dim0(t1) != N) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + if (dim<0>(t1) != N) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(t2) != M) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + if (dim<0>(t2) != M) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(V) != M || dim1(V) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + if (dim<0>(V) != M || dim<1>(V) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); - if (dim0(Y) != M || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + if (dim<0>(Y) != M || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -933,17 +929,17 @@ ffi::Error General_matmul_upperImpl( - if (dim0(t1) != N) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + if (dim<0>(t1) != N) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(t2) != M) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + if (dim<0>(t2) != M) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(c) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(V) != M || dim1(V) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + if (dim<0>(V) != M || dim<1>(V) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); - if (dim0(Y) != M || dim1(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + if (dim<0>(Y) != M || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index b0a369b..aa52ccd 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -14,13 +14,9 @@ namespace ffi = xla::ffi; using namespace celerite2::driver; // Helpers -template -inline Eigen::Index dim0(const Buffer& buf) { - return static_cast(buf.dimensions()[0]); -} -template -inline Eigen::Index dim1(const Buffer& buf) { - return static_cast(buf.dimensions()[1]); +template +inline Eigen::Index dim(const Buffer& buf) { + return static_cast(buf.dimensions()[Axis]); } template inline Eigen::Index flat_cols(const Buffer& buf) { @@ -52,9 +48,9 @@ ffi::Error {{mod.name|capitalize}}Impl( {# Minimal shape checks - rely on driver.hpp order helper #} {% for arg in mod.inputs %} {%- if arg.shape|length == 1 %} - if (dim0({{arg.name}}) != {{arg.shape[0]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); + if (dim<0>({{arg.name}}) != {{arg.shape[0]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); {%- elif arg.shape|length == 2 %} - if (dim0({{arg.name}}) != {{arg.shape[0]}} || dim1({{arg.name}}) != {{arg.shape[1]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); + if (dim<0>({{arg.name}}) != {{arg.shape[0]}} || dim<1>({{arg.name}}) != {{arg.shape[1]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); {%- endif %} {% endfor %} From faa9b951b98f5099e2c9af796559814463e54462 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 18:02:00 +0000 Subject: [PATCH 18/28] fix: add back nrhs==1 branch --- python/celerite2/jax/xla_ops.cpp | 72 ++++++++++++++++++++++----- python/spec/templates/jax/xla_ops.cpp | 21 ++++++++ 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index d481f0c..6b80caf 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -237,8 +237,16 @@ ffi::Error Solve_lowerImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y.typed_data(), N, 1); \ + } else { \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + } \ + if (nrhs == 1) { \ + Eigen::Map Z_(Z->typed_data(), N, 1); \ + } else { \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + } \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -395,8 +403,16 @@ ffi::Error Solve_upperImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y.typed_data(), N, 1); \ + } else { \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + } \ + if (nrhs == 1) { \ + Eigen::Map Z_(Z->typed_data(), N, 1); \ + } else { \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + } \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -553,8 +569,16 @@ ffi::Error Matmul_lowerImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y.typed_data(), N, 1); \ + } else { \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + } \ + if (nrhs == 1) { \ + Eigen::Map Z_(Z->typed_data(), N, 1); \ + } else { \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + } \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -711,8 +735,16 @@ ffi::Error Matmul_upperImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ - Eigen::Map> Y_(Y.typed_data(), N, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y.typed_data(), N, 1); \ + } else { \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + } \ + if (nrhs == 1) { \ + Eigen::Map Z_(Z->typed_data(), N, 1); \ + } else { \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + } \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -875,8 +907,16 @@ ffi::Error General_matmul_lowerImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), M, J); \ - Eigen::Map> Y_(Y.typed_data(), M, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y.typed_data(), M, 1); \ + } else { \ + Eigen::Map> Y_(Y.typed_data(), M, nrhs); \ + } \ + if (nrhs == 1) { \ + Eigen::Map Z_(Z->typed_data(), N, 1); \ + } else { \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + } \ Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -949,8 +989,16 @@ ffi::Error General_matmul_upperImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), M, J); \ - Eigen::Map> Y_(Y.typed_data(), M, dim1(Y)); \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + if (nrhs == 1) { \ + Eigen::Map Y_(Y.typed_data(), M, 1); \ + } else { \ + Eigen::Map> Y_(Y.typed_data(), M, nrhs); \ + } \ + if (nrhs == 1) { \ + Eigen::Map Z_(Z->typed_data(), N, 1); \ + } else { \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + } \ Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index aa52ccd..05febbf 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -45,6 +45,15 @@ ffi::Error {{mod.name|capitalize}}Impl( {% for dim in mod.dimensions %} const auto {{dim.name}} = dim{{dim.coords[1]}}({{mod.inputs[dim.coords[0]].name}}); {% endfor %} + {%- set nrhs_arg = None -%} + {%- for a in mod.inputs + mod.outputs + mod.extra_outputs -%} + {%- if a.shape|length >= 2 and a.shape[-1] == "nrhs" and nrhs_arg is none -%} + {%- set nrhs_arg = a -%} + {%- endif -%} + {%- endfor -%} + {%- if nrhs_arg is not none %} + const auto nrhs = dim<{{ nrhs_arg.shape|length - 1 }}>({{ nrhs_arg.name }}); + {%- endif %} {# Minimal shape checks - rely on driver.hpp order helper #} {% for arg in mod.inputs %} {%- if arg.shape|length == 1 %} @@ -61,6 +70,12 @@ ffi::Error {{mod.name|capitalize}}Impl( Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "nrhs" %} + if (nrhs == 1) { \ + Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ + } else { \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, nrhs); \ + } \ {%- else %} Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ {%- endif %} @@ -70,6 +85,12 @@ ffi::Error {{mod.name|capitalize}}Impl( Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "nrhs" %} + if (nrhs == 1) { \ + Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ + } else { \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, nrhs); \ + } \ {%- else %} Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} From 63dd78264806640a2d56ff84a849caeacab28315 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 18:11:32 +0000 Subject: [PATCH 19/28] style: remove extra padding --- python/celerite2/jax/xla_ops.cpp | 48 +++++++++++++-------------- python/spec/templates/jax/xla_ops.cpp | 8 ++--- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index 6b80caf..cb32d32 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -62,8 +62,8 @@ ffi::Error FactorImpl( if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map a_(a.typed_data(), N, 1); \ @@ -145,8 +145,8 @@ ffi::Error factor_revImpl( if (dim0(bW) != N || dim1(bW) != J) return shape_error("factor_rev shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map a_(a.typed_data(), N, 1); \ @@ -231,8 +231,8 @@ ffi::Error Solve_lowerImpl( if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -315,8 +315,8 @@ ffi::Error solve_lower_revImpl( if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("solve_lower_rev shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -397,8 +397,8 @@ ffi::Error Solve_upperImpl( if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -481,8 +481,8 @@ ffi::Error solve_upper_revImpl( if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("solve_upper_rev shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -563,8 +563,8 @@ ffi::Error Matmul_lowerImpl( if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -647,8 +647,8 @@ ffi::Error matmul_lower_revImpl( if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -729,8 +729,8 @@ ffi::Error Matmul_upperImpl( if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -813,8 +813,8 @@ ffi::Error matmul_upper_revImpl( if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t_(t.typed_data(), N, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ @@ -900,8 +900,8 @@ ffi::Error General_matmul_lowerImpl( if (dim<0>(Y) != M || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t1_(t1.typed_data(), N, 1); \ Eigen::Map t2_(t2.typed_data(), M, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ @@ -982,8 +982,8 @@ ffi::Error General_matmul_upperImpl( if (dim<0>(Y) != M || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ Eigen::Map t1_(t1.typed_data(), N, 1); \ Eigen::Map t2_(t2.typed_data(), M, 1); \ Eigen::Map c_(c.typed_data(), J, 1); \ diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index 05febbf..b0fb06d 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -63,8 +63,8 @@ ffi::Error {{mod.name|capitalize}}Impl( {%- endif %} {% endfor %} -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ {%- for arg in mod.inputs %} {%- if arg.shape|length == 1 %} Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ @@ -149,8 +149,8 @@ ffi::Error {{mod.name}}_revImpl( {%- endif %} {% endfor %} -#define FIXED_SIZE_MAP(SIZE) \ - { \ +#define FIXED_SIZE_MAP(SIZE) \ + { \ {%- for arg in mod.rev_inputs %} {%- if arg.shape|length == 1 %} Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ From 09d3ac4ca19e9850fadf8642f3349a2a04c23ebf Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 29 Nov 2025 20:29:40 +0000 Subject: [PATCH 20/28] deps: remove upper bound --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 666adfa..337e10d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = ["numpy"] [project.optional-dependencies] test = ["pytest", "scipy", "celerite"] pymc = ["pymc>=5.26.1"] -jax = ["jax>=0.8.0,<0.9.0", "jaxlib>=0.8.0,<0.9.0"] +jax = ["jax>=0.8.0"] docs = [ "sphinx", "sphinx-material", @@ -41,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0,<0.9.0"] +requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0"] build-backend = "scikit_build_core.build" [tool.scikit-build] From 7d71090c7ca26947720679f97f2941de01cc2627 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 30 Nov 2025 00:49:35 +0000 Subject: [PATCH 21/28] chore: re-run generator --- python/celerite2/jax/xla_ops.cpp | 218 ++++++++++---------------- python/spec/templates/jax/xla_ops.cpp | 22 +-- 2 files changed, 92 insertions(+), 148 deletions(-) diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index cb32d32..5d44178 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -45,9 +45,9 @@ ffi::Error FactorImpl( ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); @@ -119,30 +119,30 @@ ffi::Error factor_revImpl( ffi::ResultBuffer bV ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - if (dim0(t) != N) return shape_error("factor_rev shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(c) != J) return shape_error("factor_rev shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(a) != N) return shape_error("factor_rev shape mismatch"); + if (dim<0>(a) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("factor_rev shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return shape_error("factor_rev shape mismatch"); + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(d) != N) return shape_error("factor_rev shape mismatch"); + if (dim<0>(d) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return shape_error("factor_rev shape mismatch"); + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(bd) != N) return shape_error("factor_rev shape mismatch"); + if (dim<0>(bd) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - if (dim0(bW) != N || dim1(bW) != J) return shape_error("factor_rev shape mismatch"); + if (dim<0>(bW) != N || dim<1>(bW) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -212,11 +212,11 @@ ffi::Error Solve_lowerImpl( ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); @@ -237,16 +237,8 @@ ffi::Error Solve_lowerImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y.typed_data(), N, 1); \ - } else { \ - Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ - } \ - if (nrhs == 1) { \ - Eigen::Map Z_(Z->typed_data(), N, 1); \ - } else { \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ - } \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -291,28 +283,28 @@ ffi::Error solve_lower_revImpl( ffi::ResultBuffer bY ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); - if (dim0(t) != N) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); - if (dim0(c) != J) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); - if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); - if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("solve_lower_rev shape mismatch"); + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -378,11 +370,11 @@ ffi::Error Solve_upperImpl( ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); @@ -403,16 +395,8 @@ ffi::Error Solve_upperImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> W_(W.typed_data(), N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y.typed_data(), N, 1); \ - } else { \ - Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ - } \ - if (nrhs == 1) { \ - Eigen::Map Z_(Z->typed_data(), N, 1); \ - } else { \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ - } \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -457,28 +441,28 @@ ffi::Error solve_upper_revImpl( ffi::ResultBuffer bY ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); - if (dim0(t) != N) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - if (dim0(c) != J) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - if (dim0(W) != N || dim1(W) != J) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("solve_upper_rev shape mismatch"); + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -544,11 +528,11 @@ ffi::Error Matmul_lowerImpl( ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); @@ -569,16 +553,8 @@ ffi::Error Matmul_lowerImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y.typed_data(), N, 1); \ - } else { \ - Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ - } \ - if (nrhs == 1) { \ - Eigen::Map Z_(Z->typed_data(), N, 1); \ - } else { \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ - } \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -623,28 +599,28 @@ ffi::Error matmul_lower_revImpl( ffi::ResultBuffer bY ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); - if (dim0(t) != N) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - if (dim0(c) != J) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("matmul_lower_rev shape mismatch"); + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -710,11 +686,11 @@ ffi::Error Matmul_upperImpl( ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); @@ -735,16 +711,8 @@ ffi::Error Matmul_upperImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y.typed_data(), N, 1); \ - } else { \ - Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ - } \ - if (nrhs == 1) { \ - Eigen::Map Z_(Z->typed_data(), N, 1); \ - } else { \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ - } \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -789,28 +757,28 @@ ffi::Error matmul_upper_revImpl( ffi::ResultBuffer bY ) { - const auto N = dim0(t); + const auto N = dim<0>(t); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); - if (dim0(t) != N) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - if (dim0(c) != J) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - if (dim0(U) != N || dim1(U) != J) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - if (dim0(V) != N || dim1(V) != J) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - if (dim0(Y) != N || dim1(Y) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - if (dim0(Z) != N || dim1(Z) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - if (dim0(bZ) != N || dim1(bZ) != nrhs) return shape_error("matmul_upper_rev shape mismatch"); + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); #define FIXED_SIZE_MAP(SIZE) \ @@ -877,13 +845,13 @@ ffi::Error General_matmul_lowerImpl( ) { - const auto N = dim0(t1); + const auto N = dim<0>(t1); - const auto M = dim0(t2); + const auto M = dim<0>(t2); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); @@ -907,16 +875,8 @@ ffi::Error General_matmul_lowerImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y.typed_data(), M, 1); \ - } else { \ - Eigen::Map> Y_(Y.typed_data(), M, nrhs); \ - } \ - if (nrhs == 1) { \ - Eigen::Map Z_(Z->typed_data(), N, 1); \ - } else { \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ - } \ + Eigen::Map> Y_(Y.typed_data(), M, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ @@ -959,13 +919,13 @@ ffi::Error General_matmul_upperImpl( ) { - const auto N = dim0(t1); + const auto N = dim<0>(t1); - const auto M = dim0(t2); + const auto M = dim<0>(t2); - const auto J = dim0(c); + const auto J = dim<0>(c); - const auto nrhs = dim1(Y); + const auto nrhs = dim<1>(Y); @@ -989,16 +949,8 @@ ffi::Error General_matmul_upperImpl( Eigen::Map c_(c.typed_data(), J, 1); \ Eigen::Map::value>> U_(U.typed_data(), N, J); \ Eigen::Map::value>> V_(V.typed_data(), M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y.typed_data(), M, 1); \ - } else { \ - Eigen::Map> Y_(Y.typed_data(), M, nrhs); \ - } \ - if (nrhs == 1) { \ - Eigen::Map Z_(Z->typed_data(), N, 1); \ - } else { \ - Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ - } \ + Eigen::Map> Y_(Y.typed_data(), M, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ Z_.setZero(); \ F_.setZero(); \ diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index b0fb06d..3a9b73a 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -43,7 +43,7 @@ ffi::Error {{mod.name|capitalize}}Impl( ) { {# Dimension aliases for readability #} {% for dim in mod.dimensions %} - const auto {{dim.name}} = dim{{dim.coords[1]}}({{mod.inputs[dim.coords[0]].name}}); + const auto {{dim.name}} = dim<{{dim.coords[1]}}>({{mod.inputs[dim.coords[0]].name}}); {% endfor %} {%- set nrhs_arg = None -%} {%- for a in mod.inputs + mod.outputs + mod.extra_outputs -%} @@ -71,13 +71,9 @@ ffi::Error {{mod.name|capitalize}}Impl( {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ {%- elif arg.shape|length == 2 and arg.shape[1] == "nrhs" %} - if (nrhs == 1) { \ - Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ - } else { \ - Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, nrhs); \ - } \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim<1>({{arg.name}})); \ {%- else %} - Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim1({{arg.name}})); \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim<1>({{arg.name}})); \ {%- endif %} {%- endfor %} {%- for arg in mod.outputs %} @@ -86,11 +82,7 @@ ffi::Error {{mod.name|capitalize}}Impl( {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ {%- elif arg.shape|length == 2 and arg.shape[1] == "nrhs" %} - if (nrhs == 1) { \ - Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ - } else { \ - Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, nrhs); \ - } \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- else %} Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} @@ -138,14 +130,14 @@ ffi::Error {{mod.name}}_revImpl( {%- endfor %} ) { {% for dim in mod.dimensions %} - const auto {{dim.name}} = dim{{dim.coords[1]}}({{mod.inputs[dim.coords[0]].name}}); + const auto {{dim.name}} = dim<{{dim.coords[1]}}>({{mod.inputs[dim.coords[0]].name}}); {% endfor %} {# Minimal shape checks #} {% for arg in mod.rev_inputs %} {%- if arg.shape|length == 1 %} - if (dim0({{arg.name}}) != {{arg.shape[0]}}) return shape_error("{{mod.name}}_rev shape mismatch"); + if (dim<0>({{arg.name}}) != {{arg.shape[0]}}) return ffi::Error::InvalidArgument("{{mod.name}}_rev shape mismatch"); {%- elif arg.shape|length == 2 %} - if (dim0({{arg.name}}) != {{arg.shape[0]}} || dim1({{arg.name}}) != {{arg.shape[1]}}) return shape_error("{{mod.name}}_rev shape mismatch"); + if (dim<0>({{arg.name}}) != {{arg.shape[0]}} || dim<1>({{arg.name}}) != {{arg.shape[1]}}) return ffi::Error::InvalidArgument("{{mod.name}}_rev shape mismatch"); {%- endif %} {% endfor %} From 7985936fd45e949976738378745b9577046c9c4d Mon Sep 17 00:00:00 2001 From: MilesCranmerBot Date: Wed, 11 Feb 2026 22:39:56 +0000 Subject: [PATCH 22/28] Fix JAX 0.8 primitive impl path and PyMC JAX registration --- python/celerite2/jax/ops.py | 19 +++++++++++++++---- python/celerite2/pymc/__init__.py | 5 ----- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index cc8e7b5..66a40cf 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -24,15 +24,26 @@ from itertools import chain import numpy as np -from jax import core, ffi, lax +from jax import ffi, lax from jax import numpy as jnp from jax.core import ShapedArray from jax.extend.core import Primitive -from jax.interpreters import ad, mlir, xla +from jax.interpreters import ad, mlir from jax.lib import xla_client xla_ops = importlib.import_module("celerite2.jax.xla_ops") +try: + # jax<0.8 compatibility path + from jax.interpreters import xla as _xla_interpreter + + _apply_primitive = _xla_interpreter.apply_primitive +except (ImportError, AttributeError): + # jax>=0.8 moved apply_primitive out of jax.interpreters.xla + from jax._src import dispatch as _dispatch + + _apply_primitive = _dispatch.apply_primitive + def factor(t, c, a, U, V): d, W, S = factor_p.bind(t, c, a, U, V) @@ -179,7 +190,7 @@ def _build_op(name, spec): prim = Primitive(f"celerite2_{spec['name']}") prim.multiple_results = True - prim.def_impl(partial(xla.apply_primitive, prim)) + prim.def_impl(partial(_apply_primitive, prim)) prim.def_abstract_eval(partial(_abstract_eval, spec)) mlir.register_lowering( prim, partial(_lowering_rule, name, spec), platform="cpu" @@ -206,7 +217,7 @@ def _build_op(name, spec): ad.primitive_transposes[jvp_prim] = partial(_jvp_transpose, rev_prim, spec) # Handle reverse pass using custom op - rev_prim.def_impl(partial(xla.apply_primitive, rev_prim)) + rev_prim.def_impl(partial(_apply_primitive, rev_prim)) rev_prim.def_abstract_eval(partial(_rev_abstract_eval, spec)) mlir.register_lowering( rev_prim, diff --git a/python/celerite2/pymc/__init__.py b/python/celerite2/pymc/__init__.py index ee77b01..97a5712 100644 --- a/python/celerite2/pymc/__init__.py +++ b/python/celerite2/pymc/__init__.py @@ -23,8 +23,3 @@ def add_flag(current, new): from celerite2.pymc import terms # noqa from celerite2.pymc.celerite2 import GaussianProcess # noqa - -try: - from celerite2.pymc import jax_support # noqa -except ImportError: - pass From 63c1f05ad5eb864fa60ce931e5ac86942248f701 Mon Sep 17 00:00:00 2001 From: MilesCranmerBot Date: Wed, 11 Feb 2026 23:10:24 +0000 Subject: [PATCH 23/28] fix(jax): avoid removed jax.lib.xla_client import --- python/celerite2/jax/ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index 66a40cf..f3b3f7f 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -29,8 +29,6 @@ from jax.core import ShapedArray from jax.extend.core import Primitive from jax.interpreters import ad, mlir -from jax.lib import xla_client - xla_ops = importlib.import_module("celerite2.jax.xla_ops") try: From 60fe47aaea44d44d2f771cba148acf2786f4e116 Mon Sep 17 00:00:00 2001 From: MilesCranmerBot Date: Wed, 11 Feb 2026 23:26:59 +0000 Subject: [PATCH 24/28] build: require jax in PEP517 build env for JAX extension --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 337e10d..b8726ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0"] +requires = ["scikit-build-core", "numpy", "pybind11", "jax>=0.8.0", "jaxlib>=0.8.0"] build-backend = "scikit_build_core.build" [tool.scikit-build] From 8922f2aac6772faec7a7904724d84c70b0cf75ca Mon Sep 17 00:00:00 2001 From: MilesCranmerBot Date: Thu, 12 Feb 2026 00:58:08 +0000 Subject: [PATCH 25/28] refactor(jax): drop pre-0.8 apply_primitive fallback --- python/celerite2/jax/ops.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index f3b3f7f..59fe181 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -31,16 +31,9 @@ from jax.interpreters import ad, mlir xla_ops = importlib.import_module("celerite2.jax.xla_ops") -try: - # jax<0.8 compatibility path - from jax.interpreters import xla as _xla_interpreter - - _apply_primitive = _xla_interpreter.apply_primitive -except (ImportError, AttributeError): - # jax>=0.8 moved apply_primitive out of jax.interpreters.xla - from jax._src import dispatch as _dispatch - - _apply_primitive = _dispatch.apply_primitive +# celerite2 requires jax>=0.8.0 (see pyproject.toml), where apply_primitive lives in +# jax._src.dispatch. +from jax._src.dispatch import apply_primitive as _apply_primitive def factor(t, c, a, U, V): From 3f99d459763f44257f201cc700037d8df38af989 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Feb 2026 07:51:43 +0000 Subject: [PATCH 26/28] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/celerite2/jax/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index 59fe181..85bccf3 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -29,6 +29,7 @@ from jax.core import ShapedArray from jax.extend.core import Primitive from jax.interpreters import ad, mlir + xla_ops = importlib.import_module("celerite2.jax.xla_ops") # celerite2 requires jax>=0.8.0 (see pyproject.toml), where apply_primitive lives in From 5fa9d915e840c8ae86c410081069fa948321485b Mon Sep 17 00:00:00 2001 From: MilesCranmerBot Date: Thu, 12 Feb 2026 14:07:44 +0000 Subject: [PATCH 27/28] build: avoid requiring jax on Python<3.11 for docs --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b8726ca..df7f0eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11", "jax>=0.8.0", "jaxlib>=0.8.0"] +requires = ["scikit-build-core", "numpy", "pybind11", "jax>=0.8.0; python_version >= '3.11'", "jaxlib>=0.8.0"] build-backend = "scikit_build_core.build" [tool.scikit-build] From 29780043a0f63e19eca085f32eca5e60291eee3a Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Sat, 21 Feb 2026 15:10:53 -0500 Subject: [PATCH 28/28] Upgrade RTDs Python version from 3.10 to 3.11 --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 6381e94..c53cbeb 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ build: apt_packages: - fonts-liberation tools: - python: "3.10" + python: "3.11" python: install: