Skip to content

Commit b8ee0d3

Browse files
committed
Use joblib to run experiments in parallel
This change is made because multiprocessing and jupyter notebook do not work on Windows.
1 parent 7cdbbcf commit b8ee0d3

3 files changed

Lines changed: 40 additions & 44 deletions

File tree

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,7 @@ dependencies:
1313
- pillow
1414
- pip
1515
- ruff
16+
- joblib
1617
- pip:
18+
- mrg32k3a>=2.0
1719
- -e .

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@ classifiers = [
2323
"License :: OSI Approved :: MIT License",
2424
"Operating System :: OS Independent",
2525
]
26-
dependencies = ["numpy", "scipy", "matplotlib", "pandas", "seaborn", "mrg32k3a"]
26+
dependencies = [
27+
"numpy",
28+
"scipy",
29+
"matplotlib",
30+
"pandas",
31+
"seaborn",
32+
"mrg32k3a",
33+
"joblib>=1.5.1",
34+
]
2735

2836
[project.optional-dependencies]
2937
dev = ["sphinx"]

simopt/experiment_base.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import importlib
66
import itertools
77
import logging
8-
import os
98
import pickle
109
import subprocess
1110
import time
@@ -16,6 +15,7 @@
1615
import matplotlib.pyplot as plt
1716
import numpy as np
1817
import pandas as pd
18+
from joblib import Parallel, delayed
1919

2020
import simopt.curve_utils as curve_utils
2121
import simopt.directory as directory
@@ -444,7 +444,6 @@ def run(self, n_macroreps: int) -> None:
444444
"""
445445
# Local Imports
446446
from functools import partial
447-
from multiprocessing import Pool
448447

449448
# Value checking
450449
if n_macroreps <= 0:
@@ -480,27 +479,20 @@ def run(self, n_macroreps: int) -> None:
480479

481480
logging.debug("Starting macroreplications")
482481

483-
num_processes = min(n_macroreps, os.cpu_count() or 1)
484-
with Pool(num_processes) as process_pool:
485-
# Start the macroreplications in parallel (async)
486-
run_multithread_partial = partial(
487-
self.run_multithread, solver=self.solver, problem=self.problem
488-
)
489-
for num_completed, (
490-
mrep,
491-
recommended_xs,
492-
intermediate_budgets,
493-
timing,
494-
) in enumerate(
495-
process_pool.imap_unordered(run_multithread_partial, range(n_macroreps))
496-
):
497-
self.all_recommended_xs[mrep] = recommended_xs
498-
self.all_intermediate_budgets[mrep] = intermediate_budgets
499-
self.timings[mrep] = timing
500-
self.num_completed = num_completed + 1
501-
502-
runtime = round(time.time() - function_start, 3)
503-
logging.info(f"Finished running {n_macroreps} mreps in {runtime} seconds.")
482+
# Start the macroreplications in parallel (async)
483+
run_multithread_partial = partial(
484+
self.run_multithread, solver=self.solver, problem=self.problem
485+
)
486+
results = Parallel()(
487+
delayed(run_multithread_partial)(i) for i in range(n_macroreps)
488+
)
489+
for mrep, recommended_xs, intermediate_budgets, timing in results:
490+
self.all_recommended_xs[mrep] = recommended_xs
491+
self.all_intermediate_budgets[mrep] = intermediate_budgets
492+
self.timings[mrep] = timing
493+
494+
runtime = round(time.time() - function_start, 3)
495+
logging.info(f"Finished running {n_macroreps} mreps in {runtime} seconds.")
504496

505497
self.has_run = True
506498
self.has_postreplicated = False
@@ -611,9 +603,6 @@ def post_replicate(
611603
Raises:
612604
ValueError: If `n_postreps` is not positive.
613605
"""
614-
# Local Imports
615-
from multiprocessing import Pool
616-
617606
# Value checking
618607
if n_postreps <= 0:
619608
error_msg = "Number of postreplications must be positive."
@@ -638,25 +627,22 @@ def post_replicate(
638627
function_start = time.time()
639628

640629
logging.info("Starting postreplications")
641-
num_processes = min(self.n_macroreps, os.cpu_count() or 1)
642-
with Pool(num_processes) as process_pool:
643-
for num_completed, (mrep, post_rep, timing) in enumerate(
644-
process_pool.imap_unordered(
645-
self.post_replicate_multithread, range(self.n_macroreps)
646-
)
647-
):
648-
self.all_post_replicates[mrep] = post_rep
649-
self.timings[mrep] = timing
650-
self.num_completed = num_completed + 1
630+
results = Parallel()(
631+
delayed(self.post_replicate_multithread)(mrep)
632+
for mrep in range(self.n_macroreps)
633+
)
634+
for mrep, post_rep, timing in results:
635+
self.all_post_replicates[mrep] = post_rep
636+
self.timings[mrep] = timing
651637

652-
# Store estimated objective for each macrorep for each budget.
653-
self.all_est_objectives = [
654-
[
655-
float(np.mean(self.all_post_replicates[mrep][budget_index]))
656-
for budget_index in range(len(self.all_intermediate_budgets[mrep]))
657-
]
658-
for mrep in range(self.n_macroreps)
638+
# Store estimated objective for each macrorep for each budget.
639+
self.all_est_objectives = [
640+
[
641+
float(np.mean(self.all_post_replicates[mrep][budget_index]))
642+
for budget_index in range(len(self.all_intermediate_budgets[mrep]))
659643
]
644+
for mrep in range(self.n_macroreps)
645+
]
660646

661647
runtime = round(time.time() - function_start, 3)
662648
logging.info(

0 commit comments

Comments
 (0)