Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.11"]
x64: ["0"]
include:
- python-version: "3.13"
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build:
apt_packages:
- fonts-liberation
tools:
python: "3.10"
python: "3.11"

python:
install:
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "tinygp"
description = "The tiniest of Gaussian Process libraries"
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
license = { text = "MIT" }
classifiers = [
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -45,11 +45,11 @@ source = "vcs"
version-file = "src/tinygp/tinygp_version.py"

[tool.black]
target-version = ["py39"]
target-version = ["py312"]
line-length = 88

[tool.ruff]
target-version = "py39"
target-version = "py312"
line-length = 88

[tool.ruff.lint]
Expand All @@ -60,6 +60,7 @@ ignore = [
"PLR0913", # Allow many arguments to functions
"PLR0915", # Allow many statements
"PLR2004", # Allow magic numbers in comparisons
"B905", # Allow zip() without explicit `strict=` parameter
]
exclude = []

Expand Down
3 changes: 1 addition & 2 deletions src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

__all__ = ["GaussianProcess"]

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
NamedTuple,
)

Expand Down
6 changes: 3 additions & 3 deletions src/tinygp/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
]

from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Union
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any

import equinox as eqx
import jax
Expand All @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from tinygp.solvers.solver import Solver

Axis = Union[int, Sequence[int]]
Axis = int | Sequence[int]


class Kernel(eqx.Module):
Expand Down
8 changes: 4 additions & 4 deletions src/tinygp/means.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__all__ = ["Mean", "Conditioned"]

from abc import abstractmethod
from typing import Callable
from collections.abc import Callable

import equinox as eqx
import jax
Expand All @@ -39,18 +39,18 @@ class Mean(MeanBase):
signature.
"""

value: JAXArray | None = None
value: JAXArray
func: Callable[[JAXArray], JAXArray] | None = eqx.field(default=None, static=True)

def __init__(self, value: JAXArray | Callable[[JAXArray], JAXArray]):
if callable(value):
self.func = value
self.value = jax.numpy.zeros(()) # avoids undefined traced values
else:
self.value = value

def __call__(self, X: JAXArray) -> JAXArray:
if self.value is None:
assert self.func is not None
if self.func is not None:
return self.func(X)
return self.value

Expand Down
3 changes: 2 additions & 1 deletion src/tinygp/solvers/quasisep/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

import dataclasses
from abc import abstractmethod
from collections.abc import Callable
from functools import wraps
from typing import Any, Callable
from typing import Any

import equinox as eqx
import jax
Expand Down
3 changes: 2 additions & 1 deletion src/tinygp/solvers/quasisep/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

__all__ = ["GeneralQSM"]

from collections.abc import Callable
from functools import wraps
from typing import Any, Callable
from typing import Any

import equinox as eqx
import jax
Expand Down
4 changes: 2 additions & 2 deletions src/tinygp/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

__all__ = ["Transform", "Linear", "Cholesky", "Subspace"]

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, Callable
from typing import Any

import equinox as eqx
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def data(random):
def test_sample(data):
X, _ = data

with jax.experimental.enable_x64(True):
with jax.enable_x64(True):
gp = GaussianProcess(
kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_ops(data):

def test_conditioned(data):
x1, x2 = data
with jax.experimental.enable_x64(): # type: ignore
with jax.enable_x64(True): # type: ignore
k1 = 1.5 * kernels.Matern32(2.5)
k2 = 0.9 * kernels.ExpSineSquared(scale=1.5, gamma=0.3)
K = k1(x1, x1) + 0.1 * jnp.eye(x1.shape[0])
Expand Down