@@ -906,6 +906,262 @@ def _y_update_offset_text_position(axis, _bboxes, bboxes2):
906906 axis .offsetText .set_position ((x - x_offset , y ))
907907
908908
909+ @assert_bayesian ("Corner" )
910+ def plot_corner (
911+ results : ratapi .outputs .BayesResults ,
912+ params : list [int | str ] | None = None ,
913+ smooth : bool = True ,
914+ block : bool = False ,
915+ fig : matplotlib .figure .Figure | None = None ,
916+ return_fig : bool = False ,
917+ hist_kwargs : dict | None = None ,
918+ hist2d_kwargs : dict | None = None ,
919+ progress_callback : Callable [[int , int ], None ] | None = None ,
920+ ):
921+ """Create a corner plot from a Bayesian analysis.
922+
923+ Parameters
924+ ----------
925+ results : BayesResults
926+ The results from a Bayesian calculation.
927+ params : list[int or str], default None
928+ The indices or names of a subset of parameters if required.
929+ If None, uses all indices.
930+ smooth : bool, default True
931+ Whether to apply Gaussian smoothing to the corner plot.
932+ block : bool, default False
933+ Whether Python should block until the plot is closed.
934+ fig : matplotlib.figure.Figure, optional
935+ The figure object to use for plot.
936+ return_fig: bool, default False
937+ If True, return the figure as an object instead of showing it.
938+ hist_kwargs : dict
939+ Extra keyword arguments to pass to the 1d histograms.
940+ Default is {'density': True, 'bins': 25}
941+ hist2d_kwargs : dict
942+ Extra keyword arguments to pass to the 2d histograms.
943+ Default is {'density': True, 'bins': 25}
944+ progress_callback: Union[Callable[[int, int], None], None]
945+ Callback function for providing progress during plot creation
946+ First argument is current completed sub plot and second is total number of sub plots
947+
948+ Returns
949+ -------
950+ Figure or None
951+ If `return_fig` is True, return the figure - otherwise, return nothing.
952+
953+ """
954+ fitname_to_index = partial (name_to_index , names = results .fitNames )
955+
956+ if params is None :
957+ params = range (0 , len (results .fitNames ))
958+ else :
959+ params = list (map (fitname_to_index , params ))
960+
961+ # defaults are applied inside each function - just pass blank dicts for now
962+ if hist_kwargs is None :
963+ hist_kwargs = {}
964+ if hist2d_kwargs is None :
965+ hist2d_kwargs = {}
966+
967+ num_params = len (params )
968+ total_count = num_params + (num_params ** 2 - num_params ) // 2
969+
970+ if fig is None :
971+ fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ), subplot_kw = {"visible" : False })
972+ else :
973+ fig .clf ()
974+ axes = fig .subplots (num_params , num_params , subplot_kw = {"visible" : False })
975+
976+ # i is row, j is column
977+ current_count = 0
978+ for i in range (num_params ):
979+ for j in range (i + 1 ):
980+ row_param = params [i ]
981+ col_param = params [j ]
982+ current_axes : Axes = axes if isinstance (axes , matplotlib .axes .Axes ) else axes [i ][j ]
983+ current_axes .tick_params (which = "both" , labelsize = "medium" )
984+ current_axes .xaxis .offsetText .set_fontsize ("small" )
985+ current_axes .yaxis .offsetText .set_fontsize ("small" )
986+ current_axes .set_visible (True )
987+ if i == j : # diagonal: histograms
988+ plot_one_hist (results , param = row_param , smooth = smooth , axes = current_axes , ** hist_kwargs )
989+ elif i > j : # lower triangle: 2d histograms
990+ plot_contour (
991+ results , x_param = col_param , y_param = row_param , smooth = smooth , axes = current_axes , ** hist2d_kwargs
992+ )
993+
994+ # remove label if on inside of corner plot
995+ if j != 0 :
996+ current_axes .get_yaxis ().set_visible (False )
997+ if i != len (params ) - 1 :
998+ current_axes .get_xaxis ().set_visible (False )
999+ # make labels invisible as titles cover that
1000+ current_axes .yaxis ._update_offset_text_position = types .MethodType (
1001+ _y_update_offset_text_position , current_axes .yaxis
1002+ )
1003+ current_axes .yaxis .offset_text_position = "center"
1004+ current_axes .set_ylabel ("" )
1005+ current_axes .set_xlabel ("" )
1006+ if progress_callback is not None :
1007+ current_count += 1
1008+ progress_callback (current_count , total_count )
1009+ if return_fig :
1010+ return fig
1011+ plt .show (block = block )
1012+
1013+
1014+ @assert_bayesian ("Histogram" )
1015+ def plot_one_hist (
1016+ results : ratapi .outputs .BayesResults ,
1017+ param : int | str ,
1018+ smooth : bool = True ,
1019+ sigma : float | None = None ,
1020+ estimated_density : Literal ["normal" , "lognor" , "kernel" , None ] = None ,
1021+ axes : Axes | None = None ,
1022+ block : bool = False ,
1023+ return_fig : bool = False ,
1024+ ** hist_settings ,
1025+ ):
1026+ """Plot the marginalised posterior for a parameter of a Bayesian analysis.
1027+
1028+ Parameters
1029+ ----------
1030+ results : BayesResults
1031+ The results from a Bayesian calculation.
1032+ param : Union[int, str]
1033+ Either the index or name of a parameter.
1034+ block : bool, default False
1035+ Whether Python should block until the plot is closed.
1036+ smooth : bool, default True
1037+ Whether to apply Gaussian smoothing to the histogram.
1038+ Defaults to True.
1039+ sigma: float or None, default None
1040+ If given, is used as the sigma-parameter for the Gaussian smoothing.
1041+ If None, the default (1/3rd of parameter chain standard deviation) is used.
1042+ estimated_density : 'normal', 'lognor', 'kernel' or None, default None
1043+ If None (default), ignore. Else, add an estimated density
1044+ of the given form on top of the histogram by the following estimations:
1045+ 'normal': normal Gaussian.
1046+ 'lognor': Log-normal probability density.
1047+ 'kernel': kernel density estimation.
1048+ axes: Axes or None, default None
1049+ If provided, plot on the given Axes object.
1050+ block : bool, default False
1051+ Whether Python should block until the plot is closed.
1052+ return_fig: bool, default False
1053+ If True, return the figure as an object instead of showing it.
1054+ **hist_settings :
1055+ Settings passed to `np.histogram`. By default, the settings
1056+ passed are `bins = 25` and `density = True`.
1057+
1058+ Returns
1059+ -------
1060+ Figure or None
1061+ If `return_fig` is True, return the figure - otherwise, return nothing.
1062+
1063+ """
1064+ chain = results .chain
1065+ param = name_to_index (param , results .fitNames )
1066+
1067+ if axes is None :
1068+ fig , axes = plt .subplots (1 , 1 )
1069+ else :
1070+ fig = None
1071+
1072+ # apply default settings if not set by user
1073+ default_settings = {"bins" : 25 , "density" : True }
1074+ hist_settings = {** default_settings , ** hist_settings }
1075+
1076+ parameter_chain = chain [:, param ]
1077+ counts , bins = np .histogram (parameter_chain , ** hist_settings )
1078+ mean_y = np .mean (parameter_chain )
1079+ sd_y = np .std (parameter_chain )
1080+
1081+ if smooth :
1082+ if sigma is None :
1083+ sigma = sd_y / 2
1084+ counts = gaussian_filter1d (counts , sigma )
1085+ axes .hist (
1086+ bins [:- 1 ],
1087+ bins ,
1088+ weights = counts ,
1089+ edgecolor = "black" ,
1090+ linewidth = 1.2 ,
1091+ color = "white" ,
1092+ )
1093+
1094+ axes .set_title (results .fitNames [param ], loc = "left" , fontsize = "medium" )
1095+
1096+ if estimated_density :
1097+ dx = bins [1 ] - bins [0 ]
1098+ if estimated_density == "normal" :
1099+ t = np .linspace (mean_y - 3.5 * sd_y , mean_y + 3.5 * sd_y )
1100+ axes .plot (t , norm .pdf (t , loc = mean_y , scale = sd_y ** 2 ))
1101+ elif estimated_density == "lognor" :
1102+ t = np .linspace (bins [0 ] - 0.5 * dx , bins [- 1 ] + 2 * dx )
1103+ axes .plot (t , lognorm .pdf (t , np .mean (np .log (parameter_chain )), np .std (np .log (parameter_chain ))))
1104+ elif estimated_density == "kernel" :
1105+ t = np .linspace (bins [0 ] - 2 * dx , bins [- 1 ] + 2 * dx , 200 )
1106+ kde = gaussian_kde (parameter_chain )
1107+ axes .plot (t , kde .evaluate (t ))
1108+ else :
1109+ raise ValueError (
1110+ f"{ estimated_density } is not a supported estimated density function."
1111+ " Supported functions are 'normal' 'lognor' or 'kernel'."
1112+ )
1113+
1114+ # adding the estimated density extends the figure range - reset it to histogram range
1115+ x_range = hist_settings .get ("range" , (parameter_chain .min (), parameter_chain .max ()))
1116+ axes .set_xlim (x_range )
1117+
1118+ if fig is not None :
1119+ if return_fig :
1120+ return fig
1121+ plt .show (block = block )
1122+
1123+
1124+ def _y_update_offset_text_position (axis , _bboxes , bboxes2 ):
1125+ """Update the position of the Y axis offset text using the provided bounding boxes.
1126+
1127+ Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.
1128+
1129+ Parameters
1130+ ----------
1131+ axis : matplotlib.axis.YAxis
1132+ Y axis to update.
1133+ _bboxes : List
1134+ list of bounding boxes
1135+ bboxes2 : List
1136+ list of bounding boxes
1137+ """
1138+ x , y = axis .offsetText .get_position ()
1139+
1140+ if axis .offset_text_position == "left" :
1141+ # y in axes coords, x in display coords
1142+ axis .offsetText .set_transform (
1143+ mtransforms .blended_transform_factory (axis .axes .transAxes , mtransforms .IdentityTransform ())
1144+ )
1145+
1146+ top = axis .axes .bbox .ymax
1147+ y = top + axis .OFFSETTEXTPAD * axis .figure .dpi / 72.0
1148+
1149+ else :
1150+ # x & y in display coords
1151+ axis .offsetText .set_transform (mtransforms .IdentityTransform ())
1152+
1153+ # Northwest of upper-right corner of right-hand extent of tick labels
1154+ if bboxes2 :
1155+ bbox = mtransforms .Bbox .union (bboxes2 )
1156+ else :
1157+ bbox = axis .axes .bbox
1158+ center = bbox .ymin + (bbox .ymax - bbox .ymin ) / 2
1159+ x = bbox .xmin - axis .OFFSETTEXTPAD * axis .figure .dpi / 72.0
1160+ y = center
1161+ x_offset = 110
1162+ axis .offsetText .set_position ((x - x_offset , y ))
1163+
1164+
9091165@assert_bayesian ("Contour" )
9101166def plot_contour (
9111167 results : ratapi .outputs .BayesResults ,
@@ -982,7 +1238,10 @@ def plot_contour(
9821238
9831239
9841240def panel_plot_helper (
985- plot_func : Callable , indices : list [int ], fig : matplotlib .figure .Figure | None = None
1241+ plot_func : Callable ,
1242+ indices : list [int ],
1243+ fig : matplotlib .figure .Figure | None = None ,
1244+ progress_callback : Callable [[int , int ], None ] | None = None ,
9861245) -> matplotlib .figure .Figure :
9871246 """Generate a panel-based plot from a single plot function.
9881247
@@ -994,6 +1253,9 @@ def panel_plot_helper(
9941253 The list of indices to pass into ``plot_func``.
9951254 fig : matplotlib.figure.Figure, optional
9961255 The figure object to use for plot.
1256+ progress_callback: Union[Callable[[int, int], None], None]
1257+ Callback function for providing progress during plot creation
1258+ First argument is current completed sub plot and second is total number of sub plots
9971259
9981260 Returns
9991261 -------
@@ -1005,21 +1267,19 @@ def panel_plot_helper(
10051267 nrows , ncols = ceil (sqrt (nplots )), round (sqrt (nplots ))
10061268
10071269 if fig is None :
1008- fig = plt .subplots (nrows , ncols , figsize = (11 , 10 ))[0 ]
1270+ fig = plt .subplots (nrows , ncols , figsize = (11 , 10 ), subplot_kw = { "visible" : False } )[0 ]
10091271 else :
10101272 fig .clf ()
1011- fig .subplots (nrows , ncols )
1273+ fig .subplots (nrows , ncols , subplot_kw = { "visible" : False } )
10121274 axs = fig .get_axes ()
1013-
1014- for plot_num , index in enumerate (indices ):
1015- axs [plot_num ].tick_params (which = "both" , labelsize = "medium" )
1016- axs [plot_num ].xaxis .offsetText .set_fontsize ("small" )
1017- axs [plot_num ].yaxis .offsetText .set_fontsize ("small" )
1018- plot_func (axs [plot_num ], index )
1019-
1020- # blank unused plots
1021- for i in range (nplots , len (axs )):
1022- axs [i ].set_visible (False )
1275+ for index , plot_num in enumerate (indices ):
1276+ axs [index ].tick_params (which = "both" , labelsize = "medium" )
1277+ axs [index ].xaxis .offsetText .set_fontsize ("small" )
1278+ axs [index ].yaxis .offsetText .set_fontsize ("small" )
1279+ axs [index ].set_visible (True )
1280+ plot_func (axs [index ], plot_num )
1281+ if progress_callback is not None :
1282+ progress_callback (index , nplots )
10231283
10241284 fig .tight_layout ()
10251285 return fig
@@ -1036,6 +1296,7 @@ def plot_hists(
10361296 block : bool = False ,
10371297 fig : matplotlib .figure .Figure | None = None ,
10381298 return_fig : bool = False ,
1299+ progress_callback : Callable [[int , int ], None ] | None = None ,
10391300 ** hist_settings ,
10401301):
10411302 """Plot marginalised posteriors for several parameters from a Bayesian analysis.
@@ -1072,6 +1333,9 @@ def plot_hists(
10721333 The figure object to use for plot.
10731334 return_fig: bool, default False
10741335 If True, return the figure as an object instead of showing it.
1336+ progress_callback: Union[Callable[[int, int], None], None]
1337+ Callback function for providing progress during plot creation
1338+ First argument is current completed sub plot and second is total number of sub plots
10751339 hist_settings :
10761340 Settings passed to `np.histogram`. By default, the settings
10771341 passed are `bins = 25` and `density = True`.
@@ -1130,6 +1394,7 @@ def validate_dens_type(dens_type: str | None, param: str):
11301394 ),
11311395 params ,
11321396 fig ,
1397+ progress_callback ,
11331398 )
11341399 if return_fig :
11351400 return fig
@@ -1144,6 +1409,7 @@ def plot_chain(
11441409 block : bool = False ,
11451410 fig : matplotlib .figure .Figure | None = None ,
11461411 return_fig : bool = False ,
1412+ progress_callback : Callable [[int , int ], None ] | None = None ,
11471413):
11481414 """Plot the MCMC chain for each parameter of a Bayesian analysis.
11491415
@@ -1162,6 +1428,9 @@ def plot_chain(
11621428 The figure object to use for plot.
11631429 return_fig: bool, default False
11641430 If True, return the figure as an object instead of showing it.
1431+ progress_callback: Union[Callable[[int, int], None], None]
1432+ Callback function for providing progress during plot creation
1433+ First argument is current completed sub plot and second is total number of sub plots
11651434
11661435 Returns
11671436 -------
@@ -1187,7 +1456,7 @@ def plot_one_chain(axes: Axes, i: int):
11871456 axes .plot (range (0 , nsimulations , skip ), chain [:, i ][0 :nsimulations :skip ])
11881457 axes .set_title (results .fitNames [i ], fontsize = "small" )
11891458
1190- fig = panel_plot_helper (plot_one_chain , params , fig = fig )
1459+ fig = panel_plot_helper (plot_one_chain , params , fig , progress_callback )
11911460 if return_fig :
11921461 return fig
11931462 plt .show (block = block )
0 commit comments