Skip to content

Commit 08b3633

Browse files
committed
Add progress callback for panel plot helper
1 parent f848758 commit 08b3633

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

ratapi/utils/plotting.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ def plot_contour(
916916
axes: Axes | None = None,
917917
block: bool = False,
918918
return_fig: bool = False,
919+
progress_callback: Callable[[int, int], None] | None = None,
919920
**hist2d_settings,
920921
):
921922
"""Plot a 2D histogram of two indexed chain parameters, with contours.
@@ -982,7 +983,10 @@ def plot_contour(
982983

983984

984985
def panel_plot_helper(
985-
plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None
986+
plot_func: Callable,
987+
indices: list[int],
988+
fig: matplotlib.figure.Figure | None = None,
989+
progress_callback: Callable[[int, int], None] | None = None,
986990
) -> matplotlib.figure.Figure:
987991
"""Generate a panel-based plot from a single plot function.
988992
@@ -994,6 +998,9 @@ def panel_plot_helper(
994998
The list of indices to pass into ``plot_func``.
995999
fig : matplotlib.figure.Figure, optional
9961000
The figure object to use for plot.
1001+
progress_callback: Union[Callable[[int, int], None], None]
1002+
Callback function for providing progress during plot creation
1003+
First argument is current completed sub plot and second is total number of sub plots
9971004
9981005
Returns
9991006
-------
@@ -1005,21 +1012,21 @@ def panel_plot_helper(
10051012
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
10061013

10071014
if fig is None:
1008-
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
1015+
fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0]
10091016
else:
10101017
fig.clf()
1011-
fig.subplots(nrows, ncols)
1018+
fig.subplots(nrows, ncols, subplot_kw={"visible": False})
10121019
axs = fig.get_axes()
1013-
1020+
current_plot = 0
10141021
for plot_num, index in enumerate(indices):
10151022
axs[plot_num].tick_params(which="both", labelsize="medium")
10161023
axs[plot_num].xaxis.offsetText.set_fontsize("small")
10171024
axs[plot_num].yaxis.offsetText.set_fontsize("small")
1025+
axs[plot_num].set_visible(True)
10181026
plot_func(axs[plot_num], index)
1019-
1020-
# blank unused plots
1021-
for i in range(nplots, len(axs)):
1022-
axs[i].set_visible(False)
1027+
if progress_callback is not None:
1028+
current_plot += 1
1029+
progress_callback(current_plot, nplots)
10231030

10241031
fig.tight_layout()
10251032
return fig
@@ -1036,6 +1043,7 @@ def plot_hists(
10361043
block: bool = False,
10371044
fig: matplotlib.figure.Figure | None = None,
10381045
return_fig: bool = False,
1046+
progress_callback: Callable[[int, int], None] | None = None,
10391047
**hist_settings,
10401048
):
10411049
"""Plot marginalised posteriors for several parameters from a Bayesian analysis.
@@ -1072,6 +1080,9 @@ def plot_hists(
10721080
The figure object to use for plot.
10731081
return_fig: bool, default False
10741082
If True, return the figure as an object instead of showing it.
1083+
progress_callback: Union[Callable[[int, int], None], None]
1084+
Callback function for providing progress during plot creation
1085+
First argument is current completed sub plot and second is total number of sub plots
10751086
hist_settings :
10761087
Settings passed to `np.histogram`. By default, the settings
10771088
passed are `bins = 25` and `density = True`.
@@ -1130,6 +1141,7 @@ def validate_dens_type(dens_type: str | None, param: str):
11301141
),
11311142
params,
11321143
fig,
1144+
progress_callback,
11331145
)
11341146
if return_fig:
11351147
return fig
@@ -1144,6 +1156,7 @@ def plot_chain(
11441156
block: bool = False,
11451157
fig: matplotlib.figure.Figure | None = None,
11461158
return_fig: bool = False,
1159+
progress_callback: Callable[[int, int], None] | None = None,
11471160
):
11481161
"""Plot the MCMC chain for each parameter of a Bayesian analysis.
11491162
@@ -1162,6 +1175,9 @@ def plot_chain(
11621175
The figure object to use for plot.
11631176
return_fig: bool, default False
11641177
If True, return the figure as an object instead of showing it.
1178+
progress_callback: Union[Callable[[int, int], None], None]
1179+
Callback function for providing progress during plot creation
1180+
First argument is current completed sub plot and second is total number of sub plots
11651181
11661182
Returns
11671183
-------
@@ -1187,7 +1203,7 @@ def plot_one_chain(axes: Axes, i: int):
11871203
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
11881204
axes.set_title(results.fitNames[i], fontsize="small")
11891205

1190-
fig = panel_plot_helper(plot_one_chain, params, fig=fig)
1206+
fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback)
11911207
if return_fig:
11921208
return fig
11931209
plt.show(block=block)

0 commit comments

Comments
 (0)