diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f1635cc5..8618edff 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10"] + python-version: ["3.11"] x64: ["0"] include: - python-version: "3.13" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index bf326a49..d41760bc 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -5,7 +5,7 @@ build: apt_packages: - fonts-liberation tools: - python: "3.10" + python: "3.11" python: install: diff --git a/pyproject.toml b/pyproject.toml index f612ca1e..276767e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "tinygp" description = "The tiniest of Gaussian Process libraries" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "MIT" } classifiers = [ "Development Status :: 4 - Beta", @@ -45,11 +45,11 @@ source = "vcs" version-file = "src/tinygp/tinygp_version.py" [tool.black] -target-version = ["py39"] +target-version = ["py312"] line-length = 88 [tool.ruff] -target-version = "py39" +target-version = "py312" line-length = 88 [tool.ruff.lint] @@ -60,6 +60,7 @@ ignore = [ "PLR0913", # Allow many arguments to functions "PLR0915", # Allow many statements "PLR2004", # Allow magic numbers in comparisons + "B905", # Allow zip() without explicit `strict=` parameter ] exclude = [] diff --git a/src/tinygp/gp.py b/src/tinygp/gp.py index 6fa1d865..1feac47d 100644 --- a/src/tinygp/gp.py +++ b/src/tinygp/gp.py @@ -2,12 +2,11 @@ __all__ = ["GaussianProcess"] -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial from typing import ( TYPE_CHECKING, Any, - Callable, NamedTuple, ) diff --git a/src/tinygp/kernels/base.py b/src/tinygp/kernels/base.py index 1adcb88a..8df5077b 100644 --- a/src/tinygp/kernels/base.py +++ b/src/tinygp/kernels/base.py @@ -12,8 +12,8 @@ ] from abc import abstractmethod -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Callable, Union +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any import equinox as eqx import jax @@ -24,7 +24,7 @@ if TYPE_CHECKING: from tinygp.solvers.solver import Solver -Axis = Union[int, Sequence[int]] +Axis = int | Sequence[int] class Kernel(eqx.Module): diff --git a/src/tinygp/means.py b/src/tinygp/means.py index 9811d634..383f9cfe 100644 --- a/src/tinygp/means.py +++ b/src/tinygp/means.py @@ -13,7 +13,7 @@ __all__ = ["Mean", "Conditioned"] from abc import abstractmethod -from typing import Callable +from collections.abc import Callable import equinox as eqx import jax @@ -39,18 +39,18 @@ class Mean(MeanBase): signature. """ - value: JAXArray | None = None + value: JAXArray func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True) def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]): if callable(value): self.func = value + self.value = jax.numpy.zeros(()) # avoids undefined traced values else: self.value = value def __call__(self, X: JAXArray) -> JAXArray: - if self.value is None: - assert self.func is not None + if self.func is not None: return self.func(X) return self.value diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index b38aa5fe..ac1fdf18 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -18,8 +18,9 @@ import dataclasses from abc import abstractmethod +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import equinox as eqx import jax diff --git a/src/tinygp/solvers/quasisep/general.py b/src/tinygp/solvers/quasisep/general.py index e2a0c4f6..6e233d0b 100644 --- a/src/tinygp/solvers/quasisep/general.py +++ b/src/tinygp/solvers/quasisep/general.py @@ -16,8 +16,9 @@ __all__ = ["GeneralQSM"] +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any import equinox as eqx import jax diff --git a/src/tinygp/transforms.py b/src/tinygp/transforms.py index cdd357f5..65fe1b38 100644 --- a/src/tinygp/transforms.py +++ b/src/tinygp/transforms.py @@ -8,9 +8,9 @@ __all__ = ["Transform", "Linear", "Cholesky", "Subspace"] -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, Callable +from typing import Any import equinox as eqx import jax.numpy as jnp diff --git a/tests/test_gp.py b/tests/test_gp.py index 0c478e0c..add8ad91 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -24,7 +24,7 @@ def data(random): def test_sample(data): X, _ = data - with jax.experimental.enable_x64(True): + with jax.enable_x64(True): gp = GaussianProcess( kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x) ) diff --git a/tests/test_kernels/test_kernels.py b/tests/test_kernels/test_kernels.py index 6e05883d..817c11dd 100644 --- a/tests/test_kernels/test_kernels.py +++ b/tests/test_kernels/test_kernels.py @@ -71,7 +71,7 @@ def test_ops(data): def test_conditioned(data): x1, x2 = data - with jax.experimental.enable_x64(): # type: ignore + with jax.enable_x64(True): # type: ignore k1 = 1.5 * kernels.Matern32(2.5) k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3) K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0])