From 8c1cc8fd9a8f1991f7e054d102a0a6a5675e8dca Mon Sep 17 00:00:00 2001 From: Cen Wang Date: Fri, 26 Sep 2025 09:27:48 -0400 Subject: [PATCH 1/2] Enable Rust extension via environment variable SIMOPT_EXT --- pyproject.toml | 6 ++++++ simopt/models/amusementpark.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 3e7ec4ad..14ff055b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,9 @@ exclude = ["build*"] [tool.setuptools.package-data] "simopt.data_farming" = ["*.npz"] +[tool.uv.sources] +simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "b830a95" } + [project] name = "simoptlib" version = "1.2.2.dev0" @@ -53,6 +56,9 @@ dev = [ ] docs = ["sphinx>=8.2.3", "sphinx-autoapi>=3.6.1", "sphinx-rtd-theme>=3.0.2"] notebooks = ["ipykernel>=7.1.0"] +ext = [ + "simopt-extensions", +] [project.urls] "Homepage" = "https://github.com/simopt-admin/simopt" diff --git a/simopt/models/amusementpark.py b/simopt/models/amusementpark.py index 1edc8503..00e318b5 100644 --- a/simopt/models/amusementpark.py +++ b/simopt/models/amusementpark.py @@ -488,6 +488,25 @@ def set_completion(i: int, new_time: float) -> None: return responses, {} +def _patch() -> None: + import os + + if os.getenv("SIMOPT_EXT") != "1": + return + + try: + from simopt_extensions.amusement_park import replicate + except ImportError as e: + raise ImportError( + "SIMOPT_EXT=1 is set but simopt_extensions not installed" + ) from e + + AmusementPark.replicate = replicate + + +_patch() + + class AmusementParkMinDepart(Problem): """Class to make amusement park simulation-optimization problems.""" From 16d207dbdd33993d686f2090ebbf03c27a85cfde Mon Sep 17 00:00:00 2001 From: Cen Wang Date: Thu, 27 Nov 2025 19:48:45 -0500 Subject: [PATCH 2/2] Implement more flexible extension loading --- pyproject.toml | 2 +- simopt/models/_ext.py | 56 ++++++++++++++++++++++++++++++++++ simopt/models/amusementpark.py | 23 +++----------- 3 files changed, 61 insertions(+), 20 deletions(-) create mode 100644 simopt/models/_ext.py diff --git a/pyproject.toml b/pyproject.toml index 14ff055b..4a301d1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ exclude = ["build*"] "simopt.data_farming" = ["*.npz"] [tool.uv.sources] -simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "b830a95" } +simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "2bcf002" } [project] name = "simoptlib" diff --git a/simopt/models/_ext.py b/simopt/models/_ext.py new file mode 100644 index 00000000..17a751f3 --- /dev/null +++ b/simopt/models/_ext.py @@ -0,0 +1,56 @@ +import importlib +import os +from collections.abc import Callable +from types import ModuleType + +from simopt.model import Model + + +def find_model_patches(module: ModuleType) -> list[Callable]: + patches: list[Callable] = [] + for attr in dir(module): + if not attr.startswith("patch_model"): + continue + candidate = getattr(module, attr) + if callable(candidate): + patches.append(candidate) + return patches + + +def patch(model_class: type[Model], patch_function: Callable) -> None: + # Ask the extension what class to patch and with what method + class_name, method = patch_function() + # Patch the method into the model_class if it matches + full_name = model_class.__module__ + "." + model_class.__qualname__ + if full_name == class_name: + model_class.replicate = method + + +def load_module(module_name: str, model_class: type[Model]) -> None: + # Import the specified library + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ImportError(f"SimOpt failed to load extension '{module_name}'") from e + + # Find all patch_model* functions in the library + patches = find_model_patches(module) + if not patches: + raise ImportError(f"'{module_name}' does not have any 'patch_model*' functions") + + # Apply each patch to the model_class + for p in patches: + patch(model_class, p) + + +def patch_model(model_class: type[Model]) -> None: + env_var = os.environ.get("SIMOPT_EXT") + if not env_var: + return + + # Assume that the user has specified a comma-separated list of libraries to import. + # For example, SIMOPT_EXT="simopt_extension_a,simopt_extension_b" + for part in env_var.split(","): + module_name = part.strip() + if module_name: + load_module(module_name, model_class) diff --git a/simopt/models/amusementpark.py b/simopt/models/amusementpark.py index 00e318b5..c7a08f95 100644 --- a/simopt/models/amusementpark.py +++ b/simopt/models/amusementpark.py @@ -17,6 +17,7 @@ VariableType, ) from simopt.input_models import Exp, Gamma, WeightedChoice +from simopt.models._ext import patch_model INF = float("inf") @@ -488,25 +489,6 @@ def set_completion(i: int, new_time: float) -> None: return responses, {} -def _patch() -> None: - import os - - if os.getenv("SIMOPT_EXT") != "1": - return - - try: - from simopt_extensions.amusement_park import replicate - except ImportError as e: - raise ImportError( - "SIMOPT_EXT=1 is set but simopt_extensions not installed" - ) from e - - AmusementPark.replicate = replicate - - -_patch() - - class AmusementParkMinDepart(Problem): """Class to make amusement park simulation-optimization problems.""" @@ -563,3 +545,6 @@ def get_random_solution(self, rand_sol_rng: MRG32k3a) -> tuple: # noqa: D102 n_elements=num_elements, summation=summation, with_zero=False ) return tuple(vector) + + +patch_model(AmusementPark)