Skip to content

Commit 3fbd2a6

Browse files
committed
fix: forkserver w/ preloaded modules/registry
1 parent c9d60aa commit 3fbd2a6

13 files changed

Lines changed: 184 additions & 87 deletions

File tree

.coveragerc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
omit =
33
*_tmp.py
44
zetta_utils/log.py
5+
zetta_utils/builder/preload/try_load.py
56
zetta_utils/cli/task_mgmt.py
67
zetta_utils/task_management/subtask_structure.py
78
zetta_utils/task_management/segment.py

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,4 +313,7 @@ warn_unused_ignores = true
313313
known_third_party = "wandb"
314314
profile = "black"
315315
skip = ["specs"]
316-
skip_glob = ["**/__init__.py"]
316+
skip_glob = [
317+
"**/__init__.py",
318+
"zetta_utils/builder/preload/*.py"
319+
]

zetta_utils/__init__.py

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -36,88 +36,21 @@
3636
warnings.filterwarnings("ignore", category=DeprecationWarning)
3737

3838

39-
def _load_core_modules():
40-
"""Load core modules that were previously imported at package level."""
41-
from . import log, typing, parsing, builder, common, constants
42-
from . import geometry, distributions, layer, ng
43-
44-
# Add builder module suppression now that it's loaded
45-
log.add_supress_traceback_module(builder)
46-
47-
48-
def load_all_modules():
49-
_load_core_modules()
50-
load_inference_modules()
51-
load_training_modules()
52-
from . import task_management
39+
def load_all_modules(): # pragma: no cover
40+
import zetta_utils.builder.preload.all
5341

5442

5543
def try_load_train_inference(): # pragma: no cover
56-
try:
57-
_load_core_modules()
58-
except Exception as e: # pylint: disable=broad-exception-caught
59-
logger.exception(e)
60-
61-
try:
62-
load_inference_modules()
63-
64-
except Exception as e: # pylint: disable=broad-exception-caught
65-
logger.exception(e)
66-
67-
try:
68-
load_training_modules()
69-
except Exception as e: # pylint: disable=broad-exception-caught
70-
logger.exception(e)
71-
72-
try:
73-
from . import mazepa_addons
74-
except Exception as e: # pylint: disable=broad-exception-caught
75-
logger.exception(e)
44+
import zetta_utils.builder.preload.try_load
7645

7746

7847
def load_submodules(): # pragma: no cover
7948
from . import internal
8049

8150

82-
def load_inference_modules():
83-
_load_core_modules()
84-
from . import (
85-
augmentations,
86-
convnet,
87-
mazepa,
88-
mazepa_layer_processing,
89-
tensor_ops,
90-
tensor_typing,
91-
tensor_mapping,
92-
)
93-
from .layer import volumetric
94-
from .layer.volumetric import cloudvol
95-
from .message_queues import sqs
96-
97-
from . import mazepa_addons
98-
from . import message_queues
99-
from . import cloud_management
100-
101-
load_submodules()
102-
103-
104-
def load_training_modules():
105-
_load_core_modules()
106-
from . import (
107-
augmentations,
108-
convnet,
109-
mazepa,
110-
tensor_ops,
111-
tensor_typing,
112-
training,
113-
tensor_mapping,
114-
)
115-
from .layer import volumetric, db_layer
116-
from .layer.db_layer import datastore, firestore
117-
from .layer.volumetric import cloudvol
118-
119-
from . import mazepa_addons
120-
from . import message_queues
121-
from . import cloud_management
122-
123-
load_submodules()
51+
def load_inference_modules(): # pragma: no cover
52+
import zetta_utils.builder.preload.inference
53+
54+
55+
def load_training_modules(): # pragma: no cover
56+
import zetta_utils.builder.preload.training

zetta_utils/builder/preload/__init__.py

Whitespace-only changes.

zetta_utils/builder/preload/all.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# pylint: disable=unused-import, wrong-import-position
2+
"""All module imports."""
3+
4+
import time
5+
6+
from zetta_utils import log
7+
8+
_start = time.perf_counter()
9+
10+
from zetta_utils.builder.preload import inference
11+
from zetta_utils.builder.preload import training
12+
from zetta_utils import task_management
13+
14+
_elapsed = time.perf_counter() - _start
15+
log.get_logger("zetta_utils").debug(f"Preload all modules: {_elapsed:.2f}s")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# pylint: disable=unused-import, wrong-import-position
2+
"""Core module imports - shared by all load modes."""
3+
4+
import time
5+
6+
_start = time.perf_counter()
7+
8+
from zetta_utils import log, typing, parsing, builder, common, constants
9+
from zetta_utils import geometry, distributions, layer, ng
10+
11+
# Add builder module suppression now that it's loaded
12+
log.add_supress_traceback_module(builder)
13+
14+
_elapsed = time.perf_counter() - _start
15+
log.get_logger("zetta_utils").debug(f"Preload core modules: {_elapsed:.2f}s")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# pylint: disable=unused-import, wrong-import-position
2+
"""Inference module imports."""
3+
4+
import time
5+
6+
from zetta_utils import log
7+
8+
_start = time.perf_counter()
9+
10+
# Import core first
11+
from zetta_utils.builder.preload import core
12+
13+
from zetta_utils import (
14+
augmentations,
15+
convnet,
16+
mazepa,
17+
mazepa_layer_processing,
18+
tensor_ops,
19+
tensor_typing,
20+
tensor_mapping,
21+
)
22+
from zetta_utils.layer import volumetric
23+
from zetta_utils.layer.volumetric import cloudvol
24+
from zetta_utils.message_queues import sqs
25+
26+
from zetta_utils import mazepa_addons
27+
from zetta_utils import message_queues
28+
from zetta_utils import cloud_management
29+
30+
from zetta_utils import internal
31+
32+
_elapsed = time.perf_counter() - _start
33+
log.get_logger("zetta_utils").debug(f"Preload inference modules: {_elapsed:.2f}s")
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# pylint: disable=unused-import, wrong-import-position
2+
"""Training module imports."""
3+
4+
import time
5+
6+
from zetta_utils import log
7+
8+
_start = time.perf_counter()
9+
10+
# Import core first
11+
from zetta_utils.builder.preload import core
12+
13+
from zetta_utils import (
14+
augmentations,
15+
convnet,
16+
mazepa,
17+
tensor_ops,
18+
tensor_typing,
19+
training,
20+
tensor_mapping,
21+
)
22+
from zetta_utils.layer import volumetric, db_layer
23+
from zetta_utils.layer.db_layer import datastore, firestore
24+
from zetta_utils.layer.volumetric import cloudvol
25+
26+
from zetta_utils import mazepa_addons
27+
from zetta_utils import message_queues
28+
from zetta_utils import cloud_management
29+
30+
from zetta_utils import internal
31+
32+
_elapsed = time.perf_counter() - _start
33+
log.get_logger("zetta_utils").debug(f"Preload training modules: {_elapsed:.2f}s")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# pylint: disable=unused-import, broad-exception-caught
2+
"""Try-load variant with error handling."""
3+
4+
from zetta_utils import log
5+
6+
logger = log.get_logger("zetta_utils")
7+
8+
try:
9+
from zetta_utils.builder.preload import core
10+
except Exception as e:
11+
logger.exception(e)
12+
13+
try:
14+
from zetta_utils.builder.preload import inference
15+
except Exception as e:
16+
logger.exception(e)
17+
18+
try:
19+
from zetta_utils.builder.preload import training
20+
except Exception as e:
21+
logger.exception(e)
22+
23+
try:
24+
from zetta_utils import mazepa_addons
25+
except Exception as e:
26+
logger.exception(e)

zetta_utils/cli/main.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import multiprocessing
12
import os
23
import pprint
34
import subprocess
45
import sys
6+
import threading
7+
import time
58
from tempfile import NamedTemporaryFile
6-
from typing import Optional
9+
from typing import Literal, Optional, cast
710

811
import click
912

@@ -15,6 +18,35 @@
1518

1619
logger = log.get_logger("zetta_utils")
1720

21+
LoadMode = Literal["all", "inference", "training", "try"]
22+
23+
_PRELOAD_MODULES: dict[LoadMode, str] = {
24+
"all": "zetta_utils.builder.preload.all",
25+
"inference": "zetta_utils.builder.preload.inference",
26+
"training": "zetta_utils.builder.preload.training",
27+
"try": "zetta_utils.builder.preload.try_load",
28+
}
29+
30+
31+
def _noop() -> None:
32+
pass
33+
34+
35+
def initialize_forkserver(load_mode: LoadMode = "all"):
36+
"""Initialize forkserver with preloaded modules for the given load mode."""
37+
preload_module = _PRELOAD_MODULES[load_mode]
38+
logger.info(f"Configuring forkserver with preload module: {preload_module}")
39+
40+
total_start = time.perf_counter()
41+
multiprocessing.set_forkserver_preload([preload_module])
42+
ctx = multiprocessing.get_context("forkserver")
43+
proc = ctx.Process(target=_noop)
44+
proc.start()
45+
proc.join()
46+
47+
total_elapsed = time.perf_counter() - total_start
48+
logger.info(f"Forkserver initialized in {total_elapsed:.2f}s (mode: {load_mode})")
49+
1850

1951
@click.group()
2052
@click.option("-v", "--verbose", count=True, default=2)
@@ -102,19 +134,27 @@ def run(
102134
):
103135
"""Perform ``zetta_utils.builder.build`` action on file contents."""
104136
ctx = click.get_current_context()
105-
load_mode = ctx.obj.get("load_mode", "all") if ctx and ctx.obj else "all"
137+
load_mode = cast(LoadMode, ctx.obj.get("load_mode", "all") if ctx and ctx.obj else "all")
106138

107-
# Load modules first
139+
# Start forkserver init in background while main process loads modules
140+
forkserver_thread = threading.Thread(
141+
target=initialize_forkserver, args=(load_mode,), name="forkserver_init"
142+
)
143+
forkserver_thread.start()
144+
145+
# Load modules in main process (runs in parallel with forkserver init)
108146
if load_mode == "all":
109147
zetta_utils.load_all_modules()
110148
elif load_mode == "inference": # pragma: no cover
111149
zetta_utils.load_inference_modules()
112150
elif load_mode == "try": # pragma: no cover
113151
zetta_utils.try_load_train_inference()
114152
else: # pragma: no cover
115-
assert load_mode == "training"
116153
zetta_utils.load_training_modules()
117154

155+
# Wait for forkserver to be ready before proceeding
156+
forkserver_thread.join()
157+
118158
from zetta_utils import builder, parsing # pylint: disable=import-outside-toplevel
119159
from zetta_utils.run import ( # pylint: disable=import-outside-toplevel
120160
run_ctx_manager,

0 commit comments

Comments
 (0)