|
| 1 | +import multiprocessing |
1 | 2 | import os |
2 | 3 | import pprint |
3 | 4 | import subprocess |
4 | 5 | import sys |
| 6 | +import threading |
| 7 | +import time |
5 | 8 | from tempfile import NamedTemporaryFile |
6 | | -from typing import Optional |
| 9 | +from typing import Literal, Optional, cast |
7 | 10 |
|
8 | 11 | import click |
9 | 12 |
|
|
15 | 18 |
|
16 | 19 | logger = log.get_logger("zetta_utils") |
17 | 20 |
|
| 21 | +LoadMode = Literal["all", "inference", "training", "try"] |
| 22 | + |
| 23 | +_PRELOAD_MODULES: dict[LoadMode, str] = { |
| 24 | + "all": "zetta_utils.builder.preload.all", |
| 25 | + "inference": "zetta_utils.builder.preload.inference", |
| 26 | + "training": "zetta_utils.builder.preload.training", |
| 27 | + "try": "zetta_utils.builder.preload.try_load", |
| 28 | +} |
| 29 | + |
| 30 | + |
| 31 | +def _noop() -> None: |
| 32 | + pass |
| 33 | + |
| 34 | + |
| 35 | +def initialize_forkserver(load_mode: LoadMode = "all"): |
| 36 | + """Initialize forkserver with preloaded modules for the given load mode.""" |
| 37 | + preload_module = _PRELOAD_MODULES[load_mode] |
| 38 | + logger.info(f"Configuring forkserver with preload module: {preload_module}") |
| 39 | + |
| 40 | + total_start = time.perf_counter() |
| 41 | + multiprocessing.set_forkserver_preload([preload_module]) |
| 42 | + ctx = multiprocessing.get_context("forkserver") |
| 43 | + proc = ctx.Process(target=_noop) |
| 44 | + proc.start() |
| 45 | + proc.join() |
| 46 | + |
| 47 | + total_elapsed = time.perf_counter() - total_start |
| 48 | + logger.info(f"Forkserver initialized in {total_elapsed:.2f}s (mode: {load_mode})") |
| 49 | + |
18 | 50 |
|
19 | 51 | @click.group() |
20 | 52 | @click.option("-v", "--verbose", count=True, default=2) |
@@ -102,19 +134,27 @@ def run( |
102 | 134 | ): |
103 | 135 | """Perform ``zetta_utils.builder.build`` action on file contents.""" |
104 | 136 | ctx = click.get_current_context() |
105 | | - load_mode = ctx.obj.get("load_mode", "all") if ctx and ctx.obj else "all" |
| 137 | + load_mode = cast(LoadMode, ctx.obj.get("load_mode", "all") if ctx and ctx.obj else "all") |
106 | 138 |
|
107 | | - # Load modules first |
| 139 | + # Start forkserver init in background while main process loads modules |
| 140 | + forkserver_thread = threading.Thread( |
| 141 | + target=initialize_forkserver, args=(load_mode,), name="forkserver_init" |
| 142 | + ) |
| 143 | + forkserver_thread.start() |
| 144 | + |
| 145 | + # Load modules in main process (runs in parallel with forkserver init) |
108 | 146 | if load_mode == "all": |
109 | 147 | zetta_utils.load_all_modules() |
110 | 148 | elif load_mode == "inference": # pragma: no cover |
111 | 149 | zetta_utils.load_inference_modules() |
112 | 150 | elif load_mode == "try": # pragma: no cover |
113 | 151 | zetta_utils.try_load_train_inference() |
114 | 152 | else: # pragma: no cover |
115 | | - assert load_mode == "training" |
116 | 153 | zetta_utils.load_training_modules() |
117 | 154 |
|
| 155 | + # Wait for forkserver to be ready before proceeding |
| 156 | + forkserver_thread.join() |
| 157 | + |
118 | 158 | from zetta_utils import builder, parsing # pylint: disable=import-outside-toplevel |
119 | 159 | from zetta_utils.run import ( # pylint: disable=import-outside-toplevel |
120 | 160 | run_ctx_manager, |
|
0 commit comments