Skip to content

Commit 16d207d

Browse files
committed
Implement more flexible extension loading
1 parent 8c1cc8f commit 16d207d

3 files changed

Lines changed: 61 additions & 20 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ exclude = ["build*"]
1212
"simopt.data_farming" = ["*.npz"]
1313

1414
[tool.uv.sources]
15-
simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "b830a95" }
15+
simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "2bcf002" }
1616

1717
[project]
1818
name = "simoptlib"

simopt/models/_ext.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import importlib
2+
import os
3+
from collections.abc import Callable
4+
from types import ModuleType
5+
6+
from simopt.model import Model
7+
8+
9+
def find_model_patches(module: ModuleType) -> list[Callable]:
10+
patches: list[Callable] = []
11+
for attr in dir(module):
12+
if not attr.startswith("patch_model"):
13+
continue
14+
candidate = getattr(module, attr)
15+
if callable(candidate):
16+
patches.append(candidate)
17+
return patches
18+
19+
20+
def patch(model_class: type[Model], patch_function: Callable) -> None:
21+
# Ask the extension what class to patch and with what method
22+
class_name, method = patch_function()
23+
# Patch the method into the model_class if it matches
24+
full_name = model_class.__module__ + "." + model_class.__qualname__
25+
if full_name == class_name:
26+
model_class.replicate = method
27+
28+
29+
def load_module(module_name: str, model_class: type[Model]) -> None:
30+
# Import the specified library
31+
try:
32+
module = importlib.import_module(module_name)
33+
except ImportError as e:
34+
raise ImportError(f"SimOpt failed to load extension '{module_name}'") from e
35+
36+
# Find all patch_model* functions in the library
37+
patches = find_model_patches(module)
38+
if not patches:
39+
raise ImportError(f"'{module_name}' does not have any 'patch_model*' functions")
40+
41+
# Apply each patch to the model_class
42+
for p in patches:
43+
patch(model_class, p)
44+
45+
46+
def patch_model(model_class: type[Model]) -> None:
47+
env_var = os.environ.get("SIMOPT_EXT")
48+
if not env_var:
49+
return
50+
51+
# Assume that the user has specified a comma-separated list of libraries to import.
52+
# For example, SIMOPT_EXT="simopt_extension_a,simopt_extension_b"
53+
for part in env_var.split(","):
54+
module_name = part.strip()
55+
if module_name:
56+
load_module(module_name, model_class)

simopt/models/amusementpark.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
VariableType,
1818
)
1919
from simopt.input_models import Exp, Gamma, WeightedChoice
20+
from simopt.models._ext import patch_model
2021

2122
INF = float("inf")
2223

@@ -488,25 +489,6 @@ def set_completion(i: int, new_time: float) -> None:
488489
return responses, {}
489490

490491

491-
def _patch() -> None:
492-
import os
493-
494-
if os.getenv("SIMOPT_EXT") != "1":
495-
return
496-
497-
try:
498-
from simopt_extensions.amusement_park import replicate
499-
except ImportError as e:
500-
raise ImportError(
501-
"SIMOPT_EXT=1 is set but simopt_extensions not installed"
502-
) from e
503-
504-
AmusementPark.replicate = replicate
505-
506-
507-
_patch()
508-
509-
510492
class AmusementParkMinDepart(Problem):
511493
"""Class to make amusement park simulation-optimization problems."""
512494

@@ -563,3 +545,6 @@ def get_random_solution(self, rand_sol_rng: MRG32k3a) -> tuple: # noqa: D102
563545
n_elements=num_elements, summation=summation, with_zero=False
564546
)
565547
return tuple(vector)
548+
549+
550+
patch_model(AmusementPark)

0 commit comments

Comments
 (0)