Skip to content

Commit eb88023

Browse files
committed
Adds progress callback to corner plot and updates example to use scipy
1 parent 53c2618 commit eb88023

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

ratapi/examples/normal_reflectivity/custom_XY_DSPC.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""A custom XY model for a supported DSPC bilayer."""
22

3-
import math
4-
53
import numpy as np
4+
from scipy.special import erf
65

76

87
def custom_XY_DSPC(params, bulk_in, bulk_out, contrast):
@@ -114,7 +113,7 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast):
114113
totSLD = sldSilicon + sldOxide + sldHeadL + sldTails + sldHeadR + sldWat
115114

116115
# Make the SLD array for output
117-
SLD = [[a, b] for (a, b) in zip(z, totSLD)]
116+
SLD = np.column_stack((z, totSLD))
118117

119118
return SLD, subRough
120119

@@ -132,8 +131,8 @@ def layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R):
132131
a = (z - left) / ((2**0.5) * Sigma_L)
133132
b = (z - right) / ((2**0.5) * Sigma_R)
134133

135-
erf_a = np.array([math.erf(value) for value in a])
136-
erf_b = np.array([math.erf(value) for value in b])
134+
erf_a = erf(a)
135+
erf_b = erf(b)
137136

138137
VF = np.array((height / 2) * (erf_a - erf_b))
139138

ratapi/utils/plotting.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def plot_corner(
652652
return_fig: bool = False,
653653
hist_kwargs: Union[dict, None] = None,
654654
hist2d_kwargs: Union[dict, None] = None,
655+
progress_callback: Union[Callable[[int, int], None], None] = None,
655656
):
656657
"""Create a corner plot from a Bayesian analysis.
657658
@@ -697,29 +698,32 @@ def plot_corner(
697698
hist2d_kwargs = {}
698699

699700
num_params = len(params)
701+
total_count = num_params + (num_params**2 - num_params) // 2
700702

701703
if fig is None:
702-
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10))
704+
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10), subplot_kw={"visible": False})
703705
else:
704706
fig.clf()
705-
axes = fig.subplots(num_params, num_params)
707+
axes = fig.subplots(num_params, num_params, subplot_kw={"visible": False})
706708

707709
# i is row, j is column
708-
for i, row_param in enumerate(params):
709-
for j, col_param in enumerate(params):
710-
current_axes: Axes = axes[i][j]
710+
current_count = 0
711+
for i in range(num_params):
712+
for j in range(i + 1):
713+
row_param = params[i]
714+
col_param = params[j]
715+
current_axes: Axes = axes if isinstance(axes, matplotlib.axes.Axes) else axes[i][j]
711716
current_axes.tick_params(which="both", labelsize="medium")
712717
current_axes.xaxis.offsetText.set_fontsize("small")
713718
current_axes.yaxis.offsetText.set_fontsize("small")
714-
719+
current_axes.set_visible(True)
715720
if i == j: # diagonal: histograms
716721
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
717722
elif i > j: # lower triangle: 2d histograms
718723
plot_contour(
719724
results, x_param=col_param, y_param=row_param, smooth=smooth, axes=current_axes, **hist2d_kwargs
720725
)
721-
elif i < j: # upper triangle: no plot
722-
current_axes.set_visible(False)
726+
723727
# remove label if on inside of corner plot
724728
if j != 0:
725729
current_axes.get_yaxis().set_visible(False)
@@ -732,6 +736,9 @@ def plot_corner(
732736
current_axes.yaxis.offset_text_position = "center"
733737
current_axes.set_ylabel("")
734738
current_axes.set_xlabel("")
739+
if progress_callback is not None:
740+
current_count += 1
741+
progress_callback(current_count, total_count)
735742
if return_fig:
736743
return fig
737744
plt.show(block=block)
@@ -1153,7 +1160,7 @@ def plot_chain(
11531160
11541161
"""
11551162
chain = results.chain
1156-
nsimulations, nplots = chain.shape
1163+
nsimulations, _ = chain.shape
11571164
# skip is to evenly distribute points plotted
11581165
# all points will be plotted if maxpoints < nsimulations
11591166
skip = max(floor(nsimulations / maxpoints), 1)

0 commit comments

Comments
 (0)