Skip to content

Commit 389fbea

Browse files
committed
revert
1 parent e73b35a commit 389fbea

File tree

1 file changed

+283
-14
lines changed

1 file changed

+283
-14
lines changed

ratapi/utils/plotting.py

Lines changed: 283 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,262 @@ def _y_update_offset_text_position(axis, _bboxes, bboxes2):
906906
axis.offsetText.set_position((x - x_offset, y))
907907

908908

909+
@assert_bayesian("Corner")
910+
def plot_corner(
911+
results: ratapi.outputs.BayesResults,
912+
params: list[int | str] | None = None,
913+
smooth: bool = True,
914+
block: bool = False,
915+
fig: matplotlib.figure.Figure | None = None,
916+
return_fig: bool = False,
917+
hist_kwargs: dict | None = None,
918+
hist2d_kwargs: dict | None = None,
919+
progress_callback: Callable[[int, int], None] | None = None,
920+
):
921+
"""Create a corner plot from a Bayesian analysis.
922+
923+
Parameters
924+
----------
925+
results : BayesResults
926+
The results from a Bayesian calculation.
927+
params : list[int or str], default None
928+
The indices or names of a subset of parameters if required.
929+
If None, uses all indices.
930+
smooth : bool, default True
931+
Whether to apply Gaussian smoothing to the corner plot.
932+
block : bool, default False
933+
Whether Python should block until the plot is closed.
934+
fig : matplotlib.figure.Figure, optional
935+
The figure object to use for plot.
936+
return_fig: bool, default False
937+
If True, return the figure as an object instead of showing it.
938+
hist_kwargs : dict
939+
Extra keyword arguments to pass to the 1d histograms.
940+
Default is {'density': True, 'bins': 25}
941+
hist2d_kwargs : dict
942+
Extra keyword arguments to pass to the 2d histograms.
943+
Default is {'density': True, 'bins': 25}
944+
progress_callback: Union[Callable[[int, int], None], None]
945+
Callback function for providing progress during plot creation
946+
First argument is current completed sub plot and second is total number of sub plots
947+
948+
Returns
949+
-------
950+
Figure or None
951+
If `return_fig` is True, return the figure - otherwise, return nothing.
952+
953+
"""
954+
fitname_to_index = partial(name_to_index, names=results.fitNames)
955+
956+
if params is None:
957+
params = range(0, len(results.fitNames))
958+
else:
959+
params = list(map(fitname_to_index, params))
960+
961+
# defaults are applied inside each function - just pass blank dicts for now
962+
if hist_kwargs is None:
963+
hist_kwargs = {}
964+
if hist2d_kwargs is None:
965+
hist2d_kwargs = {}
966+
967+
num_params = len(params)
968+
total_count = num_params + (num_params**2 - num_params) // 2
969+
970+
if fig is None:
971+
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10), subplot_kw={"visible": False})
972+
else:
973+
fig.clf()
974+
axes = fig.subplots(num_params, num_params, subplot_kw={"visible": False})
975+
976+
# i is row, j is column
977+
current_count = 0
978+
for i in range(num_params):
979+
for j in range(i + 1):
980+
row_param = params[i]
981+
col_param = params[j]
982+
current_axes: Axes = axes if isinstance(axes, matplotlib.axes.Axes) else axes[i][j]
983+
current_axes.tick_params(which="both", labelsize="medium")
984+
current_axes.xaxis.offsetText.set_fontsize("small")
985+
current_axes.yaxis.offsetText.set_fontsize("small")
986+
current_axes.set_visible(True)
987+
if i == j: # diagonal: histograms
988+
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
989+
elif i > j: # lower triangle: 2d histograms
990+
plot_contour(
991+
results, x_param=col_param, y_param=row_param, smooth=smooth, axes=current_axes, **hist2d_kwargs
992+
)
993+
994+
# remove label if on inside of corner plot
995+
if j != 0:
996+
current_axes.get_yaxis().set_visible(False)
997+
if i != len(params) - 1:
998+
current_axes.get_xaxis().set_visible(False)
999+
# make labels invisible as titles cover that
1000+
current_axes.yaxis._update_offset_text_position = types.MethodType(
1001+
_y_update_offset_text_position, current_axes.yaxis
1002+
)
1003+
current_axes.yaxis.offset_text_position = "center"
1004+
current_axes.set_ylabel("")
1005+
current_axes.set_xlabel("")
1006+
if progress_callback is not None:
1007+
current_count += 1
1008+
progress_callback(current_count, total_count)
1009+
if return_fig:
1010+
return fig
1011+
plt.show(block=block)
1012+
1013+
1014+
@assert_bayesian("Histogram")
1015+
def plot_one_hist(
1016+
results: ratapi.outputs.BayesResults,
1017+
param: int | str,
1018+
smooth: bool = True,
1019+
sigma: float | None = None,
1020+
estimated_density: Literal["normal", "lognor", "kernel", None] = None,
1021+
axes: Axes | None = None,
1022+
block: bool = False,
1023+
return_fig: bool = False,
1024+
**hist_settings,
1025+
):
1026+
"""Plot the marginalised posterior for a parameter of a Bayesian analysis.
1027+
1028+
Parameters
1029+
----------
1030+
results : BayesResults
1031+
The results from a Bayesian calculation.
1032+
param : Union[int, str]
1033+
Either the index or name of a parameter.
1034+
block : bool, default False
1035+
Whether Python should block until the plot is closed.
1036+
smooth : bool, default True
1037+
Whether to apply Gaussian smoothing to the histogram.
1038+
Defaults to True.
1039+
sigma: float or None, default None
1040+
If given, is used as the sigma-parameter for the Gaussian smoothing.
1041+
If None, the default (1/3rd of parameter chain standard deviation) is used.
1042+
estimated_density : 'normal', 'lognor', 'kernel' or None, default None
1043+
If None (default), ignore. Else, add an estimated density
1044+
of the given form on top of the histogram by the following estimations:
1045+
'normal': normal Gaussian.
1046+
'lognor': Log-normal probability density.
1047+
'kernel': kernel density estimation.
1048+
axes: Axes or None, default None
1049+
If provided, plot on the given Axes object.
1050+
block : bool, default False
1051+
Whether Python should block until the plot is closed.
1052+
return_fig: bool, default False
1053+
If True, return the figure as an object instead of showing it.
1054+
**hist_settings :
1055+
Settings passed to `np.histogram`. By default, the settings
1056+
passed are `bins = 25` and `density = True`.
1057+
1058+
Returns
1059+
-------
1060+
Figure or None
1061+
If `return_fig` is True, return the figure - otherwise, return nothing.
1062+
1063+
"""
1064+
chain = results.chain
1065+
param = name_to_index(param, results.fitNames)
1066+
1067+
if axes is None:
1068+
fig, axes = plt.subplots(1, 1)
1069+
else:
1070+
fig = None
1071+
1072+
# apply default settings if not set by user
1073+
default_settings = {"bins": 25, "density": True}
1074+
hist_settings = {**default_settings, **hist_settings}
1075+
1076+
parameter_chain = chain[:, param]
1077+
counts, bins = np.histogram(parameter_chain, **hist_settings)
1078+
mean_y = np.mean(parameter_chain)
1079+
sd_y = np.std(parameter_chain)
1080+
1081+
if smooth:
1082+
if sigma is None:
1083+
sigma = sd_y / 2
1084+
counts = gaussian_filter1d(counts, sigma)
1085+
axes.hist(
1086+
bins[:-1],
1087+
bins,
1088+
weights=counts,
1089+
edgecolor="black",
1090+
linewidth=1.2,
1091+
color="white",
1092+
)
1093+
1094+
axes.set_title(results.fitNames[param], loc="left", fontsize="medium")
1095+
1096+
if estimated_density:
1097+
dx = bins[1] - bins[0]
1098+
if estimated_density == "normal":
1099+
t = np.linspace(mean_y - 3.5 * sd_y, mean_y + 3.5 * sd_y)
1100+
axes.plot(t, norm.pdf(t, loc=mean_y, scale=sd_y**2))
1101+
elif estimated_density == "lognor":
1102+
t = np.linspace(bins[0] - 0.5 * dx, bins[-1] + 2 * dx)
1103+
axes.plot(t, lognorm.pdf(t, np.mean(np.log(parameter_chain)), np.std(np.log(parameter_chain))))
1104+
elif estimated_density == "kernel":
1105+
t = np.linspace(bins[0] - 2 * dx, bins[-1] + 2 * dx, 200)
1106+
kde = gaussian_kde(parameter_chain)
1107+
axes.plot(t, kde.evaluate(t))
1108+
else:
1109+
raise ValueError(
1110+
f"{estimated_density} is not a supported estimated density function."
1111+
" Supported functions are 'normal' 'lognor' or 'kernel'."
1112+
)
1113+
1114+
# adding the estimated density extends the figure range - reset it to histogram range
1115+
x_range = hist_settings.get("range", (parameter_chain.min(), parameter_chain.max()))
1116+
axes.set_xlim(x_range)
1117+
1118+
if fig is not None:
1119+
if return_fig:
1120+
return fig
1121+
plt.show(block=block)
1122+
1123+
1124+
def _y_update_offset_text_position(axis, _bboxes, bboxes2):
1125+
"""Update the position of the Y axis offset text using the provided bounding boxes.
1126+
1127+
Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.
1128+
1129+
Parameters
1130+
----------
1131+
axis : matplotlib.axis.YAxis
1132+
Y axis to update.
1133+
_bboxes : List
1134+
list of bounding boxes
1135+
bboxes2 : List
1136+
list of bounding boxes
1137+
"""
1138+
x, y = axis.offsetText.get_position()
1139+
1140+
if axis.offset_text_position == "left":
1141+
# y in axes coords, x in display coords
1142+
axis.offsetText.set_transform(
1143+
mtransforms.blended_transform_factory(axis.axes.transAxes, mtransforms.IdentityTransform())
1144+
)
1145+
1146+
top = axis.axes.bbox.ymax
1147+
y = top + axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
1148+
1149+
else:
1150+
# x & y in display coords
1151+
axis.offsetText.set_transform(mtransforms.IdentityTransform())
1152+
1153+
# Northwest of upper-right corner of right-hand extent of tick labels
1154+
if bboxes2:
1155+
bbox = mtransforms.Bbox.union(bboxes2)
1156+
else:
1157+
bbox = axis.axes.bbox
1158+
center = bbox.ymin + (bbox.ymax - bbox.ymin) / 2
1159+
x = bbox.xmin - axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
1160+
y = center
1161+
x_offset = 110
1162+
axis.offsetText.set_position((x - x_offset, y))
1163+
1164+
9091165
@assert_bayesian("Contour")
9101166
def plot_contour(
9111167
results: ratapi.outputs.BayesResults,
@@ -982,7 +1238,10 @@ def plot_contour(
9821238

9831239

9841240
def panel_plot_helper(
985-
plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None
1241+
plot_func: Callable,
1242+
indices: list[int],
1243+
fig: matplotlib.figure.Figure | None = None,
1244+
progress_callback: Callable[[int, int], None] | None = None,
9861245
) -> matplotlib.figure.Figure:
9871246
"""Generate a panel-based plot from a single plot function.
9881247
@@ -994,6 +1253,9 @@ def panel_plot_helper(
9941253
The list of indices to pass into ``plot_func``.
9951254
fig : matplotlib.figure.Figure, optional
9961255
The figure object to use for plot.
1256+
progress_callback: Union[Callable[[int, int], None], None]
1257+
Callback function for providing progress during plot creation
1258+
First argument is current completed sub plot and second is total number of sub plots
9971259
9981260
Returns
9991261
-------
@@ -1005,21 +1267,19 @@ def panel_plot_helper(
10051267
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
10061268

10071269
if fig is None:
1008-
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
1270+
fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0]
10091271
else:
10101272
fig.clf()
1011-
fig.subplots(nrows, ncols)
1273+
fig.subplots(nrows, ncols, subplot_kw={"visible": False})
10121274
axs = fig.get_axes()
1013-
1014-
for plot_num, index in enumerate(indices):
1015-
axs[plot_num].tick_params(which="both", labelsize="medium")
1016-
axs[plot_num].xaxis.offsetText.set_fontsize("small")
1017-
axs[plot_num].yaxis.offsetText.set_fontsize("small")
1018-
plot_func(axs[plot_num], index)
1019-
1020-
# blank unused plots
1021-
for i in range(nplots, len(axs)):
1022-
axs[i].set_visible(False)
1275+
for index, plot_num in enumerate(indices):
1276+
axs[index].tick_params(which="both", labelsize="medium")
1277+
axs[index].xaxis.offsetText.set_fontsize("small")
1278+
axs[index].yaxis.offsetText.set_fontsize("small")
1279+
axs[index].set_visible(True)
1280+
plot_func(axs[index], plot_num)
1281+
if progress_callback is not None:
1282+
progress_callback(index, nplots)
10231283

10241284
fig.tight_layout()
10251285
return fig
@@ -1036,6 +1296,7 @@ def plot_hists(
10361296
block: bool = False,
10371297
fig: matplotlib.figure.Figure | None = None,
10381298
return_fig: bool = False,
1299+
progress_callback: Callable[[int, int], None] | None = None,
10391300
**hist_settings,
10401301
):
10411302
"""Plot marginalised posteriors for several parameters from a Bayesian analysis.
@@ -1072,6 +1333,9 @@ def plot_hists(
10721333
The figure object to use for plot.
10731334
return_fig: bool, default False
10741335
If True, return the figure as an object instead of showing it.
1336+
progress_callback: Union[Callable[[int, int], None], None]
1337+
Callback function for providing progress during plot creation
1338+
First argument is current completed sub plot and second is total number of sub plots
10751339
hist_settings :
10761340
Settings passed to `np.histogram`. By default, the settings
10771341
passed are `bins = 25` and `density = True`.
@@ -1130,6 +1394,7 @@ def validate_dens_type(dens_type: str | None, param: str):
11301394
),
11311395
params,
11321396
fig,
1397+
progress_callback,
11331398
)
11341399
if return_fig:
11351400
return fig
@@ -1144,6 +1409,7 @@ def plot_chain(
11441409
block: bool = False,
11451410
fig: matplotlib.figure.Figure | None = None,
11461411
return_fig: bool = False,
1412+
progress_callback: Callable[[int, int], None] | None = None,
11471413
):
11481414
"""Plot the MCMC chain for each parameter of a Bayesian analysis.
11491415
@@ -1162,6 +1428,9 @@ def plot_chain(
11621428
The figure object to use for plot.
11631429
return_fig: bool, default False
11641430
If True, return the figure as an object instead of showing it.
1431+
progress_callback: Union[Callable[[int, int], None], None]
1432+
Callback function for providing progress during plot creation
1433+
First argument is current completed sub plot and second is total number of sub plots
11651434
11661435
Returns
11671436
-------
@@ -1187,7 +1456,7 @@ def plot_one_chain(axes: Axes, i: int):
11871456
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
11881457
axes.set_title(results.fitNames[i], fontsize="small")
11891458

1190-
fig = panel_plot_helper(plot_one_chain, params, fig=fig)
1459+
fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback)
11911460
if return_fig:
11921461
return fig
11931462
plt.show(block=block)

0 commit comments

Comments
 (0)