Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
d6bf4e8
fix: jax 0.8 compatibility
MilesCranmer Nov 23, 2025
314ba3e
fix: get pymc working on v5
MilesCranmer Nov 23, 2025
f13d3a4
deps: lower to python 3.10
MilesCranmer Nov 23, 2025
5790e09
fix: JIT compat of `_do_compute`
MilesCranmer Nov 23, 2025
b14435b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2025
8b36a42
deps: require python 3.11 for jaxlib compat
MilesCranmer Nov 23, 2025
461fb35
ci: skip bad nox arg
MilesCranmer Nov 23, 2025
9957276
ci: fix nox target version
MilesCranmer Nov 23, 2025
7d5d97d
deps: fix missing jaxlib install
MilesCranmer Nov 23, 2025
272344d
ci: fix missing jax install for pymc tests
MilesCranmer Nov 23, 2025
341c2f5
fix: generator will no longer add lines of spaces
MilesCranmer Nov 29, 2025
ffaff43
fix: re-generated xla_ops without api.h
MilesCranmer Nov 29, 2025
f6fc718
refactor: clean up error handling
MilesCranmer Nov 29, 2025
2644c1f
fix: `Primitive` import
MilesCranmer Nov 29, 2025
4fa899d
build: update cmake per review
MilesCranmer Nov 29, 2025
f527c78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2025
7679b85
refactor: put axis into template
MilesCranmer Nov 29, 2025
faa9b95
fix: add back nrhs==1 branch
MilesCranmer Nov 29, 2025
63dd782
style: remove extra padding
MilesCranmer Nov 29, 2025
09d3ac4
deps: remove upper bound
MilesCranmer Nov 29, 2025
7d71090
chore: re-run generator
MilesCranmer Nov 30, 2025
7985936
Fix JAX 0.8 primitive impl path and PyMC JAX registration
MilesCranmerBot Feb 11, 2026
63c1f05
fix(jax): avoid removed jax.lib.xla_client import
MilesCranmerBot Feb 11, 2026
60fe47a
build: require jax in PEP517 build env for JAX extension
MilesCranmerBot Feb 11, 2026
8922f2a
refactor(jax): drop pre-0.8 apply_primitive fallback
MilesCranmerBot Feb 12, 2026
3f99d45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
5fa9d91
build: avoid requiring jax on Python<3.11 for docs
MilesCranmerBot Feb 12, 2026
2978004
Upgrade RTDs Python version from 3.10 to 3.11
dfm Feb 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.11"]
os: ["ubuntu-latest"]
session:
- "core"
- "jax"
- "pymc3"
# - "pymc"
- "pymc"
- "pymc_jax"

steps:
Expand All @@ -47,7 +46,7 @@ jobs:

- name: Run tests
run: |
python -m nox --non-interactive --error-on-missing-interpreter \
python -m nox --non-interactive \
--session ${{ matrix.session }}-${{ matrix.python-version }}

tests-pymc:
Expand All @@ -68,12 +67,12 @@ jobs:
environment-name: test-env
create-args: >-
mamba
python=3.10
python=3.11

- name: Install nox
run: python -m pip install -U nox

- name: Run tests
run: |
python -m nox --non-interactive --error-on-missing-interpreter \
--session pymc_mamba-3.10
python -m nox --non-interactive \
--session pymc_mamba-3.11
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yml.off
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.11

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install -U pip
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build:
apt_packages:
- fonts-liberation
tools:
python: "3.10"
python: "3.11"

python:
install:
Expand Down
21 changes: 18 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ project(

set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module)

include_directories(
"c++/include"
Expand All @@ -20,6 +21,20 @@ pybind11_add_module(backprop "python/celerite2/backprop.cpp")
target_compile_features(backprop PUBLIC cxx_std_14)
install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp")
target_compile_features(xla_ops PUBLIC cxx_std_14)
install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax")
option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON)
if(BUILD_JAX)
execute_process(
COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_DIR
RESULT_VARIABLE JAXLIB_RES)
if(JAXLIB_RES EQUAL 0 AND NOT "${XLA_DIR}" STREQUAL "")
message(STATUS "Building JAX extension with XLA include: ${XLA_DIR}")
pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp")
target_compile_features(xla_ops PUBLIC cxx_std_17)
target_include_directories(xla_ops PUBLIC "${XLA_DIR}")
install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax")
else()
message(STATUS "Skipping JAX extension (jax.ffi include_dir not found)")
endif()
endif()
4 changes: 2 additions & 2 deletions c++/include/celerite2/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void general_matmul_lower(const Eigen::MatrixBase<Input> &t1, // (

typedef typename LowRank::Scalar Scalar;
typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector;
typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner;
typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner;

Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols();

Expand Down Expand Up @@ -358,7 +358,7 @@ void general_matmul_upper(const Eigen::MatrixBase<Input> &t1, // (

typedef typename LowRank::Scalar Scalar;
typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector;
typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner;
typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner;

Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols();

Expand Down
25 changes: 3 additions & 22 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import nox

ALL_PYTHON_VS = ["3.8", "3.9", "3.10"]
ALL_PYTHON_VS = ["3.11"]
TEST_CMD = ["python", "-m", "pytest", "-v"]


Expand All @@ -23,22 +23,16 @@ def jax(session):
_session_run(session, "python/test/jax")


@nox.session(python=ALL_PYTHON_VS)
def pymc3(session):
session.install(".[test,pymc3]")
_session_run(session, "python/test/pymc3")


@nox.session(python=ALL_PYTHON_VS)
def pymc(session):
session.install(".[test,pymc]")
session.install(".[test,pymc,jax]")
_session_run(session, "python/test/pymc")


@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba")
def pymc_mamba(session):
session.conda_install("pymc", channel="conda-forge")
session.install(".[test,pymc]")
session.install(".[test,pymc,jax]")
_session_run(session, "python/test/pymc")


Expand All @@ -48,19 +42,6 @@ def pymc_jax(session):
_session_run(session, "python/test/pymc/test_pymc_ops.py")


@nox.session(python=ALL_PYTHON_VS)
def full(session):
session.install(".[test,jax,pymc3,pymc]")
_session_run(session, "python/test")


@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba")
def full_mamba(session):
session.conda_install("jax", "pymc3", "pymc", channel="conda-forge")
session.install(".[test]")
_session_run(session, "python/test")


@nox.session
def lint(session):
session.install("pre-commit")
Expand Down
14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "celerite2"
description = "Fast and scalable Gaussian Processes in 1D"
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.11"
license = { text = "MIT License" }
classifiers = [
"Development Status :: 4 - Beta",
Expand All @@ -18,10 +18,8 @@ dependencies = ["numpy"]

[project.optional-dependencies]
test = ["pytest", "scipy", "celerite"]
pymc3 = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"]
theano = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"]
pymc = ["pymc>=5.9.2"]
jax = ["jax"]
pymc = ["pymc>=5.26.1"]
jax = ["jax>=0.8.0"]
docs = [
"sphinx",
"sphinx-material",
Expand All @@ -31,19 +29,19 @@ docs = [
"matplotlib",
"scipy",
"emcee",
"pymc>=5",
"pymc>=5.26.1",
"tqdm",
"numpyro",
]
tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"]
tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"]

[project.urls]
"Homepage" = "https://celerite2.readthedocs.io"
"Source" = "https://github.com/exoplanet-dev/celerite2"
"Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues"

[build-system]
requires = ["scikit-build-core", "numpy", "pybind11"]
requires = ["scikit-build-core", "numpy", "pybind11", "jax>=0.8.0; python_version >= '3.11'", "jaxlib>=0.8.0"]
build-backend = "scikit_build_core.build"

[tool.scikit-build]
Expand Down
13 changes: 12 additions & 1 deletion python/celerite2/jax/celerite2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

__all__ = ["GaussianProcess", "ConditionalDistribution"]
from jax import lax
from jax import numpy as np

from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess
Expand Down Expand Up @@ -35,7 +36,17 @@ def _do_compute(self, quiet):
self._t, self._c, self._a, self._U, self._V
)
self._log_det = np.sum(np.log(self._d))
self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi))

def _bad(_):
return -np.inf, np.inf

def _good(_):
return self._log_det, -0.5 * (
self._log_det + self._size * np.log(2 * np.pi)
)

bad = np.any(self._d <= 0) | (~np.isfinite(self._log_det))
self._log_det, self._norm = lax.cond(bad, _bad, _good, operand=None)

def _check_sorted(self, t):
return t
Expand Down
55 changes: 23 additions & 32 deletions python/celerite2/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@
"general_matmul_lower",
"general_matmul_upper",
]
import importlib
import importlib.resources as resources
import json
from collections import OrderedDict
from functools import partial
from itertools import chain

import numpy as np
import pkg_resources
from jax import core, lax
from jax import ffi, lax
from jax import numpy as jnp
from jax.core import ShapedArray
from jax.interpreters import ad, mlir, xla
from jax.lib import xla_client
from jax.extend.core import Primitive
from jax.interpreters import ad, mlir

from celerite2.jax import xla_ops
xla_ops = importlib.import_module("celerite2.jax.xla_ops")

xops = xla_client.ops
# celerite2 requires jax>=0.8.0 (see pyproject.toml), where apply_primitive lives in
# jax._src.dispatch.
from jax._src.dispatch import apply_primitive as _apply_primitive


def factor(t, c, a, U, V):
Expand Down Expand Up @@ -103,20 +105,7 @@ def _abstract_eval(spec, *args):
def _lowering_rule(name, spec, ctx: mlir.LoweringRuleContext, *args):
if any(a.dtype != np.float64 for a in chain(ctx.avals_in, ctx.avals_out)):
raise ValueError(f"{spec['name']} requires float64 precision")
shapes = [a.shape for a in ctx.avals_in]
dims = OrderedDict(
(s["name"], shapes[s["coords"][0]][s["coords"][1]])
for s in spec["dimensions"]
)
return mlir.custom_call(
name,
operands=tuple(mlir.ir_constant(np.int32(v)) for v in dims.values())
+ args,
operand_layouts=[()] * len(dims)
+ _default_layouts(aval.shape for aval in ctx.avals_in),
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
result_layouts=_default_layouts(aval.shape for aval in ctx.avals_out),
).results
return ffi.ffi_lowering(name)(ctx, *args)


def _default_layouts(shapes):
Expand Down Expand Up @@ -184,13 +173,16 @@ def _rev_lowering_rule(name, spec, ctx, *args):


def _build_op(name, spec):
xla_client.register_custom_call_target(
name, getattr(xla_ops, spec["name"])(), platform="cpu"
ffi.register_ffi_target(
name,
getattr(xla_ops, spec["name"])(),
platform="cpu",
api_version=1,
)

prim = core.Primitive(f"celerite2_{spec['name']}")
prim = Primitive(f"celerite2_{spec['name']}")
prim.multiple_results = True
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_impl(partial(_apply_primitive, prim))
prim.def_abstract_eval(partial(_abstract_eval, spec))
mlir.register_lowering(
prim, partial(_lowering_rule, name, spec), platform="cpu"
Expand All @@ -199,15 +191,16 @@ def _build_op(name, spec):
if not spec["has_rev"]:
return prim, None

xla_client.register_custom_call_target(
ffi.register_ffi_target(
name + "_rev",
getattr(xla_ops, f"{spec['name']}_rev")(),
platform="cpu",
api_version=1,
)

jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp")
jvp_prim = Primitive(f"celerite2_{spec['name']}_jvp")
jvp_prim.multiple_results = True
rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev")
rev_prim = Primitive(f"celerite2_{spec['name']}_rev")
rev_prim.multiple_results = True

# Setup a dummy JVP rule
Expand All @@ -216,7 +209,7 @@ def _build_op(name, spec):
ad.primitive_transposes[jvp_prim] = partial(_jvp_transpose, rev_prim, spec)

# Handle reverse pass using custom op
rev_prim.def_impl(partial(xla.apply_primitive, rev_prim))
rev_prim.def_impl(partial(_apply_primitive, rev_prim))
rev_prim.def_abstract_eval(partial(_rev_abstract_eval, spec))
mlir.register_lowering(
rev_prim,
Expand All @@ -227,9 +220,7 @@ def _build_op(name, spec):
return prim, rev_prim


with open(
pkg_resources.resource_filename("celerite2", "definitions.json"), "r"
) as f:
with resources.files("celerite2").joinpath("definitions.json").open("r") as f:
definitions = {spec["name"]: spec for spec in json.load(f)}


Expand Down
Loading