diff --git a/fastcore/parallel.py b/fastcore/parallel.py index 75a0035d..c73478ea 100644 --- a/fastcore/parallel.py +++ b/fastcore/parallel.py @@ -30,7 +30,11 @@ def g(_obj_td, *args, **kwargs): _obj_td.result = res @wraps(f) def _f(*args, **kwargs): - res = (Thread,Process)[process](target=g, args=args, kwargs=kwargs) + if process: + Proc = get_context('fork').Process if sys.platform == 'darwin' else Process + else: + Proc = Thread + res = Proc(target=g, args=args, kwargs=kwargs) res._args = (res,)+res._args res.start() return res @@ -123,7 +127,9 @@ def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None kwpool = {} if threadpool: pool = ThreadPoolExecutor else: - if not method and sys.platform == 'darwin': method='fork' + if not method and sys.platform == 'darwin': + # Use fork only if function is defined in __main__ (notebooks/REPL), otherwise use spawn + method = 'fork' if getattr(f, '__module__', None) == '__main__' else 'spawn' if method: kwpool['mp_context'] = get_context(method) pool = ProcessPoolExecutor with pool(n_workers, pause=pause, **kwpool) as ex: @@ -158,7 +164,8 @@ async def limited_task(item): # %% ../nbs/03a_parallel.ipynb def run_procs(f, f_done, args): "Call `f` for each item in `args` in parallel, yielding `f_done`" - processes = L(args).map(Process, args=arg0, target=f) + Proc = get_context('fork').Process if sys.platform == 'darwin' else Process + processes = L(args).map(Proc, args=arg0, target=f) for o in processes: o.start() yield from f_done() processes.map(Self.join()) diff --git a/nbs/03a_parallel.ipynb b/nbs/03a_parallel.ipynb index 285e528f..e894abc4 100644 --- a/nbs/03a_parallel.ipynb +++ b/nbs/03a_parallel.ipynb @@ -66,7 +66,11 @@ " _obj_td.result = res\n", " @wraps(f)\n", " def _f(*args, **kwargs):\n", - " res = (Thread,Process)[process](target=g, args=args, kwargs=kwargs)\n", + " if process:\n", + " Proc = get_context('fork').Process if sys.platform == 'darwin' else Process\n", + " else:\n", + " Proc = Thread\n", + " res = Proc(target=g, args=args, kwargs=kwargs)\n", " res._args = (res,)+res._args\n", " res.start()\n", " return res\n", @@ -414,7 +418,9 @@ " kwpool = {}\n", " if threadpool: pool = ThreadPoolExecutor\n", " else:\n", - " if not method and sys.platform == 'darwin': method='fork'\n", + " if not method and sys.platform == 'darwin':\n", + " # Use fork only if function is defined in __main__ (notebooks/REPL), otherwise use spawn\n", + " method = 'fork' if getattr(f, '__module__', None) == '__main__' else 'spawn'\n", " if method: kwpool['mp_context'] = get_context(method)\n", " pool = ProcessPoolExecutor\n", " with pool(n_workers, pause=pause, **kwpool) as ex:\n", @@ -587,7 +593,8 @@ "#| export\n", "def run_procs(f, f_done, args):\n", " \"Call `f` for each item in `args` in parallel, yielding `f_done`\"\n", - " processes = L(args).map(Process, args=arg0, target=f)\n", + " Proc = get_context('fork').Process if sys.platform == 'darwin' else Process\n", + " processes = L(args).map(Proc, args=arg0, target=f)\n", " for o in processes: o.start()\n", " yield from f_done()\n", " processes.map(Self.join())"