diff --git a/docs/pages/arviz.ipynb b/docs/pages/arviz.ipynb index 8c1989f..fcd6e4d 100644 --- a/docs/pages/arviz.ipynb +++ b/docs/pages/arviz.ipynb @@ -26,10 +26,18 @@ "np.random.seed(11234)\n", "\n", "x = np.random.randn(2, 2000)\n", - "data = az.from_dict(\n", - " posterior={\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n", - " sample_stats={\"diverging\": x < -1.2},\n", - ")\n", + "try:\n", + " data = az.from_dict(\n", + " posterior={\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n", + " sample_stats={\"diverging\": x < -1.2},\n", + " )\n", + "except TypeError:\n", + " data = az.from_dict(\n", + " {\n", + " \"posterior\": {\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n", + " \"sample_stats\": {\"diverging\": x < -1.2},\n", + " },\n", + " )\n", "\n", "figure = corner.corner(data, divergences=True)" ] diff --git a/readthedocs.yaml b/readthedocs.yaml index b59224b..50c96ae 100644 --- a/readthedocs.yaml +++ b/readthedocs.yaml @@ -4,9 +4,9 @@ submodules: include: all build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.10" + python: "3.12" python: install: diff --git a/src/corner/arviz_corner.py b/src/corner/arviz_corner.py index 699365d..1b17fac 100644 --- a/src/corner/arviz_corner.py +++ b/src/corner/arviz_corner.py @@ -3,6 +3,7 @@ __all__ = ["arviz_corner"] import logging +import re from collections.abc import Mapping import numpy as np @@ -12,10 +13,38 @@ except ImportError: from arviz import convert_to_dataset -from arviz.utils import _var_names, get_coords - # Support multiple versions of arviz try: + # arviz < 1.0 + from arviz.utils import _var_names, get_coords +except ImportError: + # arviz >= 1.0: these functions were removed + + def _var_names(var_names, dataset, filter_vars=None): + if var_names is None: + return None + if filter_vars == "like": + return [ + v + for v in dataset.data_vars + if any(vn in v for vn in var_names) + ] + elif filter_vars == "regex": + return [ + v + for v in dataset.data_vars + if any(re.search(vn, v) for vn in var_names) + ] + return list(var_names) + + def get_coords(dataset, coords): + if not coords: + return dataset + return dataset.sel(coords) + + +try: + # Very old arviz from arviz.plots.plot_utils import ( make_label, xarray_to_ndarray, @@ -29,8 +58,14 @@ def _get_labels(plotters, labeller=None): ] except ImportError: - from arviz.labels import BaseLabeller - from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter + try: + # Medium arviz (< 1.0) + from arviz.labels import BaseLabeller + from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter + except ImportError: + # arviz >= 1.0 + from arviz import xarray_to_ndarray, xarray_var_iter + from arviz_base.labels import BaseLabeller def _get_labels(plotters, labeller=None): if labeller is None: diff --git a/tests/test_corner.py b/tests/test_corner.py index 4057361..9dd778e 100644 --- a/tests/test_corner.py +++ b/tests/test_corner.py @@ -39,10 +39,18 @@ def _run_corner( ) elif arviz: az = pytest.importorskip("arviz") - data = az.from_dict( - posterior={"x": data[None]}, - sample_stats={"diverging": data[None, :, 0] < 0.0}, - ) + try: + data = az.from_dict( + posterior={"x": data[None]}, + sample_stats={"diverging": data[None, :, 0] < 0.0}, + ) + except TypeError: + data = az.from_dict( + { + "posterior": {"x": data[None]}, + "sample_stats": {"diverging": data[None, :, 0] < 0.0}, + }, + ) kwargs["truths"] = {"x": np.random.randn(ndim)} elif arviz_preview: az = pytest.importorskip("arviz.preview") @@ -293,7 +301,7 @@ def test_top_ticks(): _run_corner(top_ticks=True) -@image_comparison(baseline_images=["pandas"], extensions=["png"]) +@image_comparison(baseline_images=["pandas"], extensions=["png"], tol=7) def test_pandas(): _run_corner(pandas=True)