Skip to content

Commit 95aee8b

Browse files
committed
fixed scalefactor in 3d tests, added DREAM, and made histograms prettier
1 parent 5e862c0 commit 95aee8b

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

RATapi/examples/bayes_benchmark/bayes_benchmark.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,20 @@ def bayes_benchmark_2d(grid_size: int) -> (RAT.outputs.BayesResults, Calculation
5959
"""
6060
problem = RAT.utils.convert.r1_to_project_class(str(PWD / "defaultR1ProjectTemplate.mat"))
6161

62-
ns_controls = RAT.Controls(procedure="ns", calcSldDuringFit=True, nsTolerance=1, nLive=500, display="final")
63-
62+
ns_controls = RAT.Controls(procedure="ns", nsTolerance=1, nLive=500, display="final")
6463
_, ns_results = RAT.run(problem, ns_controls)
6564

65+
dream_controls = RAT.Controls(procedure="dream", display="final")
66+
_, dream_results = RAT.run(problem, dream_controls)
67+
6668
# now we get the parameters and use them to do a direct calculation
6769
rough_param = problem.parameters[0]
6870
roughness = np.linspace(rough_param.min, rough_param.max, grid_size)
6971

7072
back_param = problem.background_parameters[0]
7173
background = np.linspace(back_param.min, back_param.max, grid_size)
7274

73-
controls = RAT.Controls(procedure="calculate", calcSldDuringFit=True, display="off")
75+
controls = RAT.Controls(procedure="calculate", display="off")
7476

7577
def calculate_posterior(roughness_index: int, background_index: int) -> float:
7678
"""Calculate the posterior for an item in the roughness and background vectors.
@@ -100,7 +102,7 @@ def calculate_posterior(roughness_index: int, background_index: int) -> float:
100102
print("Calculating posterior directly...")
101103
probability_array = vectorized_calc_posterior(*np.indices((grid_size, grid_size), dtype=int))
102104

103-
return ns_results, CalculationResults(x_data=[roughness, background], distribution=probability_array)
105+
return ns_results, dream_results, CalculationResults(x_data=[roughness, background], distribution=probability_array)
104106

105107

106108
def bayes_benchmark_3d(grid_size: int) -> (RAT.outputs.BayesResults, CalculationResults):
@@ -124,11 +126,15 @@ def bayes_benchmark_3d(grid_size: int) -> (RAT.outputs.BayesResults, Calculation
124126
125127
"""
126128
problem = RAT.utils.convert.r1_to_project_class(str(PWD / "defaultR1ProjectTemplate.mat"))
129+
problem.scalefactors[0].min = 0.07
130+
problem.scalefactors[0].max = 0.13
127131

128-
ns_controls = RAT.Controls(procedure="ns", calcSldDuringFit=True, nsTolerance=1, nLive=500, display="final")
129-
132+
ns_controls = RAT.Controls(procedure="ns", nsTolerance=1, nLive=500, display="final")
130133
_, ns_results = RAT.run(problem, ns_controls)
131134

135+
dream_controls = RAT.Controls(procedure="dream", display="final")
136+
_, dream_results = RAT.run(problem, dream_controls)
137+
132138
# now we get the parameters and use them to do a direct calculation
133139
rough_param = problem.parameters[0]
134140
roughness = np.linspace(rough_param.min, rough_param.max, grid_size)
@@ -172,21 +178,29 @@ def calculate_posterior(roughness_index: int, background_index: int, scalefactor
172178
print("Calculating posterior directly...")
173179
probability_array = vectorized_calc_posterior(*np.indices((grid_size, grid_size, grid_size), dtype=int))
174180

175-
return ns_results, CalculationResults(x_data=[roughness, background, scalefactor], distribution=probability_array)
181+
return (
182+
ns_results,
183+
dream_results,
184+
CalculationResults(x_data=[roughness, background, scalefactor], distribution=probability_array),
185+
)
176186

177187

178-
def plot_posterior_comparison(ns_results: RAT.outputs.BayesResults, calc_results: CalculationResults):
188+
def plot_posterior_comparison(
189+
ns_results: RAT.outputs.BayesResults, dream_results: RAT.outputs.BayesResults, calc_results: CalculationResults
190+
):
179191
"""Create a grid of marginalised posteriors comparing different calculation methods.
180192
181193
Parameters
182194
----------
183195
ns_results : RAT.BayesResults
184196
The BayesResults object from a nested sampler calculation.
197+
dream_results : RAT.BayesResults
198+
The BayesResults object from a DREAM calculation.
185199
calc_results : CalculationResults
186200
The results from a direct calculation.
187201
"""
188202
num_params = calc_results.distribution.ndim
189-
fig, axes = plt.subplots(2, num_params)
203+
fig, axes = plt.subplots(3, num_params, figsize=(3 * num_params, 9))
190204

191205
def plot_marginalised_result(dimension: int, axes: plt.Axes, limits: tuple[float]):
192206
"""Plot a histogram of a marginalised posterior from the calculation results.
@@ -221,20 +235,21 @@ def plot_marginalised_result(dimension: int, axes: plt.Axes, limits: tuple[float
221235
# row 0 contains NS histograms for each parameter
222236
# row 1 contains direct calculation histograms for each parameter
223237
for i in range(0, num_params):
224-
RATplot.plot_one_hist(ns_results, i, smooth=False, axes=axes[0][i])
225-
plot_marginalised_result(i, axes[1][i], limits=axes[0][i].get_xlim())
238+
RATplot.plot_one_hist(ns_results, i, axes=axes[0][i])
239+
RATplot.plot_one_hist(dream_results, i, axes=axes[1][i])
240+
plot_marginalised_result(i, axes[2][i], limits=axes[0][i].get_xlim())
226241

227242
axes[0][0].set_ylabel("nested sampler")
228-
axes[1][0].set_ylabel("direct calculation")
243+
axes[1][0].set_ylabel("DREAM")
244+
axes[2][0].set_ylabel("direct calculation")
229245

230246
fig.tight_layout()
231247
fig.show()
232248

233249

234250
if __name__ == "__main__":
235-
ns_2d, calc_2d = bayes_benchmark_2d(30)
236-
ns_3d, calc_3d = bayes_benchmark_3d(40)
237-
238-
plot_posterior_comparison(ns_2d, calc_2d)
251+
ns_2d, dream_2d, calc_2d = bayes_benchmark_2d(30)
252+
ns_3d, dream_3d, calc_3d = bayes_benchmark_3d(40)
239253

240-
plot_posterior_comparison(ns_3d, calc_3d)
254+
plot_posterior_comparison(ns_2d, dream_2d, calc_2d)
255+
plot_posterior_comparison(ns_3d, dream_3d, calc_3d)

0 commit comments

Comments
 (0)