Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1bd7a08
BUG: catch that empty pickle files have non-zero size
ColmTalbot Oct 24, 2025
8bee9e0
FEAT: improve user pool passing
ColmTalbot Oct 24, 2025
479c100
FEAT: improve reweighting parallelisation
ColmTalbot Oct 24, 2025
978c696
Merge remote-tracking branch 'origin/main' into user-pools
ColmTalbot Oct 24, 2025
dcd05dd
FEAT: add parameters as argument to new pool
ColmTalbot Oct 24, 2025
86b993c
BUG: test that pool exists at cleanup
ColmTalbot Oct 24, 2025
c51e15c
BUG: test pool exists at closing
ColmTalbot Oct 24, 2025
c4654a9
REFACTOR: refactor run_sampler to simplify pool logic
ColmTalbot Oct 24, 2025
ba1df87
DEP: discourage setting up pool in sampler
ColmTalbot Oct 24, 2025
2025cf0
REFACTOR: remove top level multiprocessing import
ColmTalbot Oct 24, 2025
10e4267
BUG: make sure prior is passed to pool creation
ColmTalbot Oct 24, 2025
079d1c9
BUG: fix test failures
ColmTalbot Oct 30, 2025
1ab2fa4
TEST: fix reproducibility test
ColmTalbot Oct 30, 2025
cffe25c
BUG: fix a typo in conversion function test
ColmTalbot Oct 30, 2025
463af87
MAINT: don't create pool of size 1
ColmTalbot Oct 30, 2025
975cbfc
BUG: only include chunksize in multiprocessing map
ColmTalbot Nov 3, 2025
e95381e
DOC: add docstrings for pool functions
ColmTalbot Jan 26, 2026
c914a22
DOC: update pool docstrings
ColmTalbot Jan 26, 2026
4f92ce1
Address review comments
ColmTalbot Jan 26, 2026
eba6eec
TYPO: Fix typo in parameter description comments
ColmTalbot Feb 18, 2026
29ecb3b
Merge branch 'main' into user-pools
ColmTalbot Feb 18, 2026
82bcf1e
BUG: move definition of chunk size in reweighting
ColmTalbot Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions bilby/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import os
from collections import namedtuple
from copy import copy
from functools import partial
from importlib import import_module
from itertools import product
import multiprocessing
from functools import partial

import numpy as np
import pandas as pd
import scipy.stats
Expand Down Expand Up @@ -192,7 +192,7 @@ def read_in_result_list(filename_list, invalid="warning"):

def get_weights_for_reweighting(
result, new_likelihood=None, new_prior=None, old_likelihood=None,
old_prior=None, resume_file=None, n_checkpoint=5000, npool=1):
old_prior=None, resume_file=None, n_checkpoint=5000, npool=1, pool=None):
""" Calculate the weights for reweight()

See bilby.core.result.reweight() for help with the inputs
Expand Down Expand Up @@ -235,20 +235,27 @@ def get_weights_for_reweighting(
basedir = os.path.split(resume_file)[0]
check_directory_exists_and_if_not_mkdir(basedir)

dict_samples = [{key: sample[key] for key in result.posterior}
for _, sample in result.posterior.iterrows()]
dict_samples = result.posterior.to_dict(orient="records")
n = len(dict_samples) - starting_index

# Helper function to compute likelihoods in parallel
def eval_pool(this_logl):
with multiprocessing.Pool(processes=npool) as pool:
chunksize = max(100, n // (2 * npool))
return list(tqdm(
pool.imap(partial(_safe_likelihood_call, this_logl),
dict_samples[starting_index:], chunksize=chunksize),
from .utils.parallel import bilby_pool

with bilby_pool(likelihood=this_logl, npool=npool) as my_pool:
if my_pool is None:
map_fn = map
else:
chunksize = max(100, n // (2 * npool))
map_fn = partial(my_pool.imap, chunksize=chunksize)
likelihood_fn = partial(_safe_likelihood_call, this_logl)

log_l = list(tqdm(
map_fn(likelihood_fn, dict_samples[starting_index:]),
desc='Computing likelihoods',
total=n)
)
total=n,
))
return log_l

if old_likelihood is None:
old_log_likelihood_array[starting_index:] = \
Expand Down Expand Up @@ -319,7 +326,7 @@ def rejection_sample(posterior, weights):
def reweight(result, label=None, new_likelihood=None, new_prior=None,
old_likelihood=None, old_prior=None, conversion_function=None, npool=1,
verbose_output=False, resume_file=None, n_checkpoint=5000,
use_nested_samples=False):
use_nested_samples=False, pool=None):
""" Reweight a result to a new likelihood/prior using rejection sampling

Parameters
Expand Down Expand Up @@ -382,7 +389,9 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None,
get_weights_for_reweighting(
result, new_likelihood=new_likelihood, new_prior=new_prior,
old_likelihood=old_likelihood, old_prior=old_prior,
resume_file=resume_file, n_checkpoint=n_checkpoint, npool=npool)
resume_file=resume_file, n_checkpoint=n_checkpoint,
npool=npool, pool=pool,
)

if use_nested_samples:
ln_weights += np.log(result.posterior["weights"])
Expand All @@ -409,10 +418,14 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None,

if conversion_function is not None:
data_frame = result.posterior
if "npool" in inspect.signature(conversion_function).parameters:
data_frame = conversion_function(data_frame, new_likelihood, new_prior, npool=npool)
else:
data_frame = conversion_function(data_frame, new_likelihood, new_prior)
parameters = inspect.signature(conversion_function).parameters
kwargs = dict()
for key, value in [
("likelihood", new_likelihood), ("priors", new_prior), ("npool", npool), ("pool", pool)
]:
if key in parameters:
kwargs[key] = value
data_frame = conversion_function(data_frame, **kwargs)
result.posterior = data_frame

if label:
Expand Down Expand Up @@ -765,6 +778,21 @@ def log_10_evidence_err(self):
def log_10_noise_evidence(self):
return self.log_noise_evidence / np.log(10)

@property
def sampler_kwargs(self):
return self._sampler_kwargs

@sampler_kwargs.setter
def sampler_kwargs(self, sampler_kwargs):
if sampler_kwargs is None:
sampler_kwargs = dict()
else:
sampler_kwargs = copy(sampler_kwargs)
if "pool" in sampler_kwargs:
# pool objects can't be neatly serialized
sampler_kwargs["pool"] = None
self._sampler_kwargs = sampler_kwargs

@property
def version(self):
return self._version
Expand Down Expand Up @@ -1530,7 +1558,7 @@ def _add_prior_fixed_values_to_posterior(posterior, priors):
return posterior

def samples_to_posterior(self, likelihood=None, priors=None,
conversion_function=None, npool=1):
conversion_function=None, npool=1, pool=None):
"""
Convert array of samples to posterior (a Pandas data frame)

Expand Down Expand Up @@ -1560,10 +1588,14 @@ def samples_to_posterior(self, likelihood=None, priors=None,
data_frame['log_prior'] = self.log_prior_evaluations

if conversion_function is not None:
if "npool" in inspect.signature(conversion_function).parameters:
data_frame = conversion_function(data_frame, likelihood, priors, npool=npool)
else:
data_frame = conversion_function(data_frame, likelihood, priors)
parameters = inspect.signature(conversion_function).parameters
kwargs = dict()
for key, value in [
("likelihood", likelihood), ("priors", priors), ("npool", npool), ("pool", pool)
]:
if key in parameters:
kwargs[key] = value
data_frame = conversion_function(data_frame, **kwargs)
self.posterior = data_frame

def calculate_prior_values(self, priors):
Expand Down
169 changes: 108 additions & 61 deletions bilby/core/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
loaded_modules_dict,
logger,
)
from ..utils.parallel import bilby_pool
from . import proposal
from .base_sampler import Sampler, SamplingMarginalisedParameterError

Expand Down Expand Up @@ -158,6 +159,7 @@ def run_sampler(
gzip=False,
result_class=None,
npool=1,
pool=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -266,36 +268,27 @@ def run_sampler(

likelihood = ZeroLikelihood(likelihood)

common_kwargs = dict(
likelihood=likelihood,
priors=priors,
outdir=outdir,
label=label,
injection_parameters=injection_parameters,
meta_data=meta_data,
use_ratio=use_ratio,
plot=plot,
result_class=result_class,
npool=npool,
pool=pool,
)

if isinstance(sampler, Sampler):
pass
elif isinstance(sampler, str):
sampler_class = get_sampler_class(sampler)
sampler = sampler_class(
likelihood,
priors=priors,
outdir=outdir,
label=label,
injection_parameters=injection_parameters,
meta_data=meta_data,
use_ratio=use_ratio,
plot=plot,
result_class=result_class,
npool=npool,
**kwargs,
)
sampler = sampler_class(**common_kwargs, **kwargs)
elif inspect.isclass(sampler):
sampler = sampler.__init__(
likelihood,
priors=priors,
outdir=outdir,
label=label,
use_ratio=use_ratio,
plot=plot,
injection_parameters=injection_parameters,
meta_data=meta_data,
npool=npool,
**kwargs,
)
sampler = sampler.__init__(**common_kwargs, **kwargs)
else:
raise ValueError(
"Provided sampler should be a Sampler object or name of a known "
Expand All @@ -305,42 +298,81 @@ def run_sampler(
if sampler.cached_result:
logger.warning("Using cached result")
result = sampler.cached_result
result = apply_conversion_function(
result=result,
likelihood=likelihood,
conversion_function=conversion_function,
npool=npool,
pool=pool,
)
Comment on lines +301 to +307
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this always reapply the conversion function? If so, I'm unsure if this is desirable since rather than bilby quickly exiting when a run is already done it will spend time doing the conversion. That said, i don't feel strongly about this, so happy to keep it. Thoughts?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will reapply it, this maintains the current behaviour. I'd be open to changing that behaviour, maybe by adding a flag to the result file to say if the conversion has been applied, but I would say to do that as a separate change.

else:
# Run the sampler
start_time = datetime.datetime.now()
if command_line_args.bilby_test_mode:
result = sampler._run_test()
else:
result = sampler.run_sampler()
end_time = datetime.datetime.now()

# Some samplers calculate the sampling time internally
if result.sampling_time is None:
result.sampling_time = end_time - start_time
elif isinstance(result.sampling_time, (float, int)):
result.sampling_time = datetime.timedelta(result.sampling_time)

logger.info(f"Sampling time: {result.sampling_time}")
# Convert sampling time into seconds
result.sampling_time = result.sampling_time.total_seconds()

if sampler.use_ratio:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.log_evidence
result.log_evidence = result.log_bayes_factor + result.log_noise_evidence
else:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.log_evidence - result.log_noise_evidence
with bilby_pool(
likelihood,
priors,
use_ratio=sampler.use_ratio,
search_parameter_keys=sampler.search_parameter_keys,
npool=npool,
pool=pool,
parameters=priors.sample(),
) as _pool:
start_time = datetime.datetime.now()
sampler.pool = _pool
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe to do? Depending on how the sampler uses the pool, are there settings where the pool the sampler has stored is not updated?

For example, if the sampler has constructed a likelihood using the pool.map from the initial input pool will this break?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to work through what this would look like.

If the initial input pool is not None, all of the pool objects reference here should be the same, so I don't think that will make a difference.

One potential issue is that if a specific sampler implementation handles the pool internally by itself this will create two pools and then in the best case, we have a bunch of extra processes we don't need. I think this is what nessai does, so maybe we should game through that specific case.

if command_line_args.bilby_test_mode:
result = sampler._run_test()
else:
result = sampler.run_sampler()
end_time = datetime.datetime.now()
result = finalize_result(
result=result,
likelihood=likelihood,
use_ratio=sampler.use_ratio,
start_time=start_time,
end_time=end_time,
)

if None not in [result.injection_parameters, conversion_function]:
result.injection_parameters = conversion_function(
result.injection_parameters
# Initial save of the sampler in case of failure in samples_to_posterior
if save:
result.save_to_file(extension=save, gzip=gzip, outdir=outdir)

result = apply_conversion_function(
result=result,
likelihood=likelihood,
conversion_function=conversion_function,
npool=npool,
pool=_pool,
)

# Initial save of the sampler in case of failure in samples_to_posterior
if save:
result.save_to_file(extension=save, gzip=gzip, outdir=outdir)
if save:
# The overwrite here ensures we overwrite the initially stored data
result.save_to_file(overwrite=True, extension=save, gzip=gzip, outdir=outdir)

if plot:
result.plot_corner()
logger.info(f"Summary of results:\n{result}")
return result


def apply_conversion_function(
result, likelihood, conversion_function, npool=None, pool=None
):
"""
Apply the conversion function to the injected parameters and posterior if the
posterior has not already been created from the stored samples.

Parameters
----------
result : bilby.core.result.Result
The result object from the sampler.
likelihood : bilby.Likelihood
The likelihood used during sampling.
conversion_function : function
The conversion function to apply.
npool : int, optional
The number of processes to use in a processing pool.
pool : multiprocessing.Pool, schwimmbad.MPIPool, optional
The pool to use for parallelisation, this overrides the :code:`npool` argument.
"""
if None not in [result.injection_parameters, conversion_function]:
result.injection_parameters = conversion_function(
result.injection_parameters,
Expand All @@ -354,15 +386,30 @@ def run_sampler(
priors=result.priors,
conversion_function=conversion_function,
npool=npool,
pool=pool,
)
return result

if save:
# The overwrite here ensures we overwrite the initially stored data
result.save_to_file(overwrite=True, extension=save, gzip=gzip, outdir=outdir)

if plot:
result.plot_corner()
logger.info(f"Summary of results:\n{result}")
def finalize_result(result, likelihood, use_ratio, start_time=None, end_time=None):
# Some samplers calculate the sampling time internally
if result.sampling_time is None and None not in [start_time, end_time]:
result.sampling_time = end_time - start_time
elif isinstance(result.sampling_time, (float, int)):
result.sampling_time = datetime.timedelta(result.sampling_time)

logger.info(f"Sampling time: {result.sampling_time}")
# Convert sampling time into seconds
result.sampling_time = result.sampling_time.total_seconds()

if use_ratio:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.log_evidence
result.log_evidence = result.log_bayes_factor + result.log_noise_evidence
else:
result.log_noise_evidence = likelihood.noise_log_likelihood()
result.log_bayes_factor = result.log_evidence - result.log_noise_evidence

return result


Expand Down
Loading
Loading