Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.11"]
os: ["ubuntu-latest"]
session:
- "core"
- "jax"
- "pymc3"
# - "pymc"
- "pymc"
- "pymc_jax"

steps:
Expand All @@ -47,7 +46,7 @@ jobs:

- name: Run tests
run: |
python -m nox --non-interactive --error-on-missing-interpreter \
python -m nox --non-interactive \
--session ${{ matrix.session }}-${{ matrix.python-version }}

tests-pymc:
Expand All @@ -68,12 +67,12 @@ jobs:
environment-name: test-env
create-args: >-
mamba
python=3.10
python=3.11

- name: Install nox
run: python -m pip install -U nox

- name: Run tests
run: |
python -m nox --non-interactive --error-on-missing-interpreter \
--session pymc_mamba-3.10
python -m nox --non-interactive \
--session pymc_mamba-3.11
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yml.off
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.11

- name: Install dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install -U pip
Expand Down
21 changes: 18 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ project(

set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)
find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module)

include_directories(
"c++/include"
Expand All @@ -20,6 +21,20 @@ pybind11_add_module(backprop "python/celerite2/backprop.cpp")
target_compile_features(backprop PUBLIC cxx_std_14)
install(TARGETS backprop LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp")
target_compile_features(xla_ops PUBLIC cxx_std_14)
install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax")
option(BUILD_JAX "Build JAX extension (requires jaxlib headers)" ON)
if(BUILD_JAX)
execute_process(
COMMAND "${Python_EXECUTABLE}" "-c" "from jax import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE XLA_DIR
RESULT_VARIABLE JAXLIB_RES)
if(JAXLIB_RES EQUAL 0 AND NOT "${XLA_DIR}" STREQUAL "")
message(STATUS "Building JAX extension with XLA include: ${XLA_DIR}")
pybind11_add_module(xla_ops "python/celerite2/jax/xla_ops.cpp")
target_compile_features(xla_ops PUBLIC cxx_std_17)
target_include_directories(xla_ops PUBLIC "${XLA_DIR}")
install(TARGETS xla_ops LIBRARY DESTINATION "${SKBUILD_PROJECT_NAME}/jax")
else()
message(STATUS "Skipping JAX extension (jax.ffi include_dir not found)")
endif()
endif()
4 changes: 2 additions & 2 deletions c++/include/celerite2/forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ void general_matmul_lower(const Eigen::MatrixBase<Input> &t1, // (

typedef typename LowRank::Scalar Scalar;
typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector;
typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner;
typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, this change will dramatically impact performance because Eigen won't be able to generate properly vectorized code for small systems. It's really useful to compile for specific sizes! Why did you make this change?

Copy link
Collaborator Author

@MilesCranmer MilesCranmer Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I see. I couldn't get it working initially but this change seemed to do it. I didn't know it would hurt performance though so I'll fix it now.


Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols();

Expand Down Expand Up @@ -358,7 +358,7 @@ void general_matmul_upper(const Eigen::MatrixBase<Input> &t1, // (

typedef typename LowRank::Scalar Scalar;
typedef typename Eigen::internal::plain_col_type<Coeffs>::type CoeffVector;
typedef typename Eigen::Matrix<Scalar, LowRank::ColsAtCompileTime, RightHandSide::ColsAtCompileTime> Inner;
typedef typename Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> Inner;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


Eigen::Index N = t1.rows(), M = t2.rows(), J = c.rows(), nrhs = Y.cols();

Expand Down
25 changes: 3 additions & 22 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import nox

ALL_PYTHON_VS = ["3.8", "3.9", "3.10"]
ALL_PYTHON_VS = ["3.11"]
TEST_CMD = ["python", "-m", "pytest", "-v"]


Expand All @@ -23,22 +23,16 @@ def jax(session):
_session_run(session, "python/test/jax")


@nox.session(python=ALL_PYTHON_VS)
def pymc3(session):
session.install(".[test,pymc3]")
_session_run(session, "python/test/pymc3")


@nox.session(python=ALL_PYTHON_VS)
def pymc(session):
session.install(".[test,pymc]")
session.install(".[test,pymc,jax]")
_session_run(session, "python/test/pymc")


@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba")
def pymc_mamba(session):
session.conda_install("pymc", channel="conda-forge")
session.install(".[test,pymc]")
session.install(".[test,pymc,jax]")
_session_run(session, "python/test/pymc")


Expand All @@ -48,19 +42,6 @@ def pymc_jax(session):
_session_run(session, "python/test/pymc/test_pymc_ops.py")


@nox.session(python=ALL_PYTHON_VS)
def full(session):
session.install(".[test,jax,pymc3,pymc]")
_session_run(session, "python/test")


@nox.session(python=ALL_PYTHON_VS, venv_backend="mamba")
def full_mamba(session):
session.conda_install("jax", "pymc3", "pymc", channel="conda-forge")
session.install(".[test]")
_session_run(session, "python/test")


@nox.session
def lint(session):
session.install("pre-commit")
Expand Down
14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "celerite2"
description = "Fast and scalable Gaussian Processes in 1D"
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.11"
license = { text = "MIT License" }
classifiers = [
"Development Status :: 4 - Beta",
Expand All @@ -18,10 +18,8 @@ dependencies = ["numpy"]

[project.optional-dependencies]
test = ["pytest", "scipy", "celerite"]
pymc3 = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"]
theano = ["pymc3>=3.9", "numpy<1.22", "xarray<2023.10.0"]
pymc = ["pymc>=5.9.2"]
jax = ["jax"]
pymc = ["pymc>=5.26.1"]
jax = ["jax>=0.8.0"]
docs = [
"sphinx",
"sphinx-material",
Expand All @@ -31,19 +29,19 @@ docs = [
"matplotlib",
"scipy",
"emcee",
"pymc>=5",
"pymc>=5.26.1",
"tqdm",
"numpyro",
]
tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5", "tqdm", "numpyro"]
tutorials = ["matplotlib", "scipy", "emcee", "pymc>=5.26.1", "tqdm", "numpyro"]

[project.urls]
"Homepage" = "https://celerite2.readthedocs.io"
"Source" = "https://github.com/exoplanet-dev/celerite2"
"Bug Tracker" = "https://github.com/exoplanet-dev/celerite2/issues"

[build-system]
requires = ["scikit-build-core", "numpy", "pybind11"]
requires = ["scikit-build-core", "numpy", "pybind11", "jaxlib>=0.8.0"]
build-backend = "scikit_build_core.build"

[tool.scikit-build]
Expand Down
13 changes: 12 additions & 1 deletion python/celerite2/jax/celerite2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

__all__ = ["GaussianProcess", "ConditionalDistribution"]
from jax import lax
from jax import numpy as np

from celerite2.core import BaseConditionalDistribution, BaseGaussianProcess
Expand Down Expand Up @@ -35,7 +36,17 @@ def _do_compute(self, quiet):
self._t, self._c, self._a, self._U, self._V
)
self._log_det = np.sum(np.log(self._d))
self._norm = -0.5 * (self._log_det + self._size * np.log(2 * np.pi))

def _bad(_):
return -np.inf, np.inf

def _good(_):
return self._log_det, -0.5 * (
self._log_det + self._size * np.log(2 * np.pi)
)

bad = np.any(self._d <= 0) | (~np.isfinite(self._log_det))
self._log_det, self._norm = lax.cond(bad, _bad, _good, operand=None)

def _check_sorted(self, t):
return t
Expand Down
46 changes: 17 additions & 29 deletions python/celerite2/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@
"general_matmul_lower",
"general_matmul_upper",
]
import importlib
import importlib.resources as resources
import json
from collections import OrderedDict
from functools import partial
from itertools import chain

import numpy as np
import pkg_resources
from jax import core, lax
from jax import core, ffi, lax
from jax import numpy as jnp
from jax.core import ShapedArray
from jax.extend.core import Primitive
from jax.interpreters import ad, mlir, xla
from jax.lib import xla_client

from celerite2.jax import xla_ops

xops = xla_client.ops
xla_ops = importlib.import_module("celerite2.jax.xla_ops")


def factor(t, c, a, U, V):
Expand Down Expand Up @@ -103,20 +102,7 @@ def _abstract_eval(spec, *args):
def _lowering_rule(name, spec, ctx: mlir.LoweringRuleContext, *args):
if any(a.dtype != np.float64 for a in chain(ctx.avals_in, ctx.avals_out)):
raise ValueError(f"{spec['name']} requires float64 precision")
shapes = [a.shape for a in ctx.avals_in]
dims = OrderedDict(
(s["name"], shapes[s["coords"][0]][s["coords"][1]])
for s in spec["dimensions"]
)
return mlir.custom_call(
name,
operands=tuple(mlir.ir_constant(np.int32(v)) for v in dims.values())
+ args,
operand_layouts=[()] * len(dims)
+ _default_layouts(aval.shape for aval in ctx.avals_in),
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
result_layouts=_default_layouts(aval.shape for aval in ctx.avals_out),
).results
return ffi.ffi_lowering(name)(ctx, *args)


def _default_layouts(shapes):
Expand Down Expand Up @@ -184,11 +170,14 @@ def _rev_lowering_rule(name, spec, ctx, *args):


def _build_op(name, spec):
xla_client.register_custom_call_target(
name, getattr(xla_ops, spec["name"])(), platform="cpu"
ffi.register_ffi_target(
name,
getattr(xla_ops, spec["name"])(),
platform="cpu",
api_version=1,
)

prim = core.Primitive(f"celerite2_{spec['name']}")
prim = Primitive(f"celerite2_{spec['name']}")
prim.multiple_results = True
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(partial(_abstract_eval, spec))
Expand All @@ -199,15 +188,16 @@ def _build_op(name, spec):
if not spec["has_rev"]:
return prim, None

xla_client.register_custom_call_target(
ffi.register_ffi_target(
name + "_rev",
getattr(xla_ops, f"{spec['name']}_rev")(),
platform="cpu",
api_version=1,
)

jvp_prim = core.Primitive(f"celerite2_{spec['name']}_jvp")
jvp_prim = Primitive(f"celerite2_{spec['name']}_jvp")
jvp_prim.multiple_results = True
rev_prim = core.Primitive(f"celerite2_{spec['name']}_rev")
rev_prim = Primitive(f"celerite2_{spec['name']}_rev")
rev_prim.multiple_results = True

# Setup a dummy JVP rule
Expand All @@ -227,9 +217,7 @@ def _build_op(name, spec):
return prim, rev_prim


with open(
pkg_resources.resource_filename("celerite2", "definitions.json"), "r"
) as f:
with resources.files("celerite2").joinpath("definitions.json").open("r") as f:
definitions = {spec["name"]: spec for spec in json.load(f)}


Expand Down
Loading
Loading