55import importlib
66import itertools
77import logging
8- import os
98import pickle
109import subprocess
1110import time
1615import matplotlib .pyplot as plt
1716import numpy as np
1817import pandas as pd
18+ from joblib import Parallel , delayed
1919
2020import simopt .curve_utils as curve_utils
2121import simopt .directory as directory
@@ -444,7 +444,6 @@ def run(self, n_macroreps: int) -> None:
444444 """
445445 # Local Imports
446446 from functools import partial
447- from multiprocessing import Pool
448447
449448 # Value checking
450449 if n_macroreps <= 0 :
@@ -480,27 +479,20 @@ def run(self, n_macroreps: int) -> None:
480479
481480 logging .debug ("Starting macroreplications" )
482481
483- num_processes = min (n_macroreps , os .cpu_count () or 1 )
484- with Pool (num_processes ) as process_pool :
485- # Start the macroreplications in parallel (async)
486- run_multithread_partial = partial (
487- self .run_multithread , solver = self .solver , problem = self .problem
488- )
489- for num_completed , (
490- mrep ,
491- recommended_xs ,
492- intermediate_budgets ,
493- timing ,
494- ) in enumerate (
495- process_pool .imap_unordered (run_multithread_partial , range (n_macroreps ))
496- ):
497- self .all_recommended_xs [mrep ] = recommended_xs
498- self .all_intermediate_budgets [mrep ] = intermediate_budgets
499- self .timings [mrep ] = timing
500- self .num_completed = num_completed + 1
501-
502- runtime = round (time .time () - function_start , 3 )
503- logging .info (f"Finished running { n_macroreps } mreps in { runtime } seconds." )
482+ # Start the macroreplications in parallel (async)
483+ run_multithread_partial = partial (
484+ self .run_multithread , solver = self .solver , problem = self .problem
485+ )
486+ results = Parallel ()(
487+ delayed (run_multithread_partial )(i ) for i in range (n_macroreps )
488+ )
489+ for mrep , recommended_xs , intermediate_budgets , timing in results :
490+ self .all_recommended_xs [mrep ] = recommended_xs
491+ self .all_intermediate_budgets [mrep ] = intermediate_budgets
492+ self .timings [mrep ] = timing
493+
494+ runtime = round (time .time () - function_start , 3 )
495+ logging .info (f"Finished running { n_macroreps } mreps in { runtime } seconds." )
504496
505497 self .has_run = True
506498 self .has_postreplicated = False
@@ -611,9 +603,6 @@ def post_replicate(
611603 Raises:
612604 ValueError: If `n_postreps` is not positive.
613605 """
614- # Local Imports
615- from multiprocessing import Pool
616-
617606 # Value checking
618607 if n_postreps <= 0 :
619608 error_msg = "Number of postreplications must be positive."
@@ -638,25 +627,22 @@ def post_replicate(
638627 function_start = time .time ()
639628
640629 logging .info ("Starting postreplications" )
641- num_processes = min (self .n_macroreps , os .cpu_count () or 1 )
642- with Pool (num_processes ) as process_pool :
643- for num_completed , (mrep , post_rep , timing ) in enumerate (
644- process_pool .imap_unordered (
645- self .post_replicate_multithread , range (self .n_macroreps )
646- )
647- ):
648- self .all_post_replicates [mrep ] = post_rep
649- self .timings [mrep ] = timing
650- self .num_completed = num_completed + 1
630+ results = Parallel ()(
631+ delayed (self .post_replicate_multithread )(mrep )
632+ for mrep in range (self .n_macroreps )
633+ )
634+ for mrep , post_rep , timing in results :
635+ self .all_post_replicates [mrep ] = post_rep
636+ self .timings [mrep ] = timing
651637
652- # Store estimated objective for each macrorep for each budget.
653- self .all_est_objectives = [
654- [
655- float (np .mean (self .all_post_replicates [mrep ][budget_index ]))
656- for budget_index in range (len (self .all_intermediate_budgets [mrep ]))
657- ]
658- for mrep in range (self .n_macroreps )
638+ # Store estimated objective for each macrorep for each budget.
639+ self .all_est_objectives = [
640+ [
641+ float (np .mean (self .all_post_replicates [mrep ][budget_index ]))
642+ for budget_index in range (len (self .all_intermediate_budgets [mrep ]))
659643 ]
644+ for mrep in range (self .n_macroreps )
645+ ]
660646
661647 runtime = round (time .time () - function_start , 3 )
662648 logging .info (
0 commit comments