|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | from collections.abc import Sequence |
15 | | -from typing import cast |
| 15 | +from typing import Literal |
16 | 16 |
|
17 | | -from arviz import InferenceData, dict_to_dataset |
18 | | -from rich.console import Console |
19 | | -from rich.progress import Progress |
20 | | - |
21 | | -import pymc |
| 17 | +from arviz import InferenceData |
| 18 | +from xarray import Dataset |
22 | 19 |
|
23 | 20 | from pymc.backends.arviz import ( |
24 | | - _DefaultTrace, |
| 21 | + apply_function_over_dataset, |
25 | 22 | coords_and_dims_for_inferencedata, |
26 | | - dataset_to_point_list, |
27 | 23 | ) |
28 | 24 | from pymc.model import Model, modelcontext |
29 | | -from pymc.pytensorf import PointFunc |
30 | | -from pymc.util import default_progress_theme |
31 | 25 |
|
32 | 26 | __all__ = ("compute_log_likelihood", "compute_log_prior") |
33 | 27 |
|
@@ -117,10 +111,10 @@ def compute_log_density( |
117 | 111 | var_names: Sequence[str] | None = None, |
118 | 112 | extend_inferencedata: bool = True, |
119 | 113 | model: Model | None = None, |
120 | | - kind="likelihood", |
| 114 | + kind: Literal["likelihood", "prior"] = "likelihood", |
121 | 115 | sample_dims: Sequence[str] = ("chain", "draw"), |
122 | 116 | progressbar=True, |
123 | | -): |
| 117 | +) -> InferenceData | Dataset: |
124 | 118 | """ |
125 | 119 | Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group |
126 | 120 | """ |
@@ -163,40 +157,20 @@ def compute_log_density( |
163 | 157 | outs=model.logp(vars=vars, sum=False), |
164 | 158 | on_unused_input="ignore", |
165 | 159 | ) |
166 | | - elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn) |
167 | 160 | finally: |
168 | 161 | model.rvs_to_values = original_rvs_to_values |
169 | 162 | model.rvs_to_transforms = original_rvs_to_transforms |
170 | 163 |
|
171 | | - # Ignore Deterministics |
172 | | - posterior_values = posterior[[rv.name for rv in model.free_RVs]] |
173 | | - posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) |
174 | | - |
175 | | - n_pts = len(posterior_pts) |
176 | | - logdens_dict = _DefaultTrace(n_pts) |
177 | | - |
178 | | - with Progress(console=Console(theme=default_progress_theme)) as progress: |
179 | | - task = progress.add_task("Computing log density...", total=n_pts, visible=progressbar) |
180 | | - for idx in range(n_pts): |
181 | | - logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) |
182 | | - for rv_name, rv_logdens in zip(var_names, logdenss_pts): |
183 | | - logdens_dict.insert(rv_name, rv_logdens, idx) |
184 | | - progress.update(task, advance=1) |
185 | | - |
186 | | - logdens_trace = logdens_dict.trace_dict |
187 | | - for key, array in logdens_trace.items(): |
188 | | - logdens_trace[key] = array.reshape( |
189 | | - (*[len(coord) for coord in stacked_dims.values()], *array.shape[1:]) |
190 | | - ) |
191 | | - |
192 | 164 | coords, dims = coords_and_dims_for_inferencedata(model) |
193 | | - logdens_dataset = dict_to_dataset( |
194 | | - logdens_trace, |
195 | | - library=pymc, |
| 165 | + |
| 166 | + logdens_dataset = apply_function_over_dataset( |
| 167 | + elemwise_logdens_fn, |
| 168 | + posterior[[rv.name for rv in model.free_RVs]], |
| 169 | + output_var_names=var_names, |
| 170 | + sample_dims=sample_dims, |
196 | 171 | dims=dims, |
197 | 172 | coords=coords, |
198 | | - default_dims=list(sample_dims), |
199 | | - skip_event_dims=True, |
| 173 | + progressbar=progressbar, |
200 | 174 | ) |
201 | 175 |
|
202 | 176 | if extend_inferencedata: |
|
0 commit comments