diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index 231dfa05cf..230dfcb62f 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index f85f8fc55b..e8ea8cbd1a 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0 diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..4b4ede3935 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -29,7 +29,8 @@ import xarray from arviz import InferenceData, concat, rcParams -from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires +from arviz.data.base import CoordSpec, DimSpec, requires +from arviz_base import dict_to_dataset from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -305,14 +306,14 @@ def posterior_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, @@ -347,14 +348,14 @@ def sample_stats_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, @@ -367,7 +368,11 @@ def posterior_predictive_to_xarray(self): data = self.posterior_predictive dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) @requires(["predictions"]) @@ -376,7 +381,11 @@ def predictions_to_xarray(self): data = self.predictions dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) def priors_to_xarray(self): @@ -399,7 +408,7 @@ def priors_to_xarray(self): if var_names is None else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, ) @@ -414,10 +423,10 @@ def observed_data_to_xarray(self): return None return dict_to_dataset( self.observations, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) @requires("model") @@ -429,10 +438,10 @@ def constant_data_to_xarray(self): xarray_dataset = dict_to_dataset( constant_data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) # provisional handling of scalars in constant @@ -707,9 +716,9 @@ def apply_function_over_dataset( return dict_to_dataset( out_trace, - library=pymc, + inference_library=pymc, dims=dims, coords=coords, - default_dims=list(sample_dims), + sample_dims=list(sample_dims), skip_event_dims=True, ) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 5afd398281..352224026e 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -267,7 +267,9 @@ def _save_sample_stats( sample_stats = dict_to_dataset( sample_stats_dict, attrs=sample_settings_dict, - library=pymc, + inference_library=pymc, + sample_dims=["chain"], + check_conventions=False, ) ikwargs: dict[str, Any] = {"model": model} diff --git a/requirements-dev.txt b/requirements-dev.txt index 22bcdaf9ea..25bc83223e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. +arviz-base arviz>=0.13.0 cachetools>=4.2.1 cloudpickle