diff --git a/pyproject.toml b/pyproject.toml index 3e7ec4ad..4a301d1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,9 @@ exclude = ["build*"] [tool.setuptools.package-data] "simopt.data_farming" = ["*.npz"] +[tool.uv.sources] +simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "2bcf002" } + [project] name = "simoptlib" version = "1.2.2.dev0" @@ -53,6 +56,9 @@ dev = [ ] docs = ["sphinx>=8.2.3", "sphinx-autoapi>=3.6.1", "sphinx-rtd-theme>=3.0.2"] notebooks = ["ipykernel>=7.1.0"] +ext = [ + "simopt-extensions", +] [project.urls] "Homepage" = "https://github.com/simopt-admin/simopt" diff --git a/simopt/models/_ext.py b/simopt/models/_ext.py new file mode 100644 index 00000000..17a751f3 --- /dev/null +++ b/simopt/models/_ext.py @@ -0,0 +1,56 @@ +import importlib +import os +from collections.abc import Callable +from types import ModuleType + +from simopt.model import Model + + +def find_model_patches(module: ModuleType) -> list[Callable]: + patches: list[Callable] = [] + for attr in dir(module): + if not attr.startswith("patch_model"): + continue + candidate = getattr(module, attr) + if callable(candidate): + patches.append(candidate) + return patches + + +def patch(model_class: type[Model], patch_function: Callable) -> None: + # Ask the extension what class to patch and with what method + class_name, method = patch_function() + # Patch the method into the model_class if it matches + full_name = model_class.__module__ + "." + model_class.__qualname__ + if full_name == class_name: + model_class.replicate = method + + +def load_module(module_name: str, model_class: type[Model]) -> None: + # Import the specified library + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ImportError(f"SimOpt failed to load extension '{module_name}'") from e + + # Find all patch_model* functions in the library + patches = find_model_patches(module) + if not patches: + raise ImportError(f"'{module_name}' does not have any 'patch_model*' functions") + + # Apply each patch to the model_class + for p in patches: + patch(model_class, p) + + +def patch_model(model_class: type[Model]) -> None: + env_var = os.environ.get("SIMOPT_EXT") + if not env_var: + return + + # Assume that the user has specified a comma-separated list of libraries to import. + # For example, SIMOPT_EXT="simopt_extension_a,simopt_extension_b" + for part in env_var.split(","): + module_name = part.strip() + if module_name: + load_module(module_name, model_class) diff --git a/simopt/models/amusementpark.py b/simopt/models/amusementpark.py index 1edc8503..c7a08f95 100644 --- a/simopt/models/amusementpark.py +++ b/simopt/models/amusementpark.py @@ -17,6 +17,7 @@ VariableType, ) from simopt.input_models import Exp, Gamma, WeightedChoice +from simopt.models._ext import patch_model INF = float("inf") @@ -544,3 +545,6 @@ def get_random_solution(self, rand_sol_rng: MRG32k3a) -> tuple: # noqa: D102 n_elements=num_elements, summation=summation, with_zero=False ) return tuple(vector) + + +patch_model(AmusementPark)