From 1bd7a08d12de6410a07ea5d4e909b9835a534820 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 15:18:53 -0400 Subject: [PATCH 01/20] BUG: catch that empty pickle files have non-zero size --- bilby/core/sampler/ptemcee.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 231730b98..34f2bf323 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -418,7 +418,7 @@ def setup_sampler(self): if ( os.path.isfile(self.resume_file) - and os.path.getsize(self.resume_file) + and os.path.getsize(self.resume_file) > 5 and self.resume is True ): import dill From 8bee9e0ddc68b08c48af15dd966d6d8bcd4b5ce2 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 15:19:45 -0400 Subject: [PATCH 02/20] FEAT: improve user pool passing --- bilby/core/sampler/__init__.py | 40 +++++-- bilby/core/sampler/base_sampler.py | 43 ++++---- bilby/core/utils/parallel.py | 72 +++++++++++++ bilby/gw/conversion.py | 162 ++++++++++------------------- 4 files changed, 182 insertions(+), 135 deletions(-) create mode 100644 bilby/core/utils/parallel.py diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 3b9238836..319fa3ba7 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -158,6 +158,7 @@ def run_sampler( gzip=False, result_class=None, npool=1, + pool=None, **kwargs, ): """ @@ -281,6 +282,7 @@ def run_sampler( plot=plot, result_class=result_class, npool=npool, + pool=pool, **kwargs, ) elif inspect.isclass(sampler): @@ -294,6 +296,7 @@ def run_sampler( injection_parameters=injection_parameters, meta_data=meta_data, npool=npool, + pool=pool, **kwargs, ) else: @@ -308,10 +311,20 @@ def run_sampler( else: # Run the sampler start_time = datetime.datetime.now() - if command_line_args.bilby_test_mode: - result = sampler._run_test() - else: - result = sampler.run_sampler() + from ..utils.parallel import bilby_pool + with bilby_pool( + likelihood, + priors, + use_ratio=sampler.use_ratio, + search_parameter_keys=sampler.search_parameter_keys, + npool=npool, + pool=pool, + ) as _pool: + sampler.pool = _pool + if command_line_args.bilby_test_mode: + result = sampler._run_test() + else: + result = sampler.run_sampler() end_time = datetime.datetime.now() # Some samplers calculate the sampling time internally @@ -349,12 +362,21 @@ def run_sampler( # Check if the posterior has already been created if getattr(result, "_posterior", None) is None: - result.samples_to_posterior( - likelihood=likelihood, - priors=result.priors, - conversion_function=conversion_function, + with bilby_pool( + likelihood, + priors, + use_ratio=sampler.use_ratio, + search_parameter_keys=sampler.search_parameter_keys, npool=npool, - ) + pool=pool, + ) as _pool: + result.samples_to_posterior( + likelihood=likelihood, + priors=result.priors, + conversion_function=conversion_function, + npool=npool, + pool=_pool, + ) if save: # The overwrite here ensures we overwrite the initially stored data diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 788b13198..fc82527b9 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -229,6 +229,7 @@ def __init__( soft_init=False, exit_code=130, npool=1, + pool=None, **kwargs, ): self.likelihood = likelihood @@ -241,6 +242,7 @@ def __init__( self.injection_parameters = injection_parameters self.meta_data = meta_data self.use_ratio = use_ratio + self.pool = pool self._npool = npool if not skip_import_verification: self._verify_external_sampler() @@ -748,34 +750,33 @@ def write_current_state_and_exit(self, signum=None, frame=None): sys.exit(self.exit_code) def _close_pool(self): - if getattr(self, "pool", None) is not None: + if ( + getattr(self, "pool", None) is not None + and not getattr(self, "_user_pool", True) + ): + from ..utils.parallel import close_pool logger.info("Starting to close worker pool.") - self.pool.close() - self.pool.join() + close_pool(self.pool) self.pool = None self.kwargs["pool"] = self.pool logger.info("Finished closing worker pool.") def _setup_pool(self): - if self.kwargs.get("pool", None) is not None: - logger.info("Using user defined pool.") - self.pool = self.kwargs["pool"] - elif self.npool is not None and self.npool > 1: - logger.info(f"Setting up multiproccesing pool with {self.npool} processes") - import multiprocessing - - self.pool = multiprocessing.Pool( - processes=self.npool, - initializer=_initialize_global_variables, - initargs=( - self.likelihood, - self.priors, - self._search_parameter_keys, - self.use_ratio, - ), - ) + from ..utils.parallel import create_pool + + if hasattr(self.pool, "map"): + self._user_pool = True else: - self.pool = None + self._user_pool = False + + self.pool = create_pool( + likelihood=self.likelihood, + priors=self.priors, + search_parameter_keys=self._search_parameter_keys, + use_ratio=self.use_ratio, + npool=self.npool, + pool=self.pool, + ) _initialize_global_variables( likelihood=self.likelihood, priors=self.priors, diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py new file mode 100644 index 000000000..aff1a3fc9 --- /dev/null +++ b/bilby/core/utils/parallel.py @@ -0,0 +1,72 @@ +import multiprocessing +from contextlib import contextmanager + +from .log import logger + + +def create_pool(likelihood, priors, use_ratio=None, search_parameter_keys=None, npool=None, pool=None): + from ...core.sampler.base_sampler import _initialize_global_variables + + _pool = None + if pool == "mpi": + try: + from schwimmbad import MPIPool + except ImportError: + raise ImportError("schwimmbad must be installed to use MPI pool") + + _initialize_global_variables( + likelihood=likelihood, + priors=priors, + search_parameter_keys=search_parameter_keys, + use_ratio=use_ratio, + ) + _pool = MPIPool(use_dill=True) + if _pool.is_master(): + logger.info(f"Created MPI pool with size {_pool.size}") + elif pool is not None: + _pool = pool + elif npool is not None: + _pool = multiprocessing.Pool( + processes=npool, + initializer=_initialize_global_variables, + initargs=(likelihood, priors, search_parameter_keys, use_ratio), + ) + logger.info(f"Created multiprocessing pool with size {npool}") + else: + _pool = None + return _pool + + +def close_pool(pool): + if hasattr(pool, "close"): + pool.close() + if hasattr(pool, "join"): + pool.join() + + +@contextmanager +def bilby_pool( + likelihood, priors, + use_ratio=None, + search_parameter_keys=None, + npool=None, + pool=None, +): + if hasattr(pool, "map"): + user_pool = True + else: + user_pool = False + + try: + _pool = create_pool( + likelihood=likelihood, + priors=priors, + search_parameter_keys=search_parameter_keys, + use_ratio=use_ratio, + npool=npool, + pool=pool, + ) + yield _pool + finally: + if not user_pool: + close_pool(_pool) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 1a309c105..f69cb7038 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -5,7 +5,6 @@ import os import sys -import multiprocessing import pickle from copy import deepcopy @@ -29,6 +28,7 @@ from ..core.likelihood import MarginalizedLikelihoodReconstructionError from ..core.utils import logger, solar_mass, gravitational_constant, speed_of_light, command_line_args, safe_file_dump +from ..core.utils.parallel import bilby_pool from ..core.prior import DeltaFunction from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions from .eos.eos import IntegrateTOV @@ -1631,7 +1631,7 @@ def binary_love_lambda_symmetric_to_lambda_1_lambda_2_automatic_marginalisation( def _generate_all_cbc_parameters(sample, defaults, base_conversion, - likelihood=None, priors=None, npool=1): + likelihood=None, priors=None, npool=1, pool=None): """Generate all cbc parameters, helper function for BBH/BNS""" output_sample = sample.copy() @@ -1654,13 +1654,13 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, output_sample, _ = base_conversion(output_sample) if likelihood is not None: compute_per_detector_log_likelihoods( - samples=output_sample, likelihood=likelihood, npool=npool) + samples=output_sample, likelihood=likelihood, npool=npool, pool=pool) marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) if len(marginalized_parameters) > 0: try: generate_posterior_samples_from_marginalized_likelihood( - samples=output_sample, likelihood=likelihood, npool=npool) + samples=output_sample, likelihood=likelihood, npool=npool, pool=pool) except MarginalizedLikelihoodReconstructionError as e: logger.warning( "Marginalised parameter reconstruction failed with message " @@ -1694,7 +1694,7 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, "Failed to generate sky frame parameters for type {}" .format(type(output_sample)) ) - compute_snrs(output_sample, likelihood, npool=npool) + compute_snrs(output_sample, likelihood=likelihood, npool=npool, pool=pool) for key, func in zip(["mass", "spin", "source frame"], [ generate_mass_parameters, generate_spin_parameters, generate_source_frame_parameters]): @@ -1712,7 +1712,7 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion, return output_sample -def generate_all_bbh_parameters(sample, likelihood=None, priors=None, npool=1): +def generate_all_bbh_parameters(sample, likelihood=None, priors=None, npool=1, pool=None): """ From either a single sample or a set of samples fill in all missing BBH parameters, in place. @@ -1739,11 +1739,11 @@ def generate_all_bbh_parameters(sample, likelihood=None, priors=None, npool=1): output_sample = _generate_all_cbc_parameters( sample, defaults=waveform_defaults, base_conversion=convert_to_lal_binary_black_hole_parameters, - likelihood=likelihood, priors=priors, npool=npool) + likelihood=likelihood, priors=priors, npool=npool, pool=pool) return output_sample -def generate_all_bns_parameters(sample, likelihood=None, priors=None, npool=1): +def generate_all_bns_parameters(sample, likelihood=None, priors=None, npool=1, pool=None): """ From either a single sample or a set of samples fill in all missing BNS parameters, in place. @@ -1775,7 +1775,7 @@ def generate_all_bns_parameters(sample, likelihood=None, priors=None, npool=1): output_sample = _generate_all_cbc_parameters( sample, defaults=waveform_defaults, base_conversion=convert_to_lal_binary_neutron_star_parameters, - likelihood=likelihood, priors=priors, npool=npool) + likelihood=likelihood, priors=priors, npool=npool, pool=pool) try: output_sample = generate_tidal_parameters(output_sample) except KeyError as e: @@ -2227,7 +2227,7 @@ def generate_source_frame_parameters(sample): return output_sample -def compute_snrs(sample, likelihood, npool=1): +def compute_snrs(sample, likelihood, npool=1, pool=None): """ Compute the optimal and matched filter snrs of all posterior samples and print it out. @@ -2255,23 +2255,13 @@ def compute_snrs(sample, likelihood, npool=1): logger.info('Computing SNRs for every sample.') fill_args = [(ii, row) for ii, row in sample.iterrows()] - if npool > 1: - from ..core.sampler.base_sampler import _initialize_global_variables - pool = multiprocessing.Pool( - processes=npool, - initializer=_initialize_global_variables, - initargs=(likelihood, None, None, False), - ) - logger.info( - "Using a pool with size {} for nsamples={}".format(npool, len(sample)) - ) - new_samples = pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)) - pool.close() - pool.join() - else: - from ..core.sampler.base_sampler import _sampling_convenience_dump - _sampling_convenience_dump.likelihood = likelihood - new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] + with bilby_pool(likelihood=likelihood, npool=npool, pool=pool) as _pool: + if _pool is not None: + new_samples = _pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)) + else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood + new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] for ii, ifo in enumerate(likelihood.interferometers): snr_updates = dict() @@ -2303,7 +2293,7 @@ def _compute_snrs(args): return snrs -def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10): +def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10, pool=None): """ Calculate the log likelihoods in each detector. @@ -2345,49 +2335,31 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10) # Store samples to convert for checking cached_samples_dict["_samples"] = samples - # Set up the multiprocessing - if npool > 1: - from ..core.sampler.base_sampler import _initialize_global_variables - pool = multiprocessing.Pool( - processes=npool, - initializer=_initialize_global_variables, - initargs=(likelihood, None, None, False), - ) - logger.info( - "Using a pool with size {} for nsamples={}" - .format(npool, len(samples)) - ) - else: - from ..core.sampler.base_sampler import _sampling_convenience_dump - _sampling_convenience_dump.likelihood = likelihood - pool = None - fill_args = [(ii, row) for ii, row in samples.iterrows()] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) - while ii < len(samples): - if ii in cached_samples_dict: - ii += block - pbar.update(block) - continue + with bilby_pool(likelihood=likelihood, npool=npool, pool=pool) as _pool: + while ii < len(samples): + if ii in cached_samples_dict: + ii += block + pbar.update(block) + continue + + if _pool is not None: + subset_samples = _pool.map(_compute_per_detector_log_likelihoods, + fill_args[ii: ii + block]) + else: + from ..core.sampler.base_sampler import _sampling_convenience_dump + _sampling_convenience_dump.likelihood = likelihood + subset_samples = [list(_compute_per_detector_log_likelihoods(xx)) + for xx in fill_args[ii: ii + block]] + + cached_samples_dict[ii] = subset_samples - if pool is not None: - subset_samples = pool.map(_compute_per_detector_log_likelihoods, - fill_args[ii: ii + block]) - else: - subset_samples = [list(_compute_per_detector_log_likelihoods(xx)) - for xx in fill_args[ii: ii + block]] - - cached_samples_dict[ii] = subset_samples - - ii += block - pbar.update(len(subset_samples)) + ii += block + pbar.update(len(subset_samples)) pbar.close() - if pool is not None: - pool.close() - pool.join() - new_samples = np.concatenate( [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"] ) @@ -2415,7 +2387,7 @@ def _compute_per_detector_log_likelihoods(args): def generate_posterior_samples_from_marginalized_likelihood( - samples, likelihood, npool=1, block=10, use_cache=True): + samples, likelihood, npool=1, block=10, use_cache=True, pool=None): """ Reconstruct the distance posterior from a run which used a likelihood which explicitly marginalised over time/distance/phase. @@ -2489,51 +2461,31 @@ def generate_posterior_samples_from_marginalized_likelihood( # Store samples to convert for checking cached_samples_dict["_samples"] = samples - # Set up the multiprocessing - if npool > 1: - from ..core.sampler.base_sampler import _initialize_global_variables - pool = multiprocessing.Pool( - processes=npool, - initializer=_initialize_global_variables, - initargs=(likelihood, None, None, False), - ) - logger.info( - "Using a pool with size {} for nsamples={}" - .format(npool, len(samples)) - ) - else: - from ..core.sampler.base_sampler import _sampling_convenience_dump - _sampling_convenience_dump.likelihood = likelihood - pool = None - seeds = generate_seeds(len(samples)) fill_args = [(ii, row, seed) for (ii, row), seed in zip(samples.iterrows(), seeds)] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) - while ii < len(samples): - if ii in cached_samples_dict: - ii += block - pbar.update(block) - continue + with bilby_pool(likelihood=likelihood, npool=npool, pool=pool) as _pool: + while ii < len(samples): + if ii in cached_samples_dict: + ii += block + pbar.update(block) + continue - if pool is not None: - subset_samples = pool.map(fill_sample, fill_args[ii: ii + block]) - else: - subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii: ii + block]] + if _pool is not None: + subset_samples = _pool.map(fill_sample, fill_args[ii: ii + block]) + else: + subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii: ii + block]] - cached_samples_dict[ii] = subset_samples + cached_samples_dict[ii] = subset_samples - if use_cache: - safe_file_dump(cached_samples_dict, cache_filename, "pickle") + if use_cache: + safe_file_dump(cached_samples_dict, cache_filename, "pickle") - ii += block - pbar.update(len(subset_samples)) + ii += block + pbar.update(len(subset_samples)) pbar.close() - if pool is not None: - pool.close() - pool.join() - new_samples = np.concatenate( [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"] ) @@ -2585,7 +2537,7 @@ def identity_map_conversion(parameters): return parameters, [] -def identity_map_generation(sample, likelihood=None, priors=None, npool=1): +def identity_map_generation(sample, likelihood=None, priors=None, npool=1, pool=None): """An identity map generation function that handles marginalizations, SNRs, etc. correctly, but does not attempt e.g. conversions in mass or spins @@ -2610,13 +2562,13 @@ def identity_map_generation(sample, likelihood=None, priors=None, npool=1): if likelihood is not None: compute_per_detector_log_likelihoods( - samples=output_sample, likelihood=likelihood, npool=npool) + samples=output_sample, likelihood=likelihood, npool=npool, pool=pool) marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list()) if len(marginalized_parameters) > 0: try: generate_posterior_samples_from_marginalized_likelihood( - samples=output_sample, likelihood=likelihood, npool=npool) + samples=output_sample, likelihood=likelihood, npool=npool, pool=pool) except MarginalizedLikelihoodReconstructionError as e: logger.warning( "Marginalised parameter reconstruction failed with message " @@ -2625,7 +2577,7 @@ def identity_map_generation(sample, likelihood=None, priors=None, npool=1): ) if ("ra" in output_sample.keys() and "dec" in output_sample.keys() and "psi" in output_sample.keys()): - compute_snrs(output_sample, likelihood, npool=npool) + compute_snrs(output_sample, likelihood, npool=npool, pool=pool) else: logger.info( "Skipping SNR computation since samples have insufficient sky location information" From 479c10066010d2e12270348339f4058972b36b85 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 15:20:16 -0400 Subject: [PATCH 03/20] FEAT: improve reweighting parallelisation --- bilby/core/result.py | 85 ++++++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 26 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index 98a8c8bd1..fd44b782a 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -6,8 +6,6 @@ from copy import copy from importlib import import_module from itertools import product -import multiprocessing -from functools import partial import numpy as np import pandas as pd import scipy.stats @@ -33,8 +31,11 @@ EXTENSIONS = ["json", "hdf5", "h5", "pickle", "pkl"] -def __eval_l(likelihood, params): - likelihood.parameters.update(params) +def __eval_l(sample): + from ..core.sampler.base_sampler import _sampling_convenience_dump + likelihood = _sampling_convenience_dump.likelihood + sample = dict(sample).copy() + likelihood.parameters.update(dict(sample).copy()) return likelihood.log_likelihood() @@ -196,7 +197,7 @@ def read_in_result_list(filename_list, invalid="warning"): def get_weights_for_reweighting( result, new_likelihood=None, new_prior=None, old_likelihood=None, - old_prior=None, resume_file=None, n_checkpoint=5000, npool=1): + old_prior=None, resume_file=None, n_checkpoint=5000, npool=1, pool=None): """ Calculate the weights for reweight() See bilby.core.result.reweight() for help with the inputs @@ -239,20 +240,27 @@ def get_weights_for_reweighting( starting_index = 0 - dict_samples = [{key: sample[key] for key in result.posterior} - for _, sample in result.posterior.iterrows()] + dict_samples = result.posterior.to_dict(orient="records") n = len(dict_samples) - starting_index # Helper function to compute likelihoods in parallel def eval_pool(this_logl): - with multiprocessing.Pool(processes=npool) as pool: - chunksize = max(100, n // (2 * npool)) - return list(tqdm( - pool.imap(partial(__eval_l, this_logl), - dict_samples[starting_index:], chunksize=chunksize), - desc='Computing likelihoods', - total=n) - ) + from .utils.parallel import create_pool, close_pool + + chunksize = max(100, n // (2 * npool)) + my_pool = create_pool(likelihood=this_logl, npool=npool) + if my_pool is None: + map_fn = map + else: + map_fn = my_pool.imap + + log_l = list(tqdm( + map_fn(__eval_l, dict_samples[starting_index:], chunksize=chunksize), + desc='Computing likelihoods', + total=n, + )) + close_pool(my_pool) + return log_l if old_likelihood is None: old_log_likelihood_array[starting_index:] = \ @@ -323,7 +331,7 @@ def rejection_sample(posterior, weights): def reweight(result, label=None, new_likelihood=None, new_prior=None, old_likelihood=None, old_prior=None, conversion_function=None, npool=1, verbose_output=False, resume_file=None, n_checkpoint=5000, - use_nested_samples=False): + use_nested_samples=False, pool=None): """ Reweight a result to a new likelihood/prior using rejection sampling Parameters @@ -386,7 +394,9 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, get_weights_for_reweighting( result, new_likelihood=new_likelihood, new_prior=new_prior, old_likelihood=old_likelihood, old_prior=old_prior, - resume_file=resume_file, n_checkpoint=n_checkpoint, npool=npool) + resume_file=resume_file, n_checkpoint=n_checkpoint, + npool=npool, pool=pool, + ) if use_nested_samples: ln_weights += np.log(result.posterior["weights"]) @@ -413,10 +423,14 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None, if conversion_function is not None: data_frame = result.posterior - if "npool" in inspect.signature(conversion_function).parameters: - data_frame = conversion_function(data_frame, new_likelihood, new_prior, npool=npool) - else: - data_frame = conversion_function(data_frame, new_likelihood, new_prior) + parameters = inspect.signature(conversion_function).parameters + kwargs = dict() + for key, value in [ + ("likelihood", new_likelihood), ("priors", new_prior), ("npool", npool), ("pool", pool) + ]: + if key in parameters: + kwargs[key] = value + data_frame = conversion_function(data_frame, **kwargs) result.posterior = data_frame if label: @@ -769,6 +783,21 @@ def log_10_evidence_err(self): def log_10_noise_evidence(self): return self.log_noise_evidence / np.log(10) + @property + def sampler_kwargs(self): + return self._sampler_kwargs + + @sampler_kwargs.setter + def sampler_kwargs(self, sampler_kwargs): + if sampler_kwargs is None: + sampler_kwargs = dict() + else: + sampler_kwargs = copy(sampler_kwargs) + if "pool" in sampler_kwargs: + # pool objects can't be neatly serialized + sampler_kwargs["pool"] = None + self._sampler_kwargs = sampler_kwargs + @property def version(self): return self._version @@ -1534,7 +1563,7 @@ def _add_prior_fixed_values_to_posterior(posterior, priors): return posterior def samples_to_posterior(self, likelihood=None, priors=None, - conversion_function=None, npool=1): + conversion_function=None, npool=1, pool=None): """ Convert array of samples to posterior (a Pandas data frame) @@ -1564,10 +1593,14 @@ def samples_to_posterior(self, likelihood=None, priors=None, data_frame['log_prior'] = self.log_prior_evaluations if conversion_function is not None: - if "npool" in inspect.signature(conversion_function).parameters: - data_frame = conversion_function(data_frame, likelihood, priors, npool=npool) - else: - data_frame = conversion_function(data_frame, likelihood, priors) + parameters = inspect.signature(conversion_function).parameters + kwargs = dict() + for key, value in [ + ("likelihood", likelihood), ("priors", priors), ("npool", npool), ("pool", pool) + ]: + if key in parameters: + kwargs[key] = value + data_frame = conversion_function(data_frame, **kwargs) self.posterior = data_frame def calculate_prior_values(self, priors): From dcd05dd023b4d4dc77643648e29c45fe67903fb8 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 15:27:15 -0400 Subject: [PATCH 04/20] FEAT: add parameters as argument to new pool --- bilby/core/utils/parallel.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index aff1a3fc9..3ff065311 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -4,9 +4,20 @@ from .log import logger -def create_pool(likelihood, priors, use_ratio=None, search_parameter_keys=None, npool=None, pool=None): +def create_pool( + likelihood, + priors, + use_ratio=None, + search_parameter_keys=None, + npool=None, + pool=None, + parameters=None, +): from ...core.sampler.base_sampler import _initialize_global_variables + if parameters is None: + parameters = dict() + _pool = None if pool == "mpi": try: @@ -19,6 +30,7 @@ def create_pool(likelihood, priors, use_ratio=None, search_parameter_keys=None, priors=priors, search_parameter_keys=search_parameter_keys, use_ratio=use_ratio, + parameters=parameters, ) _pool = MPIPool(use_dill=True) if _pool.is_master(): @@ -29,7 +41,7 @@ def create_pool(likelihood, priors, use_ratio=None, search_parameter_keys=None, _pool = multiprocessing.Pool( processes=npool, initializer=_initialize_global_variables, - initargs=(likelihood, priors, search_parameter_keys, use_ratio), + initargs=(likelihood, priors, search_parameter_keys, use_ratio, parameters), ) logger.info(f"Created multiprocessing pool with size {npool}") else: @@ -51,6 +63,7 @@ def bilby_pool( search_parameter_keys=None, npool=None, pool=None, + parameters=None, ): if hasattr(pool, "map"): user_pool = True @@ -65,6 +78,7 @@ def bilby_pool( use_ratio=use_ratio, npool=npool, pool=pool, + parameters=parameters, ) yield _pool finally: From 86b993ce31869c4b6637737b12950a7d1933f385 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 15:37:51 -0400 Subject: [PATCH 05/20] BUG: test that pool exists at cleanup --- bilby/core/utils/parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index 3ff065311..322f4d00e 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -52,6 +52,8 @@ def create_pool( def close_pool(pool): if hasattr(pool, "close"): pool.close() + else: + import IPython; IPython.embed() if hasattr(pool, "join"): pool.join() @@ -82,5 +84,5 @@ def bilby_pool( ) yield _pool finally: - if not user_pool: + if not user_pool and "_pool" in locals(): close_pool(_pool) From c51e15c5e683e365c96c48d84bd9074b9109a927 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 15:38:19 -0400 Subject: [PATCH 06/20] BUG: test pool exists at closing --- bilby/core/utils/parallel.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index 322f4d00e..0012173ff 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -52,8 +52,6 @@ def create_pool( def close_pool(pool): if hasattr(pool, "close"): pool.close() - else: - import IPython; IPython.embed() if hasattr(pool, "join"): pool.join() From c4654a909bc5aad695adeb429da4d91a5381efae Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 16:42:05 -0400 Subject: [PATCH 07/20] REFACTOR: refactor run_sampler to simplify pool logic --- bilby/core/sampler/__init__.py | 165 ++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 73 deletions(-) diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index a24309c86..8b563c71a 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -13,6 +13,7 @@ ) from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError +from ..utils.parallel import bilby_pool class ImplementedSamplers: @@ -267,38 +268,27 @@ def run_sampler( likelihood = ZeroLikelihood(likelihood) + common_kwargs =dict( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + injection_parameters=injection_parameters, + meta_data=meta_data, + use_ratio=use_ratio, + plot=plot, + result_class=result_class, + npool=npool, + pool=pool, + ) + if isinstance(sampler, Sampler): pass elif isinstance(sampler, str): sampler_class = get_sampler_class(sampler) - sampler = sampler_class( - likelihood, - priors=priors, - outdir=outdir, - label=label, - injection_parameters=injection_parameters, - meta_data=meta_data, - use_ratio=use_ratio, - plot=plot, - result_class=result_class, - npool=npool, - pool=pool, - **kwargs, - ) + sampler = sampler_class(**common_kwargs, **kwargs) elif inspect.isclass(sampler): - sampler = sampler.__init__( - likelihood, - priors=priors, - outdir=outdir, - label=label, - use_ratio=use_ratio, - plot=plot, - injection_parameters=injection_parameters, - meta_data=meta_data, - npool=npool, - pool=pool, - **kwargs, - ) + sampler = sampler.__init__(**common_kwargs, **kwargs) else: raise ValueError( "Provided sampler should be a Sampler object or name of a known " @@ -308,10 +298,15 @@ def run_sampler( if sampler.cached_result: logger.warning("Using cached result") result = sampler.cached_result + result = apply_conversion_function( + result=result, + likelihood=likelihood, + conversion_function=conversion_function, + npool=npool, + pool=pool, + ) else: # Run the sampler - start_time = datetime.datetime.now() - from ..utils.parallel import bilby_pool with bilby_pool( likelihood, priors, @@ -319,60 +314,26 @@ def run_sampler( search_parameter_keys=sampler.search_parameter_keys, npool=npool, pool=pool, + parameters=priors.sample(), ) as _pool: + start_time = datetime.datetime.now() sampler.pool = _pool if command_line_args.bilby_test_mode: result = sampler._run_test() else: result = sampler.run_sampler() - end_time = datetime.datetime.now() - - # Some samplers calculate the sampling time internally - if result.sampling_time is None: - result.sampling_time = end_time - start_time - elif isinstance(result.sampling_time, (float, int)): - result.sampling_time = datetime.timedelta(result.sampling_time) - - logger.info(f"Sampling time: {result.sampling_time}") - # Convert sampling time into seconds - result.sampling_time = result.sampling_time.total_seconds() - - if sampler.use_ratio: - result.log_noise_evidence = likelihood.noise_log_likelihood() - result.log_bayes_factor = result.log_evidence - result.log_evidence = result.log_bayes_factor + result.log_noise_evidence - else: - result.log_noise_evidence = likelihood.noise_log_likelihood() - result.log_bayes_factor = result.log_evidence - result.log_noise_evidence - - if None not in [result.injection_parameters, conversion_function]: - result.injection_parameters = conversion_function( - result.injection_parameters + end_time = datetime.datetime.now() + result = finalize_result( + result=result, likelihood=likelihood, start_time=start_time, end_time=end_time ) - # Initial save of the sampler in case of failure in samples_to_posterior - if save: - result.save_to_file(extension=save, gzip=gzip, outdir=outdir) + # Initial save of the sampler in case of failure in samples_to_posterior + if save: + result.save_to_file(extension=save, gzip=gzip, outdir=outdir) - if None not in [result.injection_parameters, conversion_function]: - result.injection_parameters = conversion_function( - result.injection_parameters, - likelihood=likelihood, - ) - - # Check if the posterior has already been created - if getattr(result, "_posterior", None) is None: - with bilby_pool( - likelihood, - priors, - use_ratio=sampler.use_ratio, - search_parameter_keys=sampler.search_parameter_keys, - npool=npool, - pool=pool, - ) as _pool: - result.samples_to_posterior( + result = apply_conversion_function( + result=result, likelihood=likelihood, - priors=result.priors, conversion_function=conversion_function, npool=npool, pool=_pool, @@ -388,6 +349,64 @@ def run_sampler( return result +def apply_conversion_function(result, likelihood, conversion_function, npool=None, pool=None): + """ + Apply the conversion function to the injected parameters and posterior if the + posterior has not already been created from the stored samples. + + Parameters + ---------- + result : bilby.core.result.Result + The result object from the sampler. + likelihood : bilby.Likelihood + The likelihood used during sampling. + conversion_function : function + The conversion function to apply. + npool : int, optional + The number of processes to use in a processing pool. + pool : multiprocessing.Pool, schwimmbad.MPIPool, optional + The pool to use for parallelisation, this overrides the :code:`npool` argument. + """ + if None not in [result.injection_parameters, conversion_function]: + result.injection_parameters = conversion_function( + result.injection_parameters, + likelihood=likelihood, + ) + + # Check if the posterior has already been created + if getattr(result, "_posterior", None) is None: + result.samples_to_posterior( + likelihood=likelihood, + priors=result.priors, + conversion_function=conversion_function, + npool=npool, + pool=pool, + ) + return result + + +def finalize_result(result, likelihood, start_time=None, end_time=None): + # Some samplers calculate the sampling time internally + if result.sampling_time is None and None not in [start_time, end_time]: + result.sampling_time = end_time - start_time + elif isinstance(result.sampling_time, (float, int)): + result.sampling_time = datetime.timedelta(result.sampling_time) + + logger.info(f"Sampling time: {result.sampling_time}") + # Convert sampling time into seconds + result.sampling_time = result.sampling_time.total_seconds() + + if sampler.use_ratio: + result.log_noise_evidence = likelihood.noise_log_likelihood() + result.log_bayes_factor = result.log_evidence + result.log_evidence = result.log_bayes_factor + result.log_noise_evidence + else: + result.log_noise_evidence = likelihood.noise_log_likelihood() + result.log_bayes_factor = result.log_evidence - result.log_noise_evidence + + return result + + def _check_marginalized_parameters_not_sampled(likelihood, priors): for key in likelihood.marginalized_parameters: if key in priors: From ba1df8772779a9505d12efb799051eaed763ed18 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 16:44:26 -0400 Subject: [PATCH 08/20] DEP: discourage setting up pool in sampler --- bilby/core/sampler/base_sampler.py | 31 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index a86821570..29fa29816 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -20,6 +20,7 @@ command_line_args, logger, ) +from ..utils.parallel import close_pool, create_pool from ..utils.random import seed as set_seed @@ -769,7 +770,6 @@ def _close_pool(self): getattr(self, "pool", None) is not None and not getattr(self, "_user_pool", True) ): - from ..utils.parallel import close_pool logger.info("Starting to close worker pool.") close_pool(self.pool) self.pool = None @@ -777,28 +777,31 @@ def _close_pool(self): logger.info("Finished closing worker pool.") def _setup_pool(self): - from ..utils.parallel import create_pool - if hasattr(self.pool, "map"): self._user_pool = True else: self._user_pool = False - - self.pool = create_pool( - likelihood=self.likelihood, - priors=self.priors, - search_parameter_keys=self._search_parameter_keys, - use_ratio=self.use_ratio, - npool=self.npool, - pool=self.pool, - parameters=deepcopy(self.parameters), - ) + parameters = self.priors.sample() + self.pool = create_pool( + likelihood=self.likelihood, + priors=self.priors, + search_parameter_keys=self._search_parameter_keys, + use_ratio=self.use_ratio, + npool=self.npool, + pool=self.pool, + parameters=parameters, + ) + if self.pool is not None: + logger.warning( + "Setting up parallel pool in sampler is deprecated. Use " + "bilby.utils.parallel.bilby_pool context instead." + ) _initialize_global_variables( likelihood=self.likelihood, priors=self.priors, search_parameter_keys=self._search_parameter_keys, use_ratio=self.use_ratio, - parameters=deepcopy(self.parameters), + parameters=parameters, ) self.kwargs["pool"] = self.pool From 2025cf035983d5d0469745309a7a3ff9dde41267 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 16:44:40 -0400 Subject: [PATCH 09/20] REFACTOR: remove top level multiprocessing import --- bilby/core/utils/parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index 0012173ff..67324c4d4 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -1,4 +1,3 @@ -import multiprocessing from contextlib import contextmanager from .log import logger @@ -38,6 +37,8 @@ def create_pool( elif pool is not None: _pool = pool elif npool is not None: + import multiprocessing + _pool = multiprocessing.Pool( processes=npool, initializer=_initialize_global_variables, From 10e4267a5839e8003ee4ab2ed0a2bb6a1108019d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 24 Oct 2025 16:44:55 -0400 Subject: [PATCH 10/20] BUG: make sure prior is passed to pool creation --- bilby/gw/conversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 15ad42c32..60f41f458 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -2239,7 +2239,7 @@ def compute_snrs(sample, likelihood, npool=1, pool=None): logger.info('Computing SNRs for every sample.') fill_args = [(ii, row) for ii, row in sample.iterrows()] - with bilby_pool(likelihood=likelihood, npool=npool, pool=pool) as _pool: + with bilby_pool(likelihood=likelihood, priors=None, npool=npool, pool=pool) as _pool: if _pool is not None: new_samples = _pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)) else: @@ -2322,7 +2322,7 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10, fill_args = [(ii, row) for ii, row in samples.iterrows()] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) - with bilby_pool(likelihood=likelihood, npool=npool, pool=pool) as _pool: + with bilby_pool(likelihood=likelihood, priors=None, npool=npool, pool=pool) as _pool: while ii < len(samples): if ii in cached_samples_dict: ii += block @@ -2448,7 +2448,7 @@ def generate_posterior_samples_from_marginalized_likelihood( fill_args = [(ii, row, seed) for (ii, row), seed in zip(samples.iterrows(), seeds)] ii = 0 pbar = tqdm(total=len(samples), file=sys.stdout) - with bilby_pool(likelihood=likelihood, npool=npool, pool=pool) as _pool: + with bilby_pool(likelihood=likelihood, priors=None, npool=npool, pool=pool) as _pool: while ii < len(samples): if ii in cached_samples_dict: ii += block From 079d1c9083a375acb4632653cecd39f40cbd39e1 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 30 Oct 2025 14:43:09 -0400 Subject: [PATCH 11/20] BUG: fix test failures --- bilby/core/result.py | 2 +- bilby/core/sampler/__init__.py | 18 ++++++++++++------ bilby/core/sampler/base_sampler.py | 10 ++++++---- bilby/core/utils/parallel.py | 4 ++-- bilby/gw/conversion.py | 12 ++++++++---- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index 0c820ab51..a0cece883 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -248,7 +248,7 @@ def eval_pool(this_logl): map_fn = map else: map_fn = my_pool.imap - + log_l = list(tqdm( map_fn(partial(_safe_likelihood_call, this_logl), dict_samples[starting_index:], chunksize=chunksize), desc='Computing likelihoods', diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 8b563c71a..a553a225a 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -11,9 +11,9 @@ loaded_modules_dict, logger, ) +from ..utils.parallel import bilby_pool from . import proposal from .base_sampler import Sampler, SamplingMarginalisedParameterError -from ..utils.parallel import bilby_pool class ImplementedSamplers: @@ -268,7 +268,7 @@ def run_sampler( likelihood = ZeroLikelihood(likelihood) - common_kwargs =dict( + common_kwargs = dict( likelihood=likelihood, priors=priors, outdir=outdir, @@ -324,7 +324,11 @@ def run_sampler( result = sampler.run_sampler() end_time = datetime.datetime.now() result = finalize_result( - result=result, likelihood=likelihood, start_time=start_time, end_time=end_time + result=result, + likelihood=likelihood, + use_ratio=sampler.use_ratio, + start_time=start_time, + end_time=end_time, ) # Initial save of the sampler in case of failure in samples_to_posterior @@ -349,7 +353,9 @@ def run_sampler( return result -def apply_conversion_function(result, likelihood, conversion_function, npool=None, pool=None): +def apply_conversion_function( + result, likelihood, conversion_function, npool=None, pool=None +): """ Apply the conversion function to the injected parameters and posterior if the posterior has not already been created from the stored samples. @@ -385,7 +391,7 @@ def apply_conversion_function(result, likelihood, conversion_function, npool=Non return result -def finalize_result(result, likelihood, start_time=None, end_time=None): +def finalize_result(result, likelihood, use_ratio, start_time=None, end_time=None): # Some samplers calculate the sampling time internally if result.sampling_time is None and None not in [start_time, end_time]: result.sampling_time = end_time - start_time @@ -396,7 +402,7 @@ def finalize_result(result, likelihood, start_time=None, end_time=None): # Convert sampling time into seconds result.sampling_time = result.sampling_time.total_seconds() - if sampler.use_ratio: + if use_ratio: result.log_noise_evidence = likelihood.noise_log_likelihood() result.log_bayes_factor = result.log_evidence result.log_evidence = result.log_bayes_factor + result.log_noise_evidence diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 29fa29816..4ae35d9ea 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -766,9 +766,8 @@ def write_current_state_and_exit(self, signum=None, frame=None): sys.exit(self.exit_code) def _close_pool(self): - if ( - getattr(self, "pool", None) is not None - and not getattr(self, "_user_pool", True) + if getattr(self, "pool", None) is not None and not getattr( + self, "_user_pool", True ): logger.info("Starting to close worker pool.") close_pool(self.pool) @@ -777,11 +776,14 @@ def _close_pool(self): logger.info("Finished closing worker pool.") def _setup_pool(self): + parameters = self.priors.sample() + if hasattr(self.pool, "map"): self._user_pool = True + elif self.npool in (1, None): + self._user_pool = False else: self._user_pool = False - parameters = self.priors.sample() self.pool = create_pool( likelihood=self.likelihood, priors=self.priors, diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index 67324c4d4..142c7231f 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -4,8 +4,8 @@ def create_pool( - likelihood, - priors, + likelihood=None, + priors=None, use_ratio=None, search_parameter_keys=None, npool=None, diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 60f41f458..72330a30c 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -2330,13 +2330,17 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10, continue if _pool is not None: - subset_samples = _pool.map(_compute_per_detector_log_likelihoods, - fill_args[ii: ii + block]) + subset_samples = _pool.map( + _compute_per_detector_log_likelihoods, + fill_args[ii: ii + block] + ) else: from ..core.sampler.base_sampler import _sampling_convenience_dump _sampling_convenience_dump.likelihood = likelihood - subset_samples = [list(_compute_per_detector_log_likelihoods(xx)) - for xx in fill_args[ii: ii + block]] + subset_samples = [ + list(_compute_per_detector_log_likelihoods(xx)) + for xx in fill_args[ii: ii + block] + ] cached_samples_dict[ii] = subset_samples From 1ab2fa4924c6326923d6727651b3c74328f07ade Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 30 Oct 2025 15:21:42 -0400 Subject: [PATCH 12/20] TEST: fix reproducibility test --- test/core/sampler/dynesty_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index 4177c4fea..9ae552ff8 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -483,6 +483,7 @@ def _run_sampler(self, **kwargs): resume=False, dlogz=1.0, nlive=20, + sample="acceptance-walk", **kwargs, ) From cffe25c7e77958e3162ce8e895e5f62e708616a1 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 30 Oct 2025 15:29:58 -0400 Subject: [PATCH 13/20] BUG: fix a typo in conversion function test --- test/integration/sampler_run_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index f35cf07c6..471ddc8a6 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -91,7 +91,7 @@ def setUp(self): bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") @staticmethod - def conversion_function(parameters, likelihood, prior): + def conversion_function(parameters, likelihood, priors): converted = parameters.copy() if "derived" not in converted: converted["derived"] = converted["m"] * converted["c"] From 463af87ec401ee69247adb800042f1f194429ad2 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Thu, 30 Oct 2025 15:31:02 -0400 Subject: [PATCH 14/20] MAINT: don't create pool of size 1 --- bilby/core/utils/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index 142c7231f..b52e6f7ff 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -36,7 +36,7 @@ def create_pool( logger.info(f"Created MPI pool with size {_pool.size}") elif pool is not None: _pool = pool - elif npool is not None: + elif npool not in (None, 1): import multiprocessing _pool = multiprocessing.Pool( From 975cbfc6ab60d7b881a2395db63a92c43a65c232 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 3 Nov 2025 09:49:28 -0500 Subject: [PATCH 15/20] BUG: only include chunksize in multiprocessing map --- bilby/core/result.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index a0cece883..ea7dc7551 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -247,10 +247,11 @@ def eval_pool(this_logl): if my_pool is None: map_fn = map else: - map_fn = my_pool.imap + map_fn = partial(my_pool.imap, chunksize=chunksize) + likelihood_fn = partial(_safe_likelihood_call, this_logl) log_l = list(tqdm( - map_fn(partial(_safe_likelihood_call, this_logl), dict_samples[starting_index:], chunksize=chunksize), + map_fn(likelihood_fn, dict_samples[starting_index:]), desc='Computing likelihoods', total=n, )) From e95381e3702f20a1d68ef88629d20bd6c7553006 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 26 Jan 2026 10:25:08 +0000 Subject: [PATCH 16/20] DOC: add docstrings for pool functions --- bilby/core/utils/parallel.py | 116 +++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index b52e6f7ff..491cfa2f9 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -12,6 +12,61 @@ def create_pool( pool=None, parameters=None, ): + """ + Create a parallel pool object that is initialized with variables typically + needed by Bilby for parallel tasks. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood, None + The likelihood to copy into each process + priors: bilby.core.prior.PriorDict, None + The Bilby prior dictionary to copy into each process + use_ratio: bool, None + Whether to evaluate the log_likelihood_ratio + search_parameter_keys: list[str], None + The names fo parameters being sampled over + npool: int, None + The number of processes to use for multiprocessing. + If a user pool is not provided and this is either :code:`1` or :code:`None`, + this functions returns :code:`None`. + pool: pool-like, str, None + Either a premade pool object, or the pool kind (:code:`mpi`, :code:`multiprocessing`). + If a pre-made pool is passed, it is returned directly with no checks + performed. + parameters: dict, None + Parameters to pass through to the new processes, e.g., if default + parameters are to be passed. + + Returns + ======= + pool: schwimmbad.MPIPool, multiprocessing.Pool, None + Returns either a pool that can be used for mapping function calls. + Each process attached to the pool has been initialized with + the bilby.core.sampler._SamplingContainer. + + Examples + ======== + + >>> import numpy as np + >>> from bilby.core.likelihood import AnalyticalMultidimensionalCovariantGaussian + >>> from bilby.core.prior import Normal, PriorDict + >>> from bilby.core.utils.parallel import close_pool, create_pool + >>> from bilby.core.sampler.base_sampler import _sampling_convenience_dump + + >>> def parallel_likelihood_eval(parameters): + >>> likelihood = _sampling_convenience_dump.likelihood + >>> return likelihood.log_likelihood(parameters) + + >>> likelihood = AnalyticalMultidimensionalCovariantGaussian( + >>> mean=np.zeros(4), cov=np.eye(4) + >>> ) + >>> priors = PriorDict({f"x{i}": Normal(0, 1) for ii in range(4)}) + >>> parameters = [priors.sample() for _ in range(10)] + >>> pool = create_pool(likelihood, priors, npool=4) + >>> log_ls = list(pool.map(some_parallel_function, parameters)) + >>> close_pool(pool) + """ from ...core.sampler.base_sampler import _initialize_global_variables if parameters is None: @@ -51,6 +106,11 @@ def create_pool( def close_pool(pool): + """ + Safely close a parallel pool. + If the pool has a :code:`close` method :code:`pool.close` will be called. + Then, if the pool has a :code:`join` method :code:`pool.join` will be called. + """ if hasattr(pool, "close"): pool.close() if hasattr(pool, "join"): @@ -66,6 +126,62 @@ def bilby_pool( pool=None, parameters=None, ): + """ + Yield a parallel pool object that is initialized with variables typically + needed by Bilby for parallel tasks that is automatically close when closing + the context. + + Parameters + ========== + likelihood: bilby.core.likelihood.Likelihood, None + The likelihood to copy into each process + priors: bilby.core.prior.PriorDict, None + The Bilby prior dictionary to copy into each process + use_ratio: bool, None + Whether to evaluate the log_likelihood_ratio + search_parameter_keys: list[str], None + The names fo parameters being sampled over + npool: int, None + The number of processes to use for multiprocessing. + If a user pool is not provided and this is either :code:`1` or :code:`None`, + this functions returns :code:`None`. + pool: pool-like, str, None + Either a premade pool object, or the pool kind (:code:`mpi`, :code:`multiprocessing`). + If a pre-made pool is passed, it is returned directly with no checks + performed. + parameters: dict, None + Parameters to pass through to the new processes, e.g., if default + parameters are to be passed. + + Yields + ====== + pool: schwimmbad.MPIPool, multiprocessing.Pool, None + Returns either a pool that can be used for mapping function calls. + Each process attached to the pool has been initialized with + the bilby.core.sampler._SamplingContainer. + + Examples + ======== + + >>> import numpy as np + >>> from bilby.core.likelihood import AnalyticalMultidimensionalCovariantGaussian + >>> from bilby.core.prior import Normal, PriorDict + >>> from bilby.core.utils.parallel import bilby_pool + >>> from bilby.core.sampler.base_sampler import _sampling_convenience_dump + + >>> def parallel_likelihood_eval(parameters): + >>> likelihood = _sampling_convenience_dump.likelihood + >>> return likelihood.log_likelihood(parameters) + + >>> likelihood = AnalyticalMultidimensionalCovariantGaussian( + >>> mean=np.zeros(4), cov=np.eye(4) + >>> ) + >>> priors = PriorDict({f"x{i}": Normal(0, 1) for ii in range(4)}) + >>> parameters = [priors.sample() for _ in range(10)] + >>> with bilby_pool(likelihood, priors, npool=4) as pool: + >>> log_ls = list(pool.map(some_parallel_function, parameters)) + + """ if hasattr(pool, "map"): user_pool = True else: From c914a2252ac8728da45665462bcd399c8a2df2d8 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 26 Jan 2026 10:26:42 +0000 Subject: [PATCH 17/20] DOC: update pool docstrings --- bilby/core/utils/parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index 491cfa2f9..ef8d03b45 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -43,7 +43,7 @@ def create_pool( pool: schwimmbad.MPIPool, multiprocessing.Pool, None Returns either a pool that can be used for mapping function calls. Each process attached to the pool has been initialized with - the bilby.core.sampler._SamplingContainer. + the :code:`bilby.core.sampler.base_sampler._sampling_convenience_dump`. Examples ======== @@ -158,7 +158,7 @@ def bilby_pool( pool: schwimmbad.MPIPool, multiprocessing.Pool, None Returns either a pool that can be used for mapping function calls. Each process attached to the pool has been initialized with - the bilby.core.sampler._SamplingContainer. + the :code:`bilby.core.sampler.base_sampler._sampling_convenience_dump`. Examples ======== From 4f92ce113fca364b0c17df9979eed11a390c019d Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Mon, 26 Jan 2026 15:38:06 +0000 Subject: [PATCH 18/20] Address review comments --- bilby/core/result.py | 27 +++++++++++++-------------- bilby/core/utils/parallel.py | 3 ++- bilby/gw/conversion.py | 17 ++++++++--------- test/core/result_test.py | 15 +++++++++++++-- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index ea7dc7551..6703084d9 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -240,22 +240,21 @@ def get_weights_for_reweighting( # Helper function to compute likelihoods in parallel def eval_pool(this_logl): - from .utils.parallel import create_pool, close_pool + from .utils.parallel import bilby_pool chunksize = max(100, n // (2 * npool)) - my_pool = create_pool(likelihood=this_logl, npool=npool) - if my_pool is None: - map_fn = map - else: - map_fn = partial(my_pool.imap, chunksize=chunksize) - likelihood_fn = partial(_safe_likelihood_call, this_logl) - - log_l = list(tqdm( - map_fn(likelihood_fn, dict_samples[starting_index:]), - desc='Computing likelihoods', - total=n, - )) - close_pool(my_pool) + with bilby_pool(likelihood=this_logl, npool=npool) as my_pool: + if my_pool is None: + map_fn = map + else: + map_fn = partial(my_pool.imap, chunksize=chunksize) + likelihood_fn = partial(_safe_likelihood_call, this_logl) + + log_l = list(tqdm( + map_fn(likelihood_fn, dict_samples[starting_index:]), + desc='Computing likelihoods', + total=n, + )) return log_l if old_likelihood is None: diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index ef8d03b45..c6e684da8 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -119,7 +119,8 @@ def close_pool(pool): @contextmanager def bilby_pool( - likelihood, priors, + likelihood=None, + priors=None, use_ratio=None, search_parameter_keys=None, npool=None, diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 72330a30c..bcb349c6a 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -2330,17 +2330,16 @@ def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10, continue if _pool is not None: - subset_samples = _pool.map( - _compute_per_detector_log_likelihoods, - fill_args[ii: ii + block] - ) + map_fn = _pool.map else: + map_fn = map from ..core.sampler.base_sampler import _sampling_convenience_dump _sampling_convenience_dump.likelihood = likelihood - subset_samples = [ - list(_compute_per_detector_log_likelihoods(xx)) - for xx in fill_args[ii: ii + block] - ] + + subset_samples = list(map_fn( + _compute_per_detector_log_likelihoods, + fill_args[ii: ii + block], + )) cached_samples_dict[ii] = subset_samples @@ -2462,7 +2461,7 @@ def generate_posterior_samples_from_marginalized_likelihood( if _pool is not None: subset_samples = _pool.map(fill_sample, fill_args[ii: ii + block]) else: - subset_samples = [list(fill_sample(xx)) for xx in fill_args[ii: ii + block]] + subset_samples = list(map(fill_sample, fill_args[ii: ii + block])) cached_samples_dict[ii] = subset_samples diff --git a/test/core/result_test.py b/test/core/result_test.py index 23ba8e6b5..8d172a0d1 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -880,7 +880,7 @@ def setUp(self): log_evidence=-np.log(10), ) - def _run_reweighting(self, sigma): + def _run_reweighting(self, sigma, npool=None): likelihood_1 = SimpleGaussianLikelihood() likelihood_2 = SimpleGaussianLikelihood(sigma=sigma) original_ln_likelihoods = list() @@ -891,7 +891,11 @@ def _run_reweighting(self, sigma): self.result.posterior["log_likelihood"] = original_ln_likelihoods self.original_ln_likelihoods = original_ln_likelihoods return bilby.core.result.reweight( - self.result, likelihood_1, likelihood_2, verbose_output=True + self.result, + likelihood_1, + likelihood_2, + verbose_output=True, + npool=npool, ) def test_reweight_same_likelihood_weights_1(self): @@ -901,6 +905,13 @@ def test_reweight_same_likelihood_weights_1(self): _, weights, _, _, _, _ = self._run_reweighting(sigma=1) self.assertLess(min(abs(weights - 1)), 1e-10) + def test_reweight_same_likelihood_weights_1_with_pool(self): + """ + When the likelihoods are the same, the weights should be 1. + """ + _, weights, _, _, _, _ = self._run_reweighting(sigma=1, npool=2) + self.assertLess(min(abs(weights - 1)), 1e-10) + @pytest.mark.flaky(reruns=3) def test_reweight_different_likelihood_weights_correct(self): """ From eba6eecdf4ea3c2dc1064c68de8189754c0c6190 Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 18 Feb 2026 13:15:03 -0500 Subject: [PATCH 19/20] TYPO: Fix typo in parameter description comments --- bilby/core/utils/parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bilby/core/utils/parallel.py b/bilby/core/utils/parallel.py index c6e684da8..1d3d9be32 100644 --- a/bilby/core/utils/parallel.py +++ b/bilby/core/utils/parallel.py @@ -25,7 +25,7 @@ def create_pool( use_ratio: bool, None Whether to evaluate the log_likelihood_ratio search_parameter_keys: list[str], None - The names fo parameters being sampled over + The names for parameters being sampled over npool: int, None The number of processes to use for multiprocessing. If a user pool is not provided and this is either :code:`1` or :code:`None`, @@ -141,7 +141,7 @@ def bilby_pool( use_ratio: bool, None Whether to evaluate the log_likelihood_ratio search_parameter_keys: list[str], None - The names fo parameters being sampled over + The names for parameters being sampled over npool: int, None The number of processes to use for multiprocessing. If a user pool is not provided and this is either :code:`1` or :code:`None`, From 82bcf1e9bce6e2ca6b59257ac5ec7b06df7f99ff Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Wed, 18 Feb 2026 15:48:46 -0500 Subject: [PATCH 20/20] BUG: move definition of chunk size in reweighting --- bilby/core/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index d6c212a09..0247d5a8e 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -242,11 +242,11 @@ def get_weights_for_reweighting( def eval_pool(this_logl): from .utils.parallel import bilby_pool - chunksize = max(100, n // (2 * npool)) with bilby_pool(likelihood=this_logl, npool=npool) as my_pool: if my_pool is None: map_fn = map else: + chunksize = max(100, n // (2 * npool)) map_fn = partial(my_pool.imap, chunksize=chunksize) likelihood_fn = partial(_safe_likelihood_call, this_logl)