Skip to content
16 changes: 12 additions & 4 deletions docs/pages/arviz.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@
"np.random.seed(11234)\n",
"\n",
"x = np.random.randn(2, 2000)\n",
"data = az.from_dict(\n",
" posterior={\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n",
" sample_stats={\"diverging\": x < -1.2},\n",
")\n",
"try:\n",
" data = az.from_dict(\n",
" posterior={\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n",
" sample_stats={\"diverging\": x < -1.2},\n",
" )\n",
"except TypeError:\n",
" data = az.from_dict(\n",
" {\n",
" \"posterior\": {\"x\": x, \"y\": np.random.randn(2, 2000, 2)},\n",
" \"sample_stats\": {\"diverging\": x < -1.2},\n",
" },\n",
" )\n",
"\n",
"figure = corner.corner(data, divergences=True)"
]
Expand Down
4 changes: 2 additions & 2 deletions readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ submodules:
include: all

build:
os: ubuntu-20.04
os: ubuntu-22.04
tools:
python: "3.10"
python: "3.12"

python:
install:
Expand Down
43 changes: 39 additions & 4 deletions src/corner/arviz_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = ["arviz_corner"]

import logging
import re
from collections.abc import Mapping

import numpy as np
Expand All @@ -12,10 +13,38 @@
except ImportError:
from arviz import convert_to_dataset

from arviz.utils import _var_names, get_coords

# Support multiple versions of arviz
try:
# arviz < 1.0
from arviz.utils import _var_names, get_coords
except ImportError:
# arviz >= 1.0: these functions were removed

def _var_names(var_names, dataset, filter_vars=None):
if var_names is None:
return None
if filter_vars == "like":
return [
v
for v in dataset.data_vars
if any(vn in v for vn in var_names)
]
elif filter_vars == "regex":
return [
v
for v in dataset.data_vars
if any(re.search(vn, v) for vn in var_names)
]
return list(var_names)

def get_coords(dataset, coords):
if not coords:
return dataset
return dataset.sel(coords)


try:
# Very old arviz
from arviz.plots.plot_utils import (
make_label,
xarray_to_ndarray,
Expand All @@ -29,8 +58,14 @@ def _get_labels(plotters, labeller=None):
]

except ImportError:
from arviz.labels import BaseLabeller
from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter
try:
# Medium arviz (< 1.0)
from arviz.labels import BaseLabeller
from arviz.sel_utils import xarray_to_ndarray, xarray_var_iter
except ImportError:
# arviz >= 1.0
from arviz import xarray_to_ndarray, xarray_var_iter
from arviz_base.labels import BaseLabeller

def _get_labels(plotters, labeller=None):
if labeller is None:
Expand Down
18 changes: 13 additions & 5 deletions tests/test_corner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ def _run_corner(
)
elif arviz:
az = pytest.importorskip("arviz")
data = az.from_dict(
posterior={"x": data[None]},
sample_stats={"diverging": data[None, :, 0] < 0.0},
)
try:
data = az.from_dict(
posterior={"x": data[None]},
sample_stats={"diverging": data[None, :, 0] < 0.0},
)
except TypeError:
data = az.from_dict(
{
"posterior": {"x": data[None]},
"sample_stats": {"diverging": data[None, :, 0] < 0.0},
},
)
kwargs["truths"] = {"x": np.random.randn(ndim)}
elif arviz_preview:
az = pytest.importorskip("arviz.preview")
Expand Down Expand Up @@ -293,7 +301,7 @@ def test_top_ticks():
_run_corner(top_ticks=True)


@image_comparison(baseline_images=["pandas"], extensions=["png"])
@image_comparison(baseline_images=["pandas"], extensions=["png"], tol=7)
def test_pandas():
_run_corner(pandas=True)

Expand Down