diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 3bfda5c..0c0ebdb 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -19,13 +19,12 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: ["3.10"]
+ python-version: ["3.11"]
os: ["ubuntu-latest"]
session:
- "core"
- "jax"
- - "pymc3"
- # - "pymc"
+ - "pymc"
- "pymc_jax"
steps:
@@ -47,7 +46,7 @@ jobs:
- name: Run tests
run: |
- python -m nox --non-interactive --error-on-missing-interpreter \
+ python -m nox --non-interactive \
--session ${{ matrix.session }}-${{ matrix.python-version }}
tests-pymc:
@@ -68,12 +67,12 @@ jobs:
environment-name: test-env
create-args: >-
mamba
- python=3.10
+ python=3.11
- name: Install nox
run: python -m pip install -U nox
- name: Run tests
run: |
- python -m nox --non-interactive --error-on-missing-interpreter \
- --session pymc_mamba-3.10
+ python -m nox --non-interactive \
+ --session pymc_mamba-3.11
diff --git a/.github/workflows/tutorials.yml.off b/.github/workflows/tutorials.yml.off
index 74a3603..cd0278c 100644
--- a/.github/workflows/tutorials.yml.off
+++ b/.github/workflows/tutorials.yml.off
@@ -26,7 +26,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
- python-version: 3.8
+ python-version: 3.11
- name: Install dependencies
run: |
diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml
index fa3b02e..49ed81a 100644
--- a/.github/workflows/wheels.yml
+++ b/.github/workflows/wheels.yml
@@ -44,7 +44,7 @@ jobs:
- uses: actions/setup-python@v5
name: Install Python
with:
- python-version: "3.10"
+ python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install -U pip
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 6381e94..c53cbeb 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -8,7 +8,7 @@ build:
apt_packages:
- fonts-liberation
tools:
- python: "3.10"
+ python: "3.11"
python:
install:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index a4c9f09..dedab48 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -6,6 +6,7 @@ project(
set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
+find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module)
include_directories(
"c++/include"
@@ -20,6 +21,20 @@ pybind11_add_module(backprop "python/celerite2/backprop.cpp")
target_compile_features(backprop PUBLIC cxx_std_14)
install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
-pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp")
-target_compile_features(xla_ops PUBLIC cxx_std_14)
-install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax")
+option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON)
+if(BUILD_JAX)
+ execute_process(
+ COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())"
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ OUTPUT_VARIABLE XLA_DIR
+ RESULT_VARIABLE JAXLIB_RES)
+ if(JAXLIB_RES EQUAL 0 AND NOT "${XLA_DIR}" STREQUAL "")
+ message(STATUS "Building JAX extension with XLA include: ${XLA_DIR}")
+ pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp")
+ target_compile_features(xla_ops PUBLIC cxx_std_17)
+ target_include_directories(xla_ops PUBLIC "${XLA_DIR}")
+ install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax")
+ else()
+ message(STATUS "Skipping JAX extension (jax.ffi include_dir not found)")
+ endif()
+endif()
diff --git a/c++/include/celerite2/forward.hpp b/c++/include/celerite2/forward.hpp
index b5b9a39..74e5f4e 100644
--- a/c++/include/celerite2/forward.hpp
+++ b/c++/include/celerite2/forward.hpp
@@ -297,7 +297,7 @@ void general_matmul_lower(const Eigen::MatrixBase &t1, // (
typedef typename LowRank::Scalar Scalar;
typedef typename Eigen::internal::plain_col_type::type CoeffVector;
- typedef typename Eigen::Matrix Inner;
+ typedef typename Eigen::Matrix Inner;
Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols();
@@ -358,7 +358,7 @@ void general_matmul_upper(const Eigen::MatrixBase &t1, // (
typedef typename LowRank::Scalar Scalar;
typedef typename Eigen::internal::plain_col_type::type CoeffVector;
- typedef typename Eigen::Matrix Inner;
+ typedef typename Eigen::Matrix Inner;
Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols();
diff --git a/noxfile.py b/noxfile.py
index 7d9bfe8..f075c02 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -1,6 +1,6 @@
import nox
-ALL_PYTHON_VS = ["3.8", "3.9", "3.10"]
+ALL_PYTHON_VS = ["3.11"]
TEST_CMD = ["python", "-m", "pytest", "-v"]
@@ -23,22 +23,16 @@ def jax(session):
_session_run(session, "python/test/jax")
-@nox.session(python=ALL_PYTHON_VS)
-def pymc3(session):
- session.install(".[test,pymc3]")
- _session_run(session, "python/test/pymc3")
-
-
@nox.session(python=ALL_PYTHON_VS)
def pymc(session):
- session.install(".[test,pymc]")
+ session.install(".[test,pymc,jax]")
_session_run(session, "python/test/pymc")
@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba")
def pymc_mamba(session):
session.conda_install("pymc", channel="conda-forge")
- session.install(".[test,pymc]")
+ session.install(".[test,pymc,jax]")
_session_run(session, "python/test/pymc")
@@ -48,19 +42,6 @@ def pymc_jax(session):
_session_run(session, "python/test/pymc/test_pymc_ops.py")
-@nox.session(python=ALL_PYTHON_VS)
-def full(session):
- session.install(".[test,jax,pymc3,pymc]")
- _session_run(session, "python/test")
-
-
-@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba")
-def full_mamba(session):
- session.conda_install("jax", "pymc3", "pymc", channel="conda-forge")
- session.install(".[test]")
- _session_run(session, "python/test")
-
-
@nox.session
def lint(session):
session.install("pre-commit")
diff --git a/pyproject.toml b/pyproject.toml
index dc9ea68..df7f0eb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ name = "celerite2"
description = "Fast and scalable Gaussian Processes in 1D"
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
readme = "README.md"
-requires-python = ">=3.9"
+requires-python = ">=3.11"
license = { text = "MIT License" }
classifiers = [
"Development Status :: 4 - Beta",
@@ -18,10 +18,8 @@ dependencies = ["numpy"]
[project.optional-dependencies]
test = ["pytest", "scipy", "celerite"]
-pymc3 = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"]
-theano = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"]
-pymc = ["pymc>=5.9.2"]
-jax = ["jax"]
+pymc = ["pymc>=5.26.1"]
+jax = ["jax>=0.8.0"]
docs = [
"sphinx",
"sphinx-material",
@@ -31,11 +29,11 @@ docs = [
"matplotlib",
"scipy",
"emcee",
- "pymc>=5",
+ "pymc>=5.26.1",
"tqdm",
"numpyro",
]
-tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"]
+tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"]
[project.urls]
"Homepage" = "https://celerite2.readthedocs.io"
@@ -43,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"]
"Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues"
[build-system]
-requires = ["scikit-build-core", "numpy", "pybind11"]
+requires = ["scikit-build-core", "numpy", "pybind11", "jax>=0.8.0; python_version >= '3.11'", "jaxlib>=0.8.0"]
build-backend = "scikit_build_core.build"
[tool.scikit-build]
diff --git a/python/celerite2/jax/celerite2.py b/python/celerite2/jax/celerite2.py
index b3cb9fc..a3e8122 100644
--- a/python/celerite2/jax/celerite2.py
+++ b/python/celerite2/jax/celerite2.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
__all__ = ["GaussianProcess", "ConditionalDistribution"]
+from jax import lax
from jax import numpy as np
from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess
@@ -35,7 +36,17 @@ def _do_compute(self, quiet):
self._t, self._c, self._a, self._U, self._V
)
self._log_det = np.sum(np.log(self._d))
- self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi))
+
+ def _bad(_):
+ return -np.inf, np.inf
+
+ def _good(_):
+ return self._log_det, -0.5 * (
+ self._log_det + self._size * np.log(2 * np.pi)
+ )
+
+ bad = np.any(self._d <= 0) | (~np.isfinite(self._log_det))
+ self._log_det, self._norm = lax.cond(bad, _bad, _good, operand=None)
def _check_sorted(self, t):
return t
diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py
index ffbc5f8..85bccf3 100644
--- a/python/celerite2/jax/ops.py
+++ b/python/celerite2/jax/ops.py
@@ -17,22 +17,24 @@
"general_matmul_lower",
"general_matmul_upper",
]
+import importlib
+import importlib.resources as resources
import json
-from collections import OrderedDict
from functools import partial
from itertools import chain
import numpy as np
-import pkg_resources
-from jax import core, lax
+from jax import ffi, lax
from jax import numpy as jnp
from jax.core import ShapedArray
-from jax.interpreters import ad, mlir, xla
-from jax.lib import xla_client
+from jax.extend.core import Primitive
+from jax.interpreters import ad, mlir
-from celerite2.jax import xla_ops
+xla_ops = importlib.import_module("celerite2.jax.xla_ops")
-xops = xla_client.ops
+# celerite2 requires jax>=0.8.0 (see pyproject.toml), where apply_primitive lives in
+# jax._src.dispatch.
+from jax._src.dispatch import apply_primitive as _apply_primitive
def factor(t, c, a, U, V):
@@ -103,20 +105,7 @@ def _abstract_eval(spec, *args):
def _lowering_rule(name, spec, ctx: mlir.LoweringRuleContext, *args):
if any(a.dtype != np.float64 for a in chain(ctx.avals_in, ctx.avals_out)):
raise ValueError(f"{spec['name']} requires float64 precision")
- shapes = [a.shape for a in ctx.avals_in]
- dims = OrderedDict(
- (s["name"], shapes[s["coords"][0]][s["coords"][1]])
- for s in spec["dimensions"]
- )
- return mlir.custom_call(
- name,
- operands=tuple(mlir.ir_constant(np.int32(v)) for v in dims.values())
- + args,
- operand_layouts=[()] * len(dims)
- + _default_layouts(aval.shape for aval in ctx.avals_in),
- result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
- result_layouts=_default_layouts(aval.shape for aval in ctx.avals_out),
- ).results
+ return ffi.ffi_lowering(name)(ctx, *args)
def _default_layouts(shapes):
@@ -184,13 +173,16 @@ def _rev_lowering_rule(name, spec, ctx, *args):
def _build_op(name, spec):
- xla_client.register_custom_call_target(
- name, getattr(xla_ops, spec["name"])(), platform="cpu"
+ ffi.register_ffi_target(
+ name,
+ getattr(xla_ops, spec["name"])(),
+ platform="cpu",
+ api_version=1,
)
- prim = core.Primitive(f"celerite2_{spec['name']}")
+ prim = Primitive(f"celerite2_{spec['name']}")
prim.multiple_results = True
- prim.def_impl(partial(xla.apply_primitive, prim))
+ prim.def_impl(partial(_apply_primitive, prim))
prim.def_abstract_eval(partial(_abstract_eval, spec))
mlir.register_lowering(
prim, partial(_lowering_rule, name, spec), platform="cpu"
@@ -199,15 +191,16 @@ def _build_op(name, spec):
if not spec["has_rev"]:
return prim, None
- xla_client.register_custom_call_target(
+ ffi.register_ffi_target(
name + "_rev",
getattr(xla_ops, f"{spec['name']}_rev")(),
platform="cpu",
+ api_version=1,
)
- jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp")
+ jvp_prim = Primitive(f"celerite2_{spec['name']}_jvp")
jvp_prim.multiple_results = True
- rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev")
+ rev_prim = Primitive(f"celerite2_{spec['name']}_rev")
rev_prim.multiple_results = True
# Setup a dummy JVP rule
@@ -216,7 +209,7 @@ def _build_op(name, spec):
ad.primitive_transposes[jvp_prim] = partial(_jvp_transpose, rev_prim, spec)
# Handle reverse pass using custom op
- rev_prim.def_impl(partial(xla.apply_primitive, rev_prim))
+ rev_prim.def_impl(partial(_apply_primitive, rev_prim))
rev_prim.def_abstract_eval(partial(_rev_abstract_eval, spec))
mlir.register_lowering(
rev_prim,
@@ -227,9 +220,7 @@ def _build_op(name, spec):
return prim, rev_prim
-with open(
- pkg_resources.resource_filename("celerite2", "definitions.json"), "r"
-) as f:
+with resources.files("celerite2").joinpath("definitions.json").open("r") as f:
definitions = {spec["name"]: spec for spec in json.load(f)}
diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp
index 73a2a55..5d44178 100644
--- a/python/celerite2/jax/xla_ops.cpp
+++ b/python/celerite2/jax/xla_ops.cpp
@@ -1,579 +1,1011 @@
// NOTE: This file was autogenerated
// NOTE: Changes should be made to the template
+// Generated JAX FFI bindings for celerite2.
+// Regenerate with: python python/spec/generate.py
+
#include
-#include
-#include
+#include
#include
-#include
+
+#include "xla/ffi/api/ffi.h"
+
#include "../driver.hpp"
namespace py = pybind11;
+namespace ffi = xla::ffi;
using namespace celerite2::driver;
+// Helpers
+template
+inline Eigen::Index dim(const Buffer& buf) {
+ return static_cast(buf.dimensions()[Axis]);
+}
+template
+inline Eigen::Index flat_cols(const Buffer& buf) {
+ const auto& dims = buf.dimensions();
+ Eigen::Index cols = 1;
+ for (size_t i = 1; i < dims.size(); ++i)
+ cols *= static_cast(dims[i]);
+ return cols;
+}
-auto factor (void *out_tuple, const void **in) {
- void **out = reinterpret_cast(out_tuple);
- const int N = *reinterpret_cast(in[0]);
- const int J = *reinterpret_cast(in[1]);
+// === AUTO-GENERATED KERNELS ===
- const double *t = reinterpret_cast(in[2]);
- const double *c = reinterpret_cast(in[3]);
- const double *a = reinterpret_cast(in[4]);
- const double *U = reinterpret_cast(in[5]);
- const double *V = reinterpret_cast(in[6]);
- double *d = reinterpret_cast(out[0]);
- double *W = reinterpret_cast(out[1]);
- double *S = reinterpret_cast(out[2]);
-#define FIXED_SIZE_MAP(SIZE) \
- { \
- Eigen::Map t_(t, N, 1); \
- Eigen::Map> c_(c, J, 1); \
- Eigen::Map a_(a, N, 1); \
- Eigen::Map::value>> U_(U, N, J); \
- Eigen::Map::value>> V_(V, N, J); \
- Eigen::Map d_(d, N, 1); \
- Eigen::Map::value>> W_(W, N, J); \
- Eigen::Map::value>> S_(S, N, J * J); \
- Eigen::Index flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \
- if (flag) d_.setZero(); \
- }
- UNWRAP_CASES_MOST
-#undef FIXED_SIZE_MAP
-}
-auto factor_rev (void *out_tuple, const void **in) {
- void **out = reinterpret_cast(out_tuple);
- const int N = *reinterpret_cast(in[0]);
- const int J = *reinterpret_cast(in[1]);
-
- const double *t = reinterpret_cast(in[2]);
- const double *c = reinterpret_cast(in[3]);
- const double *a = reinterpret_cast(in[4]);
- const double *U = reinterpret_cast(in[5]);
- const double *V = reinterpret_cast(in[6]);
- const double *d = reinterpret_cast(in[7]);
- const double *W = reinterpret_cast(in[8]);
- const double *S = reinterpret_cast(in[9]);
- const double *bd = reinterpret_cast(in[10]);
- const double *bW = reinterpret_cast(in[11]);
- double *bt = reinterpret_cast(out[0]);
- double *bc = reinterpret_cast(out[1]);
- double *ba = reinterpret_cast(out[2]);
- double *bU = reinterpret_cast(out[3]);
- double *bV = reinterpret_cast(out[4]);
+ffi::Error FactorImpl(
+ ffi::Buffer t,
+ ffi::Buffer c,
+ ffi::Buffer a,
+ ffi::Buffer U,
+ ffi::Buffer V,
+ ffi::ResultBuffer d,
+ ffi::ResultBuffer W,
+ ffi::ResultBuffer S
+) {
+
+
+ const auto N = dim<0>(t);
+
+ const auto J = dim<0>(c);
+
+
+
+ if (dim<0>(t) != N) return ffi::Error::InvalidArgument("factor shape mismatch");
+
+ if (dim<0>(c) != J) return ffi::Error::InvalidArgument("factor shape mismatch");
+
+ if (dim<0>(a) != N) return ffi::Error::InvalidArgument("factor shape mismatch");
+
+ if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("factor shape mismatch");
+
+ if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor shape mismatch");
+
#define FIXED_SIZE_MAP(SIZE) \
- { \
- Eigen::Map t_(t, N, 1); \
- Eigen::Map> c_(c, J, 1); \
- Eigen::Map a_(a, N, 1); \
- Eigen::Map::value>> U_(U, N, J); \
- Eigen::Map::value>> V_(V, N, J); \
- Eigen::Map d_(d, N, 1); \
- Eigen::Map::value>> W_(W, N, J); \
- Eigen::Map::value>> S_(S, N, J * J); \
- Eigen::Map bd_(bd, N, 1); \
- Eigen::Map::value>> bW_(bW, N, J); \
- Eigen::Map bt_(bt, N, 1); \
- Eigen::Map> bc_(bc, J, 1); \
- Eigen::Map ba_(ba, N, 1); \
- Eigen::Map::value>> bU_(bU, N, J); \
- Eigen::Map::value>> bV_(bV, N, J); \
- celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \
- }
- UNWRAP_CASES_FEW
+ { \
+ Eigen::Map t_(t.typed_data(), N, 1); \
+ Eigen::Map c_(c.typed_data(), J, 1); \
+ Eigen::Map a_(a.typed_data(), N, 1); \
+ Eigen::Map::value>> U_(U.typed_data(), N, J); \
+ Eigen::Map::value>> V_(V.typed_data(), N, J); \
+ Eigen::Map d_(d->typed_data(), N, 1); \
+ Eigen::Map::value>> W_(W->typed_data(), N, J); \
+ Eigen::Map::value>> S_(S->typed_data(), N, J * J); \
+ d_.setZero(); \
+ W_.setZero(); \
+ S_.setZero(); \
+ try { \
+ celerite2::core::factor( t_, c_, a_, U_, V_, d_,W_,S_); \
+ } catch (const std::exception& e) { \
+ return ffi::Error::Internal(e.what()); \
+ } \
+ }
+ UNWRAP_CASES_FEW
#undef FIXED_SIZE_MAP
+
+ return ffi::Error::Success();
}
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ factor, FactorImpl,
+ ffi::Ffi::Bind()
+ .Arg>() // t
+ .Arg>() // c
+ .Arg>() // a
+ .Arg>() // U
+ .Arg>() // V
+ .Ret>() // d
+ .Ret>() // W
+ .Ret>() // S
+);
+
+
+ffi::Error factor_revImpl(
+ ffi::Buffer t,
+ ffi::Buffer c,
+ ffi::Buffer a,
+ ffi::Buffer U,
+ ffi::Buffer V,
+ ffi::Buffer d,
+ ffi::Buffer W,
+ ffi::Buffer S,
+ ffi::Buffer bd,
+ ffi::Buffer bW,
+ ffi::ResultBuffer bt,
+ ffi::ResultBuffer bc,
+ ffi::ResultBuffer ba,
+ ffi::ResultBuffer bU,
+ ffi::ResultBuffer bV
+) {
+
+ const auto N = dim<0>(t);
+
+ const auto J = dim<0>(c);
+
+
-auto solve_lower (void *out_tuple, const void **in) {
- void **out = reinterpret_cast(out_tuple);
- const int N = *reinterpret_cast(in[0]);
- const int J = *reinterpret_cast(in[1]);
- const int nrhs = *reinterpret_cast(in[2]);
+ if (dim<0>(t) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(c) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(a) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(d) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+
+ if (dim<0>(bd) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
+
+ if (dim<0>(bW) != N || dim<1>(bW) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch");
- const double *t = reinterpret_cast(in[3]);
- const double *c = reinterpret_cast(in[4]);
- const double *U = reinterpret_cast(in[5]);
- const double *W = reinterpret_cast(in[6]);
- const double *Y = reinterpret_cast(in[7]);
- double *Z = reinterpret_cast(out[0]);
- double *F = reinterpret_cast(out[1]);
#define FIXED_SIZE_MAP(SIZE) \
- { \
- Eigen::Map t_(t, N, 1); \
- Eigen::Map> c_(c, J, 1); \
- Eigen::Map::value>> U_(U, N, J); \
- Eigen::Map::value>> W_(W, N, J); \
- if (nrhs == 1) { \
- Eigen::Map Y_(Y, N, 1); \
- Eigen::Map Z_(Z, N, 1); \
- Eigen::Map::value>> F_(F, N, J); \
- Z_.setZero(); \
- celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \
- } else { \
- Eigen::Map> Y_(Y, N, nrhs); \
- Eigen::Map> Z_(Z, N, nrhs); \
- Eigen::Map> F_(F, N, J * nrhs); \
- Z_.setZero(); \
- celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \
+ { \
+ Eigen::Map t_(t.typed_data(), N, 1); \
+ Eigen::Map c_(c.typed_data(), J, 1); \
+ Eigen::Map a_(a.typed_data(), N, 1); \
+ Eigen::Map::value>> U_(U.typed_data(), N, J); \
+ Eigen::Map::value>> V_(V.typed_data(), N, J); \
+ Eigen::Map d_(d.typed_data(), N, 1); \
+ Eigen::Map::value>> W_(W.typed_data(), N, J); \
+ Eigen::Map::value>> S_(S.typed_data(), N, J * J); \
+ Eigen::Map bd_(bd.typed_data(), N, 1); \
+ Eigen::Map::value>> bW_(bW.typed_data(), N, J); \
+ Eigen::Map bt_(bt->typed_data(), N, 1); \
+ Eigen::Map bc_(bc->typed_data(), J, 1); \
+ Eigen::Map ba_(ba->typed_data(), N, 1); \
+ Eigen::Map::value>> bU_(bU->typed_data(), N, J); \
+ Eigen::Map::value>> bV_(bV->typed_data(), N, J); \
+ bt_.setZero(); \
+ bc_.setZero(); \
+ ba_.setZero(); \
+ bU_.setZero(); \
+ bV_.setZero(); \
+ try { \
+ celerite2::core::factor_rev( t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_,bc_,ba_,bU_,bV_); \
+ } catch (const std::exception& e) { \
+ return ffi::Error::Internal(e.what()); \
} \
- }
- UNWRAP_CASES_MOST
+ }
+ UNWRAP_CASES_FEW
#undef FIXED_SIZE_MAP
+
+ return ffi::Error::Success();
}
-auto solve_lower_rev (void *out_tuple, const void **in) {
- void **out = reinterpret_cast(out_tuple);
- const int N = *reinterpret_cast(in[0]);
- const int J = *reinterpret_cast(in[1]);
- const int nrhs = *reinterpret_cast(in[2]);
-
- const double *t = reinterpret_cast(in[3]);
- const double *c = reinterpret_cast(in[4]);
- const double *U = reinterpret_cast(in[5]);
- const double *W = reinterpret_cast(in[6]);
- const double *Y = reinterpret_cast(in[7]);
- const double *Z = reinterpret_cast(in[8]);
- const double *F = reinterpret_cast(in[9]);
- const double *bZ = reinterpret_cast(in[10]);
- double *bt = reinterpret_cast(out[0]);
- double *bc = reinterpret_cast(out[1]);
- double *bU = reinterpret_cast(out[2]);
- double *bW = reinterpret_cast(out[3]);
- double *bY = reinterpret_cast(out[4]);
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ factor_rev, factor_revImpl,
+ ffi::Ffi::Bind()
+ .Arg>() // t
+ .Arg>() // c
+ .Arg>() // a
+ .Arg>() // U
+ .Arg>() // V
+ .Arg>() // d
+ .Arg>() // W
+ .Arg>() // S
+ .Arg>() // bd
+ .Arg>() // bW
+ .Ret>() // bt
+ .Ret>() // bc
+ .Ret>() // ba
+ .Ret>() // bU
+ .Ret>() // bV
+);
+
+
+
+ffi::Error Solve_lowerImpl(
+ ffi::Buffer t,
+ ffi::Buffer c,
+ ffi::Buffer U,
+ ffi::Buffer W,
+ ffi::Buffer Y,
+ ffi::ResultBuffer Z,
+ ffi::ResultBuffer F
+) {
+
+
+ const auto N = dim<0>(t);
+
+ const auto J = dim<0>(c);
+
+ const auto nrhs = dim<1>(Y);
+
+
+
+ if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_lower shape mismatch");
+
+ if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch");
+
+ if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch");
+
+ if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch");
+
+ if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower shape mismatch");
+
#define FIXED_SIZE_MAP(SIZE) \
- { \
- Eigen::Map t_(t, N, 1); \
- Eigen::Map> c_(c, J, 1); \
- Eigen::Map::value>> U_(U, N, J); \
- Eigen::Map::value>> W_(W, N, J); \
- Eigen::Map bt_(bt, N, 1); \
- Eigen::Map> bc_(bc, J, 1); \
- Eigen::Map::value>> bU_(bU, N, J); \
- Eigen::Map::value>> bW_(bW, N, J); \
- if (nrhs == 1) { \
- Eigen::Map Y_(Y, N, 1); \
- Eigen::Map Z_(Z, N, 1); \
- Eigen::Map::value>> F_(F, N, J); \
- Eigen::Map bZ_(bZ, N, 1); \
- Eigen::Map bY_(bY, N, 1); \
- celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \
- } else { \
- Eigen::Map> Y_(Y, N, nrhs); \
- Eigen::Map> Z_(Z, N, nrhs); \
- Eigen::Map> F_(F, N, J * nrhs); \
- Eigen::Map> bZ_(bZ, N, nrhs); \
- Eigen::Map> bY_(bY, N, nrhs); \
- celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \
+ { \
+ Eigen::Map t_(t.typed_data(), N, 1); \
+ Eigen::Map c_(c.typed_data(), J, 1); \
+ Eigen::Map::value>> U_(U.typed_data(), N, J); \
+ Eigen::Map::value>> W_(W.typed_data(), N, J); \
+ Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \
+ Eigen::Map> Z_(Z->typed_data(), N, nrhs); \
+ Eigen::Map> F_(F->typed_data(), N, J*nrhs); \
+ Z_.setZero(); \
+ F_.setZero(); \
+ try { \
+ celerite2::core::solve_lower( t_, c_, U_, W_, Y_, Z_,F_); \
+ } catch (const std::exception& e) { \
+ return ffi::Error::Internal(e.what()); \
} \
- }
- UNWRAP_CASES_FEW
+ }
+ UNWRAP_CASES_FEW
#undef FIXED_SIZE_MAP
+
+ return ffi::Error::Success();
}
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ solve_lower, Solve_lowerImpl,
+ ffi::Ffi::Bind()
+ .Arg>() // t
+ .Arg>() // c
+ .Arg>() // U
+ .Arg>() // W
+ .Arg>() // Y
+ .Ret>() // Z
+ .Ret>() // F
+);
+
+
+ffi::Error solve_lower_revImpl(
+ ffi::Buffer t,
+ ffi::Buffer c,
+ ffi::Buffer U,
+ ffi::Buffer W,
+ ffi::Buffer Y,
+ ffi::Buffer Z,
+ ffi::Buffer F,
+ ffi::Buffer bZ,
+ ffi::ResultBuffer bt,
+ ffi::ResultBuffer bc,
+ ffi::ResultBuffer bU,
+ ffi::ResultBuffer bW,
+ ffi::ResultBuffer bY
+) {
+
+ const auto N = dim<0>(t);
+
+ const auto J = dim<0>(c);
+
+ const auto nrhs = dim<1>(Y);
+
+
+
+ if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
+
+ if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
+
+ if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
+
+ if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
+
+ if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
+
+ if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
-auto solve_upper (void *out_tuple, const void **in) {
- void **out = reinterpret_cast(out_tuple);
- const int N = *reinterpret_cast(in[0]);
- const int J = *reinterpret_cast(in[1]);
- const int nrhs = *reinterpret_cast(in[2]);
- const double *t = reinterpret_cast(in[3]);
- const double *c = reinterpret_cast(in[4]);
- const double *U = reinterpret_cast(in[5]);
- const double *W = reinterpret_cast(in[6]);
- const double *Y = reinterpret_cast(in[7]);
- double *Z = reinterpret_cast(out[0]);
- double *F = reinterpret_cast(out[1]);
+ if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch");
+
#define FIXED_SIZE_MAP(SIZE) \
- { \
- Eigen::Map t_(t, N, 1); \
- Eigen::Map> c_(c, J, 1); \
- Eigen::Map::value>> U_(U, N, J); \
- Eigen::Map::value>> W_(W, N, J); \
- if (nrhs == 1) { \
- Eigen::Map Y_(Y, N, 1); \
- Eigen::Map Z_(Z, N, 1); \
- Eigen::Map::value>> F_(F, N, J); \
- Z_.setZero(); \
- celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \
- } else { \
- Eigen::Map> Y_(Y, N, nrhs); \
- Eigen::Map> Z_(Z, N, nrhs); \
- Eigen::Map> F_(F, N, J * nrhs); \
- Z_.setZero(); \
- celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \
+ { \
+ Eigen::Map t_(t.typed_data(), N, 1); \
+ Eigen::Map c_(c.typed_data(), J, 1); \
+ Eigen::Map::value>> U_(U.typed_data(), N, J); \
+ Eigen::Map::value>> W_(W.typed_data(), N, J); \
+ Eigen::Map> Y_(Y.typed_data(), N, nrhs); \
+ Eigen::Map> Z_(Z.typed_data(), N, nrhs); \
+ Eigen::Map