diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 3bfda5c..0c0ebdb 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -19,13 +19,12 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] os: ["ubuntu-latest"] session: - "core" - "jax" - - "pymc3" - # - "pymc" + - "pymc" - "pymc_jax" steps: @@ -47,7 +46,7 @@ jobs: - name: Run tests run: | - python -m nox --non-interactive --error-on-missing-interpreter \ + python -m nox --non-interactive \ --session ${{ matrix.session }}-${{ matrix.python-version }} tests-pymc: @@ -68,12 +67,12 @@ jobs: environment-name: test-env create-args: >- mamba - python=3.10 + python=3.11 - name: Install nox run: python -m pip install -U nox - name: Run tests run: | - python -m nox --non-interactive --error-on-missing-interpreter \ - --session pymc_mamba-3.10 + python -m nox --non-interactive \ + --session pymc_mamba-3.11 diff --git a/.github/workflows/tutorials.yml.off b/.github/workflows/tutorials.yml.off index 74a3603..cd0278c 100644 --- a/.github/workflows/tutorials.yml.off +++ b/.github/workflows/tutorials.yml.off @@ -26,7 +26,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.11 - name: Install dependencies run: | diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index fa3b02e..49ed81a 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -44,7 +44,7 @@ jobs: - uses: actions/setup-python@v5 name: Install Python with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | python -m pip install -U pip diff --git a/CMakeLists.txt b/CMakeLists.txt index a4c9f09..dedab48 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ project( set(PYBIND11_NEWPYTHON ON) find_package(pybind11 CONFIG REQUIRED) +find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) include_directories( "c++/include" @@ -20,6 +21,20 @@ pybind11_add_module(backprop "python/celerite2/backprop.cpp") target_compile_features(backprop PUBLIC cxx_std_14) install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) -pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp") -target_compile_features(xla_ops PUBLIC cxx_std_14) -install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax") +option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON) +if(BUILD_JAX) + execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE XLA_DIR + RESULT_VARIABLE JAXLIB_RES) + if(JAXLIB_RES EQUAL 0 AND NOT "${XLA_DIR}" STREQUAL "") + message(STATUS "Building JAX extension with XLA include: ${XLA_DIR}") + pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp") + target_compile_features(xla_ops PUBLIC cxx_std_17) + target_include_directories(xla_ops PUBLIC "${XLA_DIR}") + install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax") + else() + message(STATUS "Skipping JAX extension (jax.ffi include_dir not found)") + endif() +endif() diff --git a/c++/include/celerite2/forward.hpp b/c++/include/celerite2/forward.hpp index b5b9a39..74e5f4e 100644 --- a/c++/include/celerite2/forward.hpp +++ b/c++/include/celerite2/forward.hpp @@ -297,7 +297,7 @@ void general_matmul_lower(const Eigen::MatrixBase &t1, // ( typedef typename LowRank::Scalar Scalar; typedef typename Eigen::internal::plain_col_type::type CoeffVector; - typedef typename Eigen::Matrix Inner; + typedef typename Eigen::Matrix Inner; Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols(); @@ -358,7 +358,7 @@ void general_matmul_upper(const Eigen::MatrixBase &t1, // ( typedef typename LowRank::Scalar Scalar; typedef typename Eigen::internal::plain_col_type::type CoeffVector; - typedef typename Eigen::Matrix Inner; + typedef typename Eigen::Matrix Inner; Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols(); diff --git a/noxfile.py b/noxfile.py index 7d9bfe8..f075c02 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,6 +1,6 @@ import nox -ALL_PYTHON_VS = ["3.8", "3.9", "3.10"] +ALL_PYTHON_VS = ["3.11"] TEST_CMD = ["python", "-m", "pytest", "-v"] @@ -23,22 +23,16 @@ def jax(session): _session_run(session, "python/test/jax") -@nox.session(python=ALL_PYTHON_VS) -def pymc3(session): - session.install(".[test,pymc3]") - _session_run(session, "python/test/pymc3") - - @nox.session(python=ALL_PYTHON_VS) def pymc(session): - session.install(".[test,pymc]") + session.install(".[test,pymc,jax]") _session_run(session, "python/test/pymc") @nox.session(python=ALL_PYTHON_VS, venv_backend="mamba") def pymc_mamba(session): session.conda_install("pymc", channel="conda-forge") - session.install(".[test,pymc]") + session.install(".[test,pymc,jax]") _session_run(session, "python/test/pymc") @@ -48,19 +42,6 @@ def pymc_jax(session): _session_run(session, "python/test/pymc/test_pymc_ops.py") -@nox.session(python=ALL_PYTHON_VS) -def full(session): - session.install(".[test,jax,pymc3,pymc]") - _session_run(session, "python/test") - - -@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba") -def full_mamba(session): - session.conda_install("jax", "pymc3", "pymc", channel="conda-forge") - session.install(".[test]") - _session_run(session, "python/test") - - @nox.session def lint(session): session.install("pre-commit") diff --git a/pyproject.toml b/pyproject.toml index dc9ea68..337e10d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "celerite2" description = "Fast and scalable Gaussian Processes in 1D" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.11" license = { text = "MIT License" } classifiers = [ "Development Status :: 4 - Beta", @@ -18,10 +18,8 @@ dependencies = ["numpy"] [project.optional-dependencies] test = ["pytest", "scipy", "celerite"] -pymc3 = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"] -theano = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"] -pymc = ["pymc>=5.9.2"] -jax = ["jax"] +pymc = ["pymc>=5.26.1"] +jax = ["jax>=0.8.0"] docs = [ "sphinx", "sphinx-material", @@ -31,11 +29,11 @@ docs = [ "matplotlib", "scipy", "emcee", - "pymc>=5", + "pymc>=5.26.1", "tqdm", "numpyro", ] -tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"] +tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"] [project.urls] "Homepage" = "https://celerite2.readthedocs.io" @@ -43,7 +41,7 @@ tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"] "Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues" [build-system] -requires = ["scikit-build-core", "numpy", "pybind11"] +requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0"] build-backend = "scikit_build_core.build" [tool.scikit-build] diff --git a/python/celerite2/jax/celerite2.py b/python/celerite2/jax/celerite2.py index b3cb9fc..a3e8122 100644 --- a/python/celerite2/jax/celerite2.py +++ b/python/celerite2/jax/celerite2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- __all__ = ["GaussianProcess", "ConditionalDistribution"] +from jax import lax from jax import numpy as np from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess @@ -35,7 +36,17 @@ def _do_compute(self, quiet): self._t, self._c, self._a, self._U, self._V ) self._log_det = np.sum(np.log(self._d)) - self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi)) + + def _bad(_): + return -np.inf, np.inf + + def _good(_): + return self._log_det, -0.5 * ( + self._log_det + self._size * np.log(2 * np.pi) + ) + + bad = np.any(self._d <= 0) | (~np.isfinite(self._log_det)) + self._log_det, self._norm = lax.cond(bad, _bad, _good, operand=None) def _check_sorted(self, t): return t diff --git a/python/celerite2/jax/ops.py b/python/celerite2/jax/ops.py index ffbc5f8..cc8e7b5 100644 --- a/python/celerite2/jax/ops.py +++ b/python/celerite2/jax/ops.py @@ -17,22 +17,21 @@ "general_matmul_lower", "general_matmul_upper", ] +import importlib +import importlib.resources as resources import json -from collections import OrderedDict from functools import partial from itertools import chain import numpy as np -import pkg_resources -from jax import core, lax +from jax import core, ffi, lax from jax import numpy as jnp from jax.core import ShapedArray +from jax.extend.core import Primitive from jax.interpreters import ad, mlir, xla from jax.lib import xla_client -from celerite2.jax import xla_ops - -xops = xla_client.ops +xla_ops = importlib.import_module("celerite2.jax.xla_ops") def factor(t, c, a, U, V): @@ -103,20 +102,7 @@ def _abstract_eval(spec, *args): def _lowering_rule(name, spec, ctx: mlir.LoweringRuleContext, *args): if any(a.dtype != np.float64 for a in chain(ctx.avals_in, ctx.avals_out)): raise ValueError(f"{spec['name']} requires float64 precision") - shapes = [a.shape for a in ctx.avals_in] - dims = OrderedDict( - (s["name"], shapes[s["coords"][0]][s["coords"][1]]) - for s in spec["dimensions"] - ) - return mlir.custom_call( - name, - operands=tuple(mlir.ir_constant(np.int32(v)) for v in dims.values()) - + args, - operand_layouts=[()] * len(dims) - + _default_layouts(aval.shape for aval in ctx.avals_in), - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - result_layouts=_default_layouts(aval.shape for aval in ctx.avals_out), - ).results + return ffi.ffi_lowering(name)(ctx, *args) def _default_layouts(shapes): @@ -184,11 +170,14 @@ def _rev_lowering_rule(name, spec, ctx, *args): def _build_op(name, spec): - xla_client.register_custom_call_target( - name, getattr(xla_ops, spec["name"])(), platform="cpu" + ffi.register_ffi_target( + name, + getattr(xla_ops, spec["name"])(), + platform="cpu", + api_version=1, ) - prim = core.Primitive(f"celerite2_{spec['name']}") + prim = Primitive(f"celerite2_{spec['name']}") prim.multiple_results = True prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(_abstract_eval, spec)) @@ -199,15 +188,16 @@ def _build_op(name, spec): if not spec["has_rev"]: return prim, None - xla_client.register_custom_call_target( + ffi.register_ffi_target( name + "_rev", getattr(xla_ops, f"{spec['name']}_rev")(), platform="cpu", + api_version=1, ) - jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp") + jvp_prim = Primitive(f"celerite2_{spec['name']}_jvp") jvp_prim.multiple_results = True - rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev") + rev_prim = Primitive(f"celerite2_{spec['name']}_rev") rev_prim.multiple_results = True # Setup a dummy JVP rule @@ -227,9 +217,7 @@ def _build_op(name, spec): return prim, rev_prim -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with resources.files("celerite2").joinpath("definitions.json").open("r") as f: definitions = {spec["name"]: spec for spec in json.load(f)} diff --git a/python/celerite2/jax/xla_ops.cpp b/python/celerite2/jax/xla_ops.cpp index 73a2a55..5d44178 100644 --- a/python/celerite2/jax/xla_ops.cpp +++ b/python/celerite2/jax/xla_ops.cpp @@ -1,579 +1,1011 @@ // NOTE: This file was autogenerated // NOTE: Changes should be made to the template +// Generated JAX FFI bindings for celerite2. +// Regenerate with: python python/spec/generate.py + #include -#include -#include +#include #include -#include + +#include "xla/ffi/api/ffi.h" + #include "../driver.hpp" namespace py = pybind11; +namespace ffi = xla::ffi; using namespace celerite2::driver; +// Helpers +template +inline Eigen::Index dim(const Buffer& buf) { + return static_cast(buf.dimensions()[Axis]); +} +template +inline Eigen::Index flat_cols(const Buffer& buf) { + const auto& dims = buf.dimensions(); + Eigen::Index cols = 1; + for (size_t i = 1; i < dims.size(); ++i) + cols *= static_cast(dims[i]); + return cols; +} -auto factor (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); +// === AUTO-GENERATED KERNELS === - const double *t = reinterpret_cast(in[2]); - const double *c = reinterpret_cast(in[3]); - const double *a = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - double *d = reinterpret_cast(out[0]); - double *W = reinterpret_cast(out[1]); - double *S = reinterpret_cast(out[2]); -#define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map a_(a, N, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map d_(d, N, 1); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map::value>> S_(S, N, J * J); \ - Eigen::Index flag = celerite2::core::factor(t_, c_, a_, U_, V_, d_, W_, S_); \ - if (flag) d_.setZero(); \ - } - UNWRAP_CASES_MOST -#undef FIXED_SIZE_MAP -} -auto factor_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - - const double *t = reinterpret_cast(in[2]); - const double *c = reinterpret_cast(in[3]); - const double *a = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *d = reinterpret_cast(in[7]); - const double *W = reinterpret_cast(in[8]); - const double *S = reinterpret_cast(in[9]); - const double *bd = reinterpret_cast(in[10]); - const double *bW = reinterpret_cast(in[11]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *ba = reinterpret_cast(out[2]); - double *bU = reinterpret_cast(out[3]); - double *bV = reinterpret_cast(out[4]); +ffi::Error FactorImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer a, + ffi::Buffer U, + ffi::Buffer V, + ffi::ResultBuffer d, + ffi::ResultBuffer W, + ffi::ResultBuffer S +) { + + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); + + if (dim<0>(a) != N) return ffi::Error::InvalidArgument("factor shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); + + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map a_(a, N, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map d_(d, N, 1); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map::value>> S_(S, N, J * J); \ - Eigen::Map bd_(bd, N, 1); \ - Eigen::Map::value>> bW_(bW, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map ba_(ba, N, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - celerite2::core::factor_rev(t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_, bc_, ba_, bU_, bV_); \ - } - UNWRAP_CASES_FEW + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map a_(a.typed_data(), N, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map d_(d->typed_data(), N, 1); \ + Eigen::Map::value>> W_(W->typed_data(), N, J); \ + Eigen::Map::value>> S_(S->typed_data(), N, J * J); \ + d_.setZero(); \ + W_.setZero(); \ + S_.setZero(); \ + try { \ + celerite2::core::factor( t_, c_, a_, U_, V_, d_,W_,S_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + factor, FactorImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // a + .Arg>() // U + .Arg>() // V + .Ret>() // d + .Ret>() // W + .Ret>() // S +); + + +ffi::Error factor_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer a, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer d, + ffi::Buffer W, + ffi::Buffer S, + ffi::Buffer bd, + ffi::Buffer bW, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer ba, + ffi::ResultBuffer bU, + ffi::ResultBuffer bV +) { + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + -auto solve_lower (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(a) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(d) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + + if (dim<0>(bd) != N) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); + + if (dim<0>(bW) != N || dim<1>(bW) != J) return ffi::Error::InvalidArgument("factor_rev shape mismatch"); - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_lower(t_, c_, U_, W_, Y_, Z_, F_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map a_(a.typed_data(), N, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map d_(d.typed_data(), N, 1); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map::value>> S_(S.typed_data(), N, J * J); \ + Eigen::Map bd_(bd.typed_data(), N, 1); \ + Eigen::Map::value>> bW_(bW.typed_data(), N, J); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map ba_(ba->typed_data(), N, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ + bt_.setZero(); \ + bc_.setZero(); \ + ba_.setZero(); \ + bU_.setZero(); \ + bV_.setZero(); \ + try { \ + celerite2::core::factor_rev( t_, c_, a_, U_, V_, d_, W_, S_, bd_, bW_, bt_,bc_,ba_,bU_,bV_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_MOST + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto solve_lower_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bW = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + factor_rev, factor_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // a + .Arg>() // U + .Arg>() // V + .Arg>() // d + .Arg>() // W + .Arg>() // S + .Arg>() // bd + .Arg>() // bW + .Ret>() // bt + .Ret>() // bc + .Ret>() // ba + .Ret>() // bU + .Ret>() // bV +); + + + +ffi::Error Solve_lowerImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bW_(bW, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::solve_lower_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ + Z_.setZero(); \ + F_.setZero(); \ + try { \ + celerite2::core::solve_lower( t_, c_, U_, W_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_FEW + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_lower, Solve_lowerImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error solve_lower_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bW, + ffi::ResultBuffer bY +) { + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); + + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); + + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); -auto solve_upper (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("solve_lower_rev shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::solve_upper(t_, c_, U_, W_, Y_, Z_, F_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bW_(bW->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bW_.setZero(); \ + bY_.setZero(); \ + try { \ + celerite2::core::solve_lower_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_MOST + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto solve_upper_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *W = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bW = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_lower_rev, solve_lower_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bW + .Ret>() // bY +); + + + +ffi::Error Solve_upperImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> W_(W, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bW_(bW, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::solve_upper_rev(t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bW_, bY_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ + Z_.setZero(); \ + F_.setZero(); \ + try { \ + celerite2::core::solve_upper( t_, c_, U_, W_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_FEW + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_upper, Solve_upperImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error solve_upper_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer W, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bW, + ffi::ResultBuffer bY +) { + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + -auto matmul_lower (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); + + if (dim<0>(W) != N || dim<1>(W) != J) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); + + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); + + + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("solve_upper_rev shape mismatch"); - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_lower(t_, c_, U_, V_, Y_, Z_, F_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> W_(W.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bW_(bW->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bW_.setZero(); \ + bY_.setZero(); \ + try { \ + celerite2::core::solve_upper_rev( t_, c_, U_, W_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bW_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_MOST + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto matmul_lower_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bV = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + solve_upper_rev, solve_upper_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // W + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bW + .Ret>() // bY +); + + + +ffi::Error Matmul_lowerImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::matmul_lower_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ + Z_.setZero(); \ + F_.setZero(); \ + try { \ + celerite2::core::matmul_lower( t_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_FEW + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_lower, Matmul_lowerImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error matmul_lower_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bV, + ffi::ResultBuffer bY +) { -auto matmul_upper (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); + + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); + + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); + + + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("matmul_lower_rev shape mismatch"); - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::matmul_upper(t_, c_, U_, V_, Y_, Z_, F_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bV_.setZero(); \ + bY_.setZero(); \ + try { \ + celerite2::core::matmul_lower_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_MOST + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto matmul_upper_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - const int nrhs = *reinterpret_cast(in[2]); - - const double *t = reinterpret_cast(in[3]); - const double *c = reinterpret_cast(in[4]); - const double *U = reinterpret_cast(in[5]); - const double *V = reinterpret_cast(in[6]); - const double *Y = reinterpret_cast(in[7]); - const double *Z = reinterpret_cast(in[8]); - const double *F = reinterpret_cast(in[9]); - const double *bZ = reinterpret_cast(in[10]); - double *bt = reinterpret_cast(out[0]); - double *bc = reinterpret_cast(out[1]); - double *bU = reinterpret_cast(out[2]); - double *bV = reinterpret_cast(out[3]); - double *bY = reinterpret_cast(out[4]); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_lower_rev, matmul_lower_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bV + .Ret>() // bY +); + + + +ffi::Error Matmul_upperImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t_(t, N, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, N, J); \ - Eigen::Map bt_(bt, N, 1); \ - Eigen::Map> bc_(bc, J, 1); \ - Eigen::Map::value>> bU_(bU, N, J); \ - Eigen::Map::value>> bV_(bV, N, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, N, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, N, J); \ - Eigen::Map bZ_(bZ, N, 1); \ - Eigen::Map bY_(bY, N, 1); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ - } else { \ - Eigen::Map> Y_(Y, N, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, N, J * nrhs); \ - Eigen::Map> bZ_(bZ, N, nrhs); \ - Eigen::Map> bY_(bY, N, nrhs); \ - celerite2::core::matmul_upper_rev(t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_, bc_, bU_, bV_, bY_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), N, J*nrhs); \ + Z_.setZero(); \ + F_.setZero(); \ + try { \ + celerite2::core::matmul_upper( t_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_FEW + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_upper, Matmul_upperImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + +ffi::Error matmul_upper_revImpl( + ffi::Buffer t, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::Buffer Z, + ffi::Buffer F, + ffi::Buffer bZ, + ffi::ResultBuffer bt, + ffi::ResultBuffer bc, + ffi::ResultBuffer bU, + ffi::ResultBuffer bV, + ffi::ResultBuffer bY +) { + + const auto N = dim<0>(t); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t) != N) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); -auto general_matmul_lower (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int M = *reinterpret_cast(in[1]); - const int J = *reinterpret_cast(in[2]); - const int nrhs = *reinterpret_cast(in[3]); + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); + + if (dim<0>(V) != N || dim<1>(V) != J) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); + + if (dim<0>(Y) != N || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); + + if (dim<0>(Z) != N || dim<1>(Z) != nrhs) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); + + + if (dim<0>(bZ) != N || dim<1>(bZ) != nrhs) return ffi::Error::InvalidArgument("matmul_upper_rev shape mismatch"); - const double *t1 = reinterpret_cast(in[4]); - const double *t2 = reinterpret_cast(in[5]); - const double *c = reinterpret_cast(in[6]); - const double *U = reinterpret_cast(in[7]); - const double *V = reinterpret_cast(in[8]); - const double *Y = reinterpret_cast(in[9]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_(t1, N, 1); \ - Eigen::Map t2_(t2, M, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, M, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, M, J); \ - Z_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, M, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, M, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::general_matmul_lower(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + { \ + Eigen::Map t_(t.typed_data(), N, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), N, J); \ + Eigen::Map> Y_(Y.typed_data(), N, nrhs); \ + Eigen::Map> Z_(Z.typed_data(), N, nrhs); \ + Eigen::Map> F_(F.typed_data(), N, J*nrhs); \ + Eigen::Map> bZ_(bZ.typed_data(), N, nrhs); \ + Eigen::Map bt_(bt->typed_data(), N, 1); \ + Eigen::Map bc_(bc->typed_data(), J, 1); \ + Eigen::Map::value>> bU_(bU->typed_data(), N, J); \ + Eigen::Map::value>> bV_(bV->typed_data(), N, J); \ + Eigen::Map> bY_(bY->typed_data(), N, nrhs); \ + bt_.setZero(); \ + bc_.setZero(); \ + bU_.setZero(); \ + bV_.setZero(); \ + bY_.setZero(); \ + try { \ + celerite2::core::matmul_upper_rev( t_, c_, U_, V_, Y_, Z_, F_, bZ_, bt_,bc_,bU_,bV_,bY_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_MOST + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -auto general_matmul_upper (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int M = *reinterpret_cast(in[1]); - const int J = *reinterpret_cast(in[2]); - const int nrhs = *reinterpret_cast(in[3]); - - const double *t1 = reinterpret_cast(in[4]); - const double *t2 = reinterpret_cast(in[5]); - const double *c = reinterpret_cast(in[6]); - const double *U = reinterpret_cast(in[7]); - const double *V = reinterpret_cast(in[8]); - const double *Y = reinterpret_cast(in[9]); - double *Z = reinterpret_cast(out[0]); - double *F = reinterpret_cast(out[1]); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + matmul_upper_rev, matmul_upper_revImpl, + ffi::Ffi::Bind() + .Arg>() // t + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Arg>() // Z + .Arg>() // F + .Arg>() // bZ + .Ret>() // bt + .Ret>() // bc + .Ret>() // bU + .Ret>() // bV + .Ret>() // bY +); + + + +ffi::Error General_matmul_lowerImpl( + ffi::Buffer t1, + ffi::Buffer t2, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { + + + const auto N = dim<0>(t1); + + const auto M = dim<0>(t2); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t1) != N) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + + if (dim<0>(t2) != M) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + + if (dim<0>(V) != M || dim<1>(V) != J) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + + if (dim<0>(Y) != M || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_lower shape mismatch"); + #define FIXED_SIZE_MAP(SIZE) \ - { \ - Eigen::Map t1_(t1, N, 1); \ - Eigen::Map t2_(t2, M, 1); \ - Eigen::Map> c_(c, J, 1); \ - Eigen::Map::value>> U_(U, N, J); \ - Eigen::Map::value>> V_(V, M, J); \ - if (nrhs == 1) { \ - Eigen::Map Y_(Y, M, 1); \ - Eigen::Map Z_(Z, N, 1); \ - Eigen::Map::value>> F_(F, M, J); \ - Z_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ - } else { \ - Eigen::Map> Y_(Y, M, nrhs); \ - Eigen::Map> Z_(Z, N, nrhs); \ - Eigen::Map> F_(F, M, J * nrhs); \ - Z_.setZero(); \ - celerite2::core::general_matmul_upper(t1_, t2_, c_, U_, V_, Y_, Z_, F_); \ + { \ + Eigen::Map t1_(t1.typed_data(), N, 1); \ + Eigen::Map t2_(t2.typed_data(), M, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), M, J); \ + Eigen::Map> Y_(Y.typed_data(), M, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ + Z_.setZero(); \ + F_.setZero(); \ + try { \ + celerite2::core::general_matmul_lower( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ } \ - } - UNWRAP_CASES_MOST + } + UNWRAP_CASES_MOST #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } +XLA_FFI_DEFINE_HANDLER_SYMBOL( + general_matmul_lower, General_matmul_lowerImpl, + ffi::Ffi::Bind() + .Arg>() // t1 + .Arg>() // t2 + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + -// https://en.cppreference.com/w/cpp/numeric/bit_cast -template -typename std::enable_if::value && std::is_trivially_copyable::value, To>::type -bit_cast(const From &src) noexcept { - static_assert(std::is_trivially_constructible::value, - "This implementation additionally requires destination type to be trivially constructible"); +ffi::Error General_matmul_upperImpl( + ffi::Buffer t1, + ffi::Buffer t2, + ffi::Buffer c, + ffi::Buffer U, + ffi::Buffer V, + ffi::Buffer Y, + ffi::ResultBuffer Z, + ffi::ResultBuffer F +) { - To dst; - memcpy(&dst, &src, sizeof(To)); - return dst; + + const auto N = dim<0>(t1); + + const auto M = dim<0>(t2); + + const auto J = dim<0>(c); + + const auto nrhs = dim<1>(Y); + + + + if (dim<0>(t1) != N) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + + if (dim<0>(t2) != M) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + + if (dim<0>(c) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + + if (dim<0>(U) != N || dim<1>(U) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + + if (dim<0>(V) != M || dim<1>(V) != J) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + + if (dim<0>(Y) != M || dim<1>(Y) != nrhs) return ffi::Error::InvalidArgument("general_matmul_upper shape mismatch"); + + +#define FIXED_SIZE_MAP(SIZE) \ + { \ + Eigen::Map t1_(t1.typed_data(), N, 1); \ + Eigen::Map t2_(t2.typed_data(), M, 1); \ + Eigen::Map c_(c.typed_data(), J, 1); \ + Eigen::Map::value>> U_(U.typed_data(), N, J); \ + Eigen::Map::value>> V_(V.typed_data(), M, J); \ + Eigen::Map> Y_(Y.typed_data(), M, dim<1>(Y)); \ + Eigen::Map> Z_(Z->typed_data(), N, nrhs); \ + Eigen::Map> F_(F->typed_data(), M, J*nrhs); \ + Z_.setZero(); \ + F_.setZero(); \ + try { \ + celerite2::core::general_matmul_upper( t1_, t2_, c_, U_, V_, Y_, Z_,F_); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ + } + UNWRAP_CASES_MOST +#undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } -template -py::capsule encapsulate_function(T* fn) { - return py::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + general_matmul_upper, General_matmul_upperImpl, + ffi::Ffi::Bind() + .Arg>() // t1 + .Arg>() // t2 + .Arg>() // c + .Arg>() // U + .Arg>() // V + .Arg>() // Y + .Ret>() // Z + .Ret>() // F +); + + + +// Pybind -------------------------------------------------------------------- +template +py::capsule Encapsulate() { + return py::capsule(reinterpret_cast(Fn), "xla._CUSTOM_CALL_TARGET"); } PYBIND11_MODULE(xla_ops, m) { - m.def("factor", []() { - return encapsulate_function(factor); - }); - m.def("factor_rev", []() { - return encapsulate_function(factor_rev); - }); - m.def("solve_lower", []() { - return encapsulate_function(solve_lower); - }); - m.def("solve_lower_rev", []() { - return encapsulate_function(solve_lower_rev); - }); - m.def("solve_upper", []() { - return encapsulate_function(solve_upper); - }); - m.def("solve_upper_rev", []() { - return encapsulate_function(solve_upper_rev); - }); - m.def("matmul_lower", []() { - return encapsulate_function(matmul_lower); - }); - m.def("matmul_lower_rev", []() { - return encapsulate_function(matmul_lower_rev); - }); - m.def("matmul_upper", []() { - return encapsulate_function(matmul_upper); - }); - m.def("matmul_upper_rev", []() { - return encapsulate_function(matmul_upper_rev); - }); - m.def("general_matmul_lower", []() { - return encapsulate_function(general_matmul_lower); - }); - m.def("general_matmul_upper", []() { - return encapsulate_function(general_matmul_upper); - }); + + m.def("factor", &Encapsulate); + m.def("factor_rev", &Encapsulate); + + m.def("solve_lower", &Encapsulate); + m.def("solve_lower_rev", &Encapsulate); + + m.def("solve_upper", &Encapsulate); + m.def("solve_upper_rev", &Encapsulate); + + m.def("matmul_lower", &Encapsulate); + m.def("matmul_lower_rev", &Encapsulate); + + m.def("matmul_upper", &Encapsulate); + m.def("matmul_upper_rev", &Encapsulate); + + m.def("general_matmul_lower", &Encapsulate); + + m.def("general_matmul_upper", &Encapsulate); + } diff --git a/python/celerite2/pymc/ops.py b/python/celerite2/pymc/ops.py index 2cceef5..cce6928 100644 --- a/python/celerite2/pymc/ops.py +++ b/python/celerite2/pymc/ops.py @@ -11,14 +11,15 @@ "general_matmul_upper", ] +import importlib.resources as resources import json from itertools import chain import numpy as np -import pkg_resources import pytensor import pytensor.tensor as pt from pytensor.graph import basic, op +from pytensor.link.jax.dispatch import jax_funcify import celerite2.backprop as backprop import celerite2.driver as driver @@ -140,9 +141,7 @@ def grad(self, inputs, gradients): return self.rev_op(*chain(inputs, outputs, grads)) -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with resources.files("celerite2").joinpath("definitions.json").open("r") as f: definitions = {spec["name"]: spec for spec in json.load(f)} @@ -158,3 +157,70 @@ def grad(self, inputs, gradients): general_matmul_upper = _CeleriteOp( "general_matmul_upper_fwd", definitions["general_matmul_upper"] ) + + +# JAX conversion for PyTensor JAX linker ------------------------------------- +@jax_funcify.register(_CeleriteOp) +def _jax_funcify_celerite(op, node, **kwargs): + """Map celerite2 PyTensor ops to their JAX counterparts.""" + + # Lazy import to avoid circular import during module import + import celerite2.jax.ops as jax_ops + + def factor_fwd(t, c, a, U, V): + return jax_ops.factor_p.bind(t, c, a, U, V) + + def factor_rev(t, c, a, U, V, d, W, S, bd, bW): + return jax_ops.factor_rev_p.bind(t, c, a, U, V, d, W, S, bd, bW) + + def solve_lower_fwd(t, c, U, W, Y): + return jax_ops.solve_lower_p.bind(t, c, U, W, Y) + + def solve_lower_rev(t, c, U, W, Y, Z, F, bZ): + return jax_ops.solve_lower_rev_p.bind(t, c, U, W, Y, Z, F, bZ) + + def solve_upper_fwd(t, c, U, W, Y): + return jax_ops.solve_upper_p.bind(t, c, U, W, Y) + + def solve_upper_rev(t, c, U, W, Y, Z, F, bZ): + return jax_ops.solve_upper_rev_p.bind(t, c, U, W, Y, Z, F, bZ) + + def matmul_lower_fwd(t, c, U, V, Y): + return jax_ops.matmul_lower_p.bind(t, c, U, V, Y) + + def matmul_lower_rev(t, c, U, V, Y, Z, F, bZ): + return jax_ops.matmul_lower_rev_p.bind(t, c, U, V, Y, Z, F, bZ) + + def matmul_upper_fwd(t, c, U, V, Y): + return jax_ops.matmul_upper_p.bind(t, c, U, V, Y) + + def matmul_upper_rev(t, c, U, V, Y, Z, F, bZ): + return jax_ops.matmul_upper_rev_p.bind(t, c, U, V, Y, Z, F, bZ) + + def general_matmul_lower_fwd(t1, t2, c, U, V, Y): + return jax_ops.general_matmul_lower_p.bind(t1, t2, c, U, V, Y) + + def general_matmul_upper_fwd(t1, t2, c, U, V, Y): + return jax_ops.general_matmul_upper_p.bind(t1, t2, c, U, V, Y) + + mapping = { + "factor_fwd": factor_fwd, + "factor_rev": factor_rev, + "solve_lower_fwd": solve_lower_fwd, + "solve_lower_rev": solve_lower_rev, + "solve_upper_fwd": solve_upper_fwd, + "solve_upper_rev": solve_upper_rev, + "matmul_lower_fwd": matmul_lower_fwd, + "matmul_lower_rev": matmul_lower_rev, + "matmul_upper_fwd": matmul_upper_fwd, + "matmul_upper_rev": matmul_upper_rev, + "general_matmul_lower_fwd": general_matmul_lower_fwd, + "general_matmul_upper_fwd": general_matmul_upper_fwd, + } + + try: + return mapping[op.name] + except KeyError: + raise NotImplementedError( + f"No JAX conversion registered for {op.name}" + ) diff --git a/python/celerite2/pymc3/ops.py b/python/celerite2/pymc3/ops.py index c945858..6d42ece 100644 --- a/python/celerite2/pymc3/ops.py +++ b/python/celerite2/pymc3/ops.py @@ -11,11 +11,11 @@ "general_matmul_upper", ] +import importlib.resources as resources import json from itertools import chain import numpy as np -import pkg_resources import theano import theano.tensor as tt from theano.graph import basic, op @@ -140,9 +140,7 @@ def grad(self, inputs, gradients): return self.rev_op(*chain(inputs, outputs, grads)) -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with resources.files("celerite2").joinpath("definitions.json").open("r") as f: definitions = {spec["name"]: spec for spec in json.load(f)} diff --git a/python/spec/generate.py b/python/spec/generate.py index 14323ad..341e588 100644 --- a/python/spec/generate.py +++ b/python/spec/generate.py @@ -5,7 +5,6 @@ import os from pathlib import Path -import pkg_resources from jinja2 import Environment, FileSystemLoader, select_autoescape base = Path(os.path.dirname(os.path.abspath(__file__))) @@ -13,11 +12,10 @@ env = Environment( loader=FileSystemLoader(base / "templates"), autoescape=select_autoescape(["cpp"]), + keep_trailing_newline=True, ) -with open( - pkg_resources.resource_filename("celerite2", "definitions.json"), "r" -) as f: +with open(base.parent / "celerite2" / "definitions.json", "r") as f: data = json.load(f) for n in range(len(data)): @@ -45,7 +43,10 @@ for name in ["driver.cpp", "backprop.cpp", "jax/xla_ops.cpp"]: template = env.get_template(name) result = template.render(spec=data) + clean_result = ( + "\n".join(line.rstrip() for line in result.splitlines()) + "\n" + ) with open(base.parent / "celerite2" / name, "w") as f: f.write("// NOTE: This file was autogenerated\n") f.write("// NOTE: Changes should be made to the template\n\n") - f.write(result) + f.write(clean_result) diff --git a/python/spec/templates/jax/xla_ops.cpp b/python/spec/templates/jax/xla_ops.cpp index bd22248..3a9b73a 100644 --- a/python/spec/templates/jax/xla_ops.cpp +++ b/python/spec/templates/jax/xla_ops.cpp @@ -1,169 +1,209 @@ +// Generated JAX FFI bindings for celerite2. +// Regenerate with: python python/spec/generate.py + #include -#include -#include +#include #include -#include + +#include "xla/ffi/api/ffi.h" + #include "../driver.hpp" namespace py = pybind11; +namespace ffi = xla::ffi; using namespace celerite2::driver; +// Helpers +template +inline Eigen::Index dim(const Buffer& buf) { + return static_cast(buf.dimensions()[Axis]); +} +template +inline Eigen::Index flat_cols(const Buffer& buf) { + const auto& dims = buf.dimensions(); + Eigen::Index cols = 1; + for (size_t i = 1; i < dims.size(); ++i) + cols *= static_cast(dims[i]); + return cols; +} + +// === AUTO-GENERATED KERNELS === {% for mod in spec %} -auto {{mod.name}} (void *out_tuple, const void **in) { - {% set input_index = mod.dimensions|length - 1 -%} - void **out = reinterpret_cast(out_tuple); - {%- for dim in mod.dimensions %} - const int {{dim.name}} = *reinterpret_cast(in[{{loop.index - 1}}]); - {%- endfor %} - {% for arg in mod.inputs %} - const double *{{arg.name}} = reinterpret_cast(in[{{ input_index + loop.index }}]); - {%- endfor %} - {%- for arg in mod.outputs + mod.extra_outputs %} - double *{{arg.name}} = reinterpret_cast(out[{{ loop.index - 1 }}]); - {%- endfor %} + +ffi::Error {{mod.name|capitalize}}Impl( +{%- for arg in mod.inputs %} + ffi::Buffer {{arg.name}}{% if not loop.last or mod.outputs or mod.extra_outputs %},{% endif %} +{%- endfor -%} +{%- for arg in mod.outputs %} + ffi::ResultBuffer {{arg.name}}{% if not loop.last or mod.extra_outputs %},{% endif %} +{%- endfor -%} +{%- for arg in mod.extra_outputs %} + ffi::ResultBuffer {{arg.name}}{% if not loop.last %},{% endif %} +{%- endfor %} +) { + {# Dimension aliases for readability #} + {% for dim in mod.dimensions %} + const auto {{dim.name}} = dim<{{dim.coords[1]}}>({{mod.inputs[dim.coords[0]].name}}); + {% endfor %} + {%- set nrhs_arg = None -%} + {%- for a in mod.inputs + mod.outputs + mod.extra_outputs -%} + {%- if a.shape|length >= 2 and a.shape[-1] == "nrhs" and nrhs_arg is none -%} + {%- set nrhs_arg = a -%} + {%- endif -%} + {%- endfor -%} + {%- if nrhs_arg is not none %} + const auto nrhs = dim<{{ nrhs_arg.shape|length - 1 }}>({{ nrhs_arg.name }}); + {%- endif %} + {# Minimal shape checks - rely on driver.hpp order helper #} + {% for arg in mod.inputs %} + {%- if arg.shape|length == 1 %} + if (dim<0>({{arg.name}}) != {{arg.shape[0]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); + {%- elif arg.shape|length == 2 %} + if (dim<0>({{arg.name}}) != {{arg.shape[0]}} || dim<1>({{arg.name}}) != {{arg.shape[1]}}) return ffi::Error::InvalidArgument("{{mod.name}} shape mismatch"); + {%- endif %} + {% endfor %} #define FIXED_SIZE_MAP(SIZE) \ - { \ - {%- for arg in mod.inputs + mod.outputs + mod.extra_outputs %} - {%- if arg.shape|length == 1 -%} - {%- if arg.shape[0] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, J, 1); \ + { \ + {%- for arg in mod.inputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "nrhs" %} + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim<1>({{arg.name}})); \ {%- else %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {%- endif -%} - {%- elif arg.shape|length == 2 -%} - {%- if arg.shape[1] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {%- endif -%} - {%- else -%} - {%- if arg.shape[2] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J * J); \ - {%- endif -%} - {%- endif -%} - {% endfor %} - {%- if mod.name == "factor" %} - Eigen::Index flag = celerite2::core::{{mod.name}}({% for val in mod.inputs + mod.outputs + mod.extra_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - if (flag) d_.setZero(); \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, dim<1>({{arg.name}})); \ + {%- endif %} + {%- endfor %} + {%- for arg in mod.outputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "nrhs" %} + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- else %} - if (nrhs == 1) { \ - {% for arg in mod.inputs + mod.outputs + mod.extra_outputs %} - {%- if arg.shape|length == 2 and arg.shape[1] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {% elif arg.shape|length == 3 and arg.shape[2] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {% endif -%} - {% endfor -%} - {% for arg in mod.outputs -%} - {{arg.name}}_.setZero(); \ - {% endfor -%} - celerite2::core::{{mod.name}}({% for val in mod.inputs + mod.outputs + mod.extra_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } else { \ - {% for arg in mod.inputs + mod.outputs + mod.extra_outputs %} - {%- if (arg.shape|length == 2 and arg.shape[1] == "nrhs") or (arg.shape|length == 3 and arg.shape[2] == "nrhs") -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, {{arg.shape[1:]|join(" * ")}}); \ - {% endif -%} - {% endfor -%} - {% for arg in mod.outputs -%} - {{arg.name}}_.setZero(); \ - {% endfor -%} - celerite2::core::{{mod.name}}({% for val in mod.inputs + mod.outputs + mod.extra_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} - } - UNWRAP_CASES_MOST -#undef FIXED_SIZE_MAP -} -{%- if mod.has_rev %} -auto {{mod.name}}_rev (void *out_tuple, const void **in) { - void **out = reinterpret_cast(out_tuple); - const int N = *reinterpret_cast(in[0]); - const int J = *reinterpret_cast(in[1]); - {%- if mod.name == "factor" -%} - {%- set input_index = 1 -%} + {%- endfor %} + {%- for arg in mod.extra_outputs %} + {%- if arg.shape|length == 3 and arg.shape[1] == "J" and arg.shape[2] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J * J); \ {%- else %} - const int nrhs = *reinterpret_cast(in[2]); - {%- set input_index = 2 -%} + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{ '*'.join(arg.shape[1:]) }}); \ {%- endif %} - {% for arg in mod.rev_inputs %} - const double *{{arg.name}} = reinterpret_cast(in[{{ input_index + loop.index }}]); {%- endfor %} - {%- for arg in mod.rev_outputs %} - double *{{arg.name}} = reinterpret_cast(out[{{ loop.index - 1 }}]); + {%- for arg in mod.outputs + mod.extra_outputs %} + {{arg.name}}_.setZero(); \ {%- endfor %} + try { \ + celerite2::core::{{mod.name}}({%- for arg in mod.inputs %} {{arg.name}}_{% if not loop.last or mod.outputs or mod.extra_outputs %},{% endif %}{%- endfor %}{%- if mod.outputs or mod.extra_outputs %} {% endif %}{%- for arg in mod.outputs + mod.extra_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ + } + UNWRAP_CASES_{{ "FEW" if mod.has_rev else "MOST" }} +#undef FIXED_SIZE_MAP + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + {{mod.name}}, {{mod.name|capitalize}}Impl, + ffi::Ffi::Bind() +{%- for arg in mod.inputs %} + .Arg>() // {{arg.name}} +{%- endfor %} +{%- for arg in mod.outputs + mod.extra_outputs %} + .Ret>() // {{arg.name}} +{%- endfor %} +); +{% if mod.has_rev %} + +ffi::Error {{mod.name}}_revImpl( +{%- for arg in mod.rev_inputs %} + ffi::Buffer {{arg.name}}{% if not loop.last or mod.rev_outputs %},{% endif %} +{%- endfor -%} +{%- for arg in mod.rev_outputs %} + ffi::ResultBuffer {{arg.name}}{% if not loop.last %},{% endif %} +{%- endfor %} +) { + {% for dim in mod.dimensions %} + const auto {{dim.name}} = dim<{{dim.coords[1]}}>({{mod.inputs[dim.coords[0]].name}}); + {% endfor %} + {# Minimal shape checks #} + {% for arg in mod.rev_inputs %} + {%- if arg.shape|length == 1 %} + if (dim<0>({{arg.name}}) != {{arg.shape[0]}}) return ffi::Error::InvalidArgument("{{mod.name}}_rev shape mismatch"); + {%- elif arg.shape|length == 2 %} + if (dim<0>({{arg.name}}) != {{arg.shape[0]}} || dim<1>({{arg.name}}) != {{arg.shape[1]}}) return ffi::Error::InvalidArgument("{{mod.name}}_rev shape mismatch"); + {%- endif %} + {% endfor %} #define FIXED_SIZE_MAP(SIZE) \ - { \ - {%- for arg in mod.rev_inputs + mod.rev_outputs %} - {%- if arg.shape|length == 1 -%} - {%- if arg.shape[0] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, J, 1); \ + { \ + {%- for arg in mod.rev_inputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J); \ + {%- elif arg.shape|length == 3 and arg.shape[1] == "J" and arg.shape[2] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, J * J); \ + {%- elif arg.shape|length == 3 %} + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, {{ '*'.join(arg.shape[1:]) }}); \ {%- else %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {%- endif -%} - {%- elif arg.shape|length == 2 -%} - {%- if arg.shape[1] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {%- endif -%} - {%- else -%} - {%- if arg.shape[2] == "J" %} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J * J); \ - {%- endif -%} - {%- endif -%} - {% endfor %} - {%- if mod.name == "factor" %} - celerite2::core::{{mod.name}}_rev({% for val in mod.rev_inputs + mod.rev_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ + Eigen::Map> {{arg.name}}_({{arg.name}}.typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ + {%- endif %} + {%- endfor %} + {%- for arg in mod.rev_outputs %} + {%- if arg.shape|length == 1 %} + Eigen::Map {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, 1); \ + {%- elif arg.shape|length == 2 and arg.shape[1] == "J" %} + Eigen::Map::value>> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, J); \ {%- else %} - if (nrhs == 1) { \ - {% for arg in mod.rev_inputs + mod.rev_outputs %} - {%- if arg.shape|length == 2 and arg.shape[1] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::VectorXd> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, 1); \ - {% elif arg.shape|length == 3 and arg.shape[2] == "nrhs" -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix::value>> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, J); \ - {% endif -%} - {% endfor -%} - celerite2::core::{{mod.name}}_rev({% for val in mod.rev_inputs + mod.rev_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } else { \ - {% for arg in mod.rev_inputs + mod.rev_outputs %} - {%- if (arg.shape|length == 2 and arg.shape[1] == "nrhs") or (arg.shape|length == 3 and arg.shape[2] == "nrhs") -%} - Eigen::Map<{% if not arg.is_output %}const {% endif %}Eigen::Matrix> {{arg.name}}_({{arg.name}}, {{arg.shape[0]}}, {{arg.shape[1:]|join(" * ")}}); \ - {% endif -%} - {% endfor -%} - celerite2::core::{{mod.name}}_rev({% for val in mod.rev_inputs + mod.rev_outputs %}{{val.name}}_{%- if not loop.last %}, {% endif %}{% endfor %}); \ - } \ + Eigen::Map> {{arg.name}}_({{arg.name}}->typed_data(), {{arg.shape[0]}}, {{arg.shape[1]}}); \ {%- endif %} - } - UNWRAP_CASES_FEW + {%- endfor %} + {%- for arg in mod.rev_outputs %} + {{arg.name}}_.setZero(); \ + {%- endfor %} + try { \ + celerite2::core::{{mod.name}}_rev({%- for arg in mod.rev_inputs %} {{arg.name}}_{% if not loop.last or mod.rev_outputs %},{% endif %}{%- endfor %}{%- if mod.rev_outputs %} {% endif %}{%- for arg in mod.rev_outputs %}{{arg.name}}_{% if not loop.last %},{% endif %}{%- endfor %}); \ + } catch (const std::exception& e) { \ + return ffi::Error::Internal(e.what()); \ + } \ + } + UNWRAP_CASES_FEW #undef FIXED_SIZE_MAP + + return ffi::Error::Success(); } + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + {{mod.name}}_rev, {{mod.name}}_revImpl, + ffi::Ffi::Bind() +{%- for arg in mod.rev_inputs %} + .Arg>() // {{arg.name}} +{%- endfor %} +{%- for arg in mod.rev_outputs %} + .Ret>() // {{arg.name}} +{%- endfor %} +); {% endif %} {% endfor %} -// https://en.cppreference.com/w/cpp/numeric/bit_cast -template -typename std::enable_if::value && std::is_trivially_copyable::value, To>::type -bit_cast(const From &src) noexcept { - static_assert(std::is_trivially_constructible::value, - "This implementation additionally requires destination type to be trivially constructible"); - - To dst; - memcpy(&dst, &src, sizeof(To)); - return dst; -} - -template -py::capsule encapsulate_function(T* fn) { - return py::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +// Pybind -------------------------------------------------------------------- +template +py::capsule Encapsulate() { + return py::capsule(reinterpret_cast(Fn), "xla._CUSTOM_CALL_TARGET"); } PYBIND11_MODULE(xla_ops, m) { - {%- for mod in spec %} - m.def("{{mod.name}}", []() { - return encapsulate_function({{mod.name}}); - }); - {%- if mod.has_rev %} - m.def("{{mod.name}}_rev", []() { - return encapsulate_function({{mod.name}}_rev); - }); - {%- endif %} - {%- endfor %} + {% for mod in spec %} + m.def("{{mod.name}}", &Encapsulate<{{mod.name}}>); + {%- if mod.has_rev %} + m.def("{{mod.name}}_rev", &Encapsulate<{{mod.name}}_rev>); + {%- endif %} + {% endfor %} }