From d588dac853603a4c7db78ff02850b531c9476d1e Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Tue, 27 Jan 2026 11:11:17 +0100 Subject: [PATCH 1/9] Replace deprecated jax.experimental.enable_x64 contextmanager --- src/tinygp/test_utils.py | 26 ++++++++++++++++++++++++++ tests/test_gp.py | 4 ++-- tests/test_kernels/test_kernels.py | 4 ++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/tinygp/test_utils.py b/src/tinygp/test_utils.py index dd469d84..f40723ae 100644 --- a/src/tinygp/test_utils.py +++ b/src/tinygp/test_utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Any import jax @@ -30,3 +31,28 @@ def assert_pytrees_allclose(calculated: Any, expected: Any, *args: Any, **kwargs jax.tree_util.tree_map( lambda a, b: assert_allclose(a, b, *args, **kwargs), calculated, expected ) + + +def _as_context_manager(obj): + # If it's already a context manager + if hasattr(obj, "__enter__") and hasattr(obj, "__exit__"): + return obj + + # If it's a generator, wrap it + if hasattr(obj, "__iter__") and hasattr(obj, "send"): + return contextmanager(lambda: obj)() + + raise TypeError("Object is neither a context manager nor a generator") + + +@contextmanager +def jax_enable_x64(): + if hasattr(jax, "enable_x64"): + cm = jax.enable_x64(True) + else: + # deprecated in jax>=0.9 + from jax.experimental import enable_x64 as _enable_x64 + cm = _enable_x64() + + with _as_context_manager(cm): + yield diff --git a/tests/test_gp.py b/tests/test_gp.py index 0c478e0c..be77c11b 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -6,7 +6,7 @@ from numpy import random as np_random from tinygp import GaussianProcess, kernels -from tinygp.test_utils import assert_allclose +from tinygp.test_utils import assert_allclose, jax_enable_x64 @pytest.fixture @@ -24,7 +24,7 @@ def data(random): def test_sample(data): X, _ = data - with jax.experimental.enable_x64(True): + with jax_enable_x64(): gp = GaussianProcess( kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x) ) diff --git a/tests/test_kernels/test_kernels.py b/tests/test_kernels/test_kernels.py index 6e05883d..f41347fc 100644 --- a/tests/test_kernels/test_kernels.py +++ b/tests/test_kernels/test_kernels.py @@ -5,7 +5,7 @@ from tinygp import kernels, noise from tinygp.solvers import DirectSolver -from tinygp.test_utils import assert_allclose +from tinygp.test_utils import assert_allclose, jax_enable_x64 @pytest.fixture @@ -71,7 +71,7 @@ def test_ops(data): def test_conditioned(data): x1, x2 = data - with jax.experimental.enable_x64(): # type: ignore + with jax_enable_x64(): # type: ignore k1 = 1.5 * kernels.Matern32(2.5) k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3) K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0]) From 9cb6245a34ecf9bc4413762d329a4fd1332bc3e0 Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Tue, 27 Jan 2026 11:19:11 +0100 Subject: [PATCH 2/9] Avoid None attributes in JIT-wrapped functions --- src/tinygp/means.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tinygp/means.py b/src/tinygp/means.py index 9811d634..0db37b17 100644 --- a/src/tinygp/means.py +++ b/src/tinygp/means.py @@ -39,18 +39,18 @@ class Mean(MeanBase): signature. """ - value: JAXArray | None = None + value: JAXArray func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True) def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]): if callable(value): self.func = value + self.value = jax.numpy.zeros(()) else: self.value = value def __call__(self, X: JAXArray) -> JAXArray: - if self.value is None: - assert self.func is not None + if self.func is not None: return self.func(X) return self.value From e4a7f69b260484c057aa6d8d4b6b493c411a550a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:50:52 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/tinygp/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tinygp/test_utils.py b/src/tinygp/test_utils.py index f40723ae..e55d211b 100644 --- a/src/tinygp/test_utils.py +++ b/src/tinygp/test_utils.py @@ -52,6 +52,7 @@ def jax_enable_x64(): else: # deprecated in jax>=0.9 from jax.experimental import enable_x64 as _enable_x64 + cm = _enable_x64() with _as_context_manager(cm): From 5dff3b0f0c67911d5a7a43b64a474bd1546e6984 Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Mon, 2 Feb 2026 09:37:17 +0100 Subject: [PATCH 4/9] Inline comment --- src/tinygp/means.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tinygp/means.py b/src/tinygp/means.py index 0db37b17..5f01c21c 100644 --- a/src/tinygp/means.py +++ b/src/tinygp/means.py @@ -45,7 +45,7 @@ class Mean(MeanBase): def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]): if callable(value): self.func = value - self.value = jax.numpy.zeros(()) + self.value = jax.numpy.zeros(()) # avoids undefined traced values else: self.value = value From f1e5ac034e95ba13cdbedc7bbd5de44d0906852c Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Mon, 2 Feb 2026 09:37:39 +0100 Subject: [PATCH 5/9] Drop Python 3.10 --- pyproject.toml | 6 +++--- src/tinygp/test_utils.py | 27 --------------------------- tests/test_gp.py | 4 ++-- tests/test_kernels/test_kernels.py | 4 ++-- 4 files changed, 7 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f612ca1e..6846ad9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "tinygp" description = "The tiniest of Gaussian Process libraries" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "MIT" } classifiers = [ "Development Status :: 4 - Beta", @@ -45,11 +45,11 @@ source = "vcs" version-file = "src/tinygp/tinygp_version.py" [tool.black] -target-version = ["py39"] +target-version = ["py312"] line-length = 88 [tool.ruff] -target-version = "py39" +target-version = "py312" line-length = 88 [tool.ruff.lint] diff --git a/src/tinygp/test_utils.py b/src/tinygp/test_utils.py index e55d211b..dd469d84 100644 --- a/src/tinygp/test_utils.py +++ b/src/tinygp/test_utils.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Any import jax @@ -31,29 +30,3 @@ def assert_pytrees_allclose(calculated: Any, expected: Any, *args: Any, **kwargs jax.tree_util.tree_map( lambda a, b: assert_allclose(a, b, *args, **kwargs), calculated, expected ) - - -def _as_context_manager(obj): - # If it's already a context manager - if hasattr(obj, "__enter__") and hasattr(obj, "__exit__"): - return obj - - # If it's a generator, wrap it - if hasattr(obj, "__iter__") and hasattr(obj, "send"): - return contextmanager(lambda: obj)() - - raise TypeError("Object is neither a context manager nor a generator") - - -@contextmanager -def jax_enable_x64(): - if hasattr(jax, "enable_x64"): - cm = jax.enable_x64(True) - else: - # deprecated in jax>=0.9 - from jax.experimental import enable_x64 as _enable_x64 - - cm = _enable_x64() - - with _as_context_manager(cm): - yield diff --git a/tests/test_gp.py b/tests/test_gp.py index be77c11b..add8ad91 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -6,7 +6,7 @@ from numpy import random as np_random from tinygp import GaussianProcess, kernels -from tinygp.test_utils import assert_allclose, jax_enable_x64 +from tinygp.test_utils import assert_allclose @pytest.fixture @@ -24,7 +24,7 @@ def data(random): def test_sample(data): X, _ = data - with jax_enable_x64(): + with jax.enable_x64(True): gp = GaussianProcess( kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x) ) diff --git a/tests/test_kernels/test_kernels.py b/tests/test_kernels/test_kernels.py index f41347fc..817c11dd 100644 --- a/tests/test_kernels/test_kernels.py +++ b/tests/test_kernels/test_kernels.py @@ -5,7 +5,7 @@ from tinygp import kernels, noise from tinygp.solvers import DirectSolver -from tinygp.test_utils import assert_allclose, jax_enable_x64 +from tinygp.test_utils import assert_allclose @pytest.fixture @@ -71,7 +71,7 @@ def test_ops(data): def test_conditioned(data): x1, x2 = data - with jax_enable_x64(): # type: ignore + with jax.enable_x64(True): # type: ignore k1 = 1.5 * kernels.Matern32(2.5) k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3) K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0]) From dda2568a0d160147c1867c22e0d4273ea7729033 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 08:39:35 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/tinygp/gp.py | 3 +-- src/tinygp/kernels/base.py | 4 ++-- src/tinygp/means.py | 2 +- src/tinygp/solvers/quasisep/core.py | 3 ++- src/tinygp/solvers/quasisep/general.py | 3 ++- src/tinygp/transforms.py | 4 ++-- 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/tinygp/gp.py b/src/tinygp/gp.py index 6fa1d865..1feac47d 100644 --- a/src/tinygp/gp.py +++ b/src/tinygp/gp.py @@ -2,12 +2,11 @@ __all__ = ["GaussianProcess"] -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial from typing import ( TYPE_CHECKING, Any, - Callable, NamedTuple, ) diff --git a/src/tinygp/kernels/base.py b/src/tinygp/kernels/base.py index 1adcb88a..21de863f 100644 --- a/src/tinygp/kernels/base.py +++ b/src/tinygp/kernels/base.py @@ -12,8 +12,8 @@ ] from abc import abstractmethod -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Callable, Union +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, Union import equinox as eqx import jax diff --git a/src/tinygp/means.py b/src/tinygp/means.py index 5f01c21c..383f9cfe 100644 --- a/src/tinygp/means.py +++ b/src/tinygp/means.py @@ -13,7 +13,7 @@ __all__ = ["Mean", "Conditioned"] from abc import abstractmethod -from typing import Callable +from collections.abc import Callable import equinox as eqx import jax diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index b38aa5fe..ac1fdf18 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -18,8 +18,9 @@ import dataclasses from abc import abstractmethod +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import equinox as eqx import jax diff --git a/src/tinygp/solvers/quasisep/general.py b/src/tinygp/solvers/quasisep/general.py index e2a0c4f6..6e233d0b 100644 --- a/src/tinygp/solvers/quasisep/general.py +++ b/src/tinygp/solvers/quasisep/general.py @@ -16,8 +16,9 @@ __all__ = ["GeneralQSM"] +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import equinox as eqx import jax diff --git a/src/tinygp/transforms.py b/src/tinygp/transforms.py index cdd357f5..65fe1b38 100644 --- a/src/tinygp/transforms.py +++ b/src/tinygp/transforms.py @@ -8,9 +8,9 @@ __all__ = ["Transform", "Linear", "Cholesky", "Subspace"] -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, Callable +from typing import Any import equinox as eqx import jax.numpy as jnp From 9412064b69c32dbf2e005999dcedc0e8f5cb144b Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Mon, 2 Feb 2026 09:48:10 +0100 Subject: [PATCH 7/9] CI test on Python 3.11 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f1635cc5..8618edff 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.11"] x64: ["0"] include: - python-version: "3.13" From 66ef95b2767db0f3c287b02553bc05a6f6dd0a2d Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Mon, 2 Feb 2026 10:06:27 +0100 Subject: [PATCH 8/9] Python 3.11 for readthedocs --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index bf326a49..d41760bc 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,7 +5,7 @@ build: apt_packages: - fonts-liberation tools: - python: "3.10" + python: "3.11" python: install: From 1933355dcee50cc33d61cab58210d06627e4bbff Mon Sep 17 00:00:00 2001 From: Hannes Holey Date: Mon, 2 Feb 2026 10:07:35 +0100 Subject: [PATCH 9/9] Fix linting issues --- pyproject.toml | 1 + src/tinygp/kernels/base.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6846ad9f..276767e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ ignore = [ "PLR0913", # Allow many arguments to functions "PLR0915", # Allow many statements "PLR2004", # Allow magic numbers in comparisons + "B905", # Allow zip() without explicit `strict=` parameter ] exclude = [] diff --git a/src/tinygp/kernels/base.py b/src/tinygp/kernels/base.py index 21de863f..8df5077b 100644 --- a/src/tinygp/kernels/base.py +++ b/src/tinygp/kernels/base.py @@ -13,7 +13,7 @@ from abc import abstractmethod from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import equinox as eqx import jax @@ -24,7 +24,7 @@ if TYPE_CHECKING: from tinygp.solvers.solver import Solver -Axis = Union[int, Sequence[int]] +Axis = int | Sequence[int] class Kernel(eqx.Module):