@@ -652,6 +652,7 @@ def plot_corner(
652652 return_fig : bool = False ,
653653 hist_kwargs : Union [dict , None ] = None ,
654654 hist2d_kwargs : Union [dict , None ] = None ,
655+ progress_callback : Union [Callable [[int , int ], None ], None ] = None ,
655656):
656657 """Create a corner plot from a Bayesian analysis.
657658
@@ -697,29 +698,32 @@ def plot_corner(
697698 hist2d_kwargs = {}
698699
699700 num_params = len (params )
701+ total_count = num_params + (num_params ** 2 - num_params ) // 2
700702
701703 if fig is None :
702- fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ))
704+ fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ), subplot_kw = { "visible" : False } )
703705 else :
704706 fig .clf ()
705- axes = fig .subplots (num_params , num_params )
707+ axes = fig .subplots (num_params , num_params , subplot_kw = { "visible" : False } )
706708
707709 # i is row, j is column
708- for i , row_param in enumerate (params ):
709- for j , col_param in enumerate (params ):
710- current_axes : Axes = axes [i ][j ]
710+ current_count = 0
711+ for i in range (num_params ):
712+ for j in range (i + 1 ):
713+ row_param = params [i ]
714+ col_param = params [j ]
715+ current_axes : Axes = axes if isinstance (axes , matplotlib .axes .Axes ) else axes [i ][j ]
711716 current_axes .tick_params (which = "both" , labelsize = "medium" )
712717 current_axes .xaxis .offsetText .set_fontsize ("small" )
713718 current_axes .yaxis .offsetText .set_fontsize ("small" )
714-
719+ current_axes . set_visible ( True )
715720 if i == j : # diagonal: histograms
716721 plot_one_hist (results , param = row_param , smooth = smooth , axes = current_axes , ** hist_kwargs )
717722 elif i > j : # lower triangle: 2d histograms
718723 plot_contour (
719724 results , x_param = col_param , y_param = row_param , smooth = smooth , axes = current_axes , ** hist2d_kwargs
720725 )
721- elif i < j : # upper triangle: no plot
722- current_axes .set_visible (False )
726+
723727 # remove label if on inside of corner plot
724728 if j != 0 :
725729 current_axes .get_yaxis ().set_visible (False )
@@ -732,6 +736,9 @@ def plot_corner(
732736 current_axes .yaxis .offset_text_position = "center"
733737 current_axes .set_ylabel ("" )
734738 current_axes .set_xlabel ("" )
739+ if progress_callback is not None :
740+ current_count += 1
741+ progress_callback (current_count , total_count )
735742 if return_fig :
736743 return fig
737744 plt .show (block = block )
@@ -1153,7 +1160,7 @@ def plot_chain(
11531160
11541161 """
11551162 chain = results .chain
1156- nsimulations , nplots = chain .shape
1163+ nsimulations , _ = chain .shape
11571164 # skip is to evenly distribute points plotted
11581165 # all points will be plotted if maxpoints < nsimulations
11591166 skip = max (floor (nsimulations / maxpoints ), 1 )
0 commit comments