Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ exclude = ["build*"]
[tool.setuptools.package-data]
"simopt.data_farming" = ["*.npz"]

[tool.uv.sources]
simopt-extensions = { git = "https://github.com/cenwangumass/simopt-extensions", rev = "2bcf002" }

[project]
name = "simoptlib"
version = "1.2.2.dev0"
Expand Down Expand Up @@ -53,6 +56,9 @@ dev = [
]
docs = ["sphinx>=8.2.3", "sphinx-autoapi>=3.6.1", "sphinx-rtd-theme>=3.0.2"]
notebooks = ["ipykernel>=7.1.0"]
ext = [
"simopt-extensions",
]

[project.urls]
"Homepage" = "https://github.com/simopt-admin/simopt"
Expand Down
56 changes: 56 additions & 0 deletions simopt/models/_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib
import os
from collections.abc import Callable
from types import ModuleType

from simopt.model import Model


def find_model_patches(module: ModuleType) -> list[Callable]:
patches: list[Callable] = []
for attr in dir(module):
if not attr.startswith("patch_model"):
continue
candidate = getattr(module, attr)
if callable(candidate):
patches.append(candidate)
return patches


def patch(model_class: type[Model], patch_function: Callable) -> None:
# Ask the extension what class to patch and with what method
class_name, method = patch_function()
# Patch the method into the model_class if it matches
full_name = model_class.__module__ + "." + model_class.__qualname__
if full_name == class_name:
model_class.replicate = method


def load_module(module_name: str, model_class: type[Model]) -> None:
# Import the specified library
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise ImportError(f"SimOpt failed to load extension '{module_name}'") from e

# Find all patch_model* functions in the library
patches = find_model_patches(module)
if not patches:
raise ImportError(f"'{module_name}' does not have any 'patch_model*' functions")

# Apply each patch to the model_class
for p in patches:
patch(model_class, p)


def patch_model(model_class: type[Model]) -> None:
env_var = os.environ.get("SIMOPT_EXT")
if not env_var:
return

# Assume that the user has specified a comma-separated list of libraries to import.
# For example, SIMOPT_EXT="simopt_extension_a,simopt_extension_b"
for part in env_var.split(","):
module_name = part.strip()
if module_name:
load_module(module_name, model_class)
4 changes: 4 additions & 0 deletions simopt/models/amusementpark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VariableType,
)
from simopt.input_models import Exp, Gamma, WeightedChoice
from simopt.models._ext import patch_model

INF = float("inf")

Expand Down Expand Up @@ -544,3 +545,6 @@ def get_random_solution(self, rand_sol_rng: MRG32k3a) -> tuple: # noqa: D102
n_elements=num_elements, summation=summation, with_zero=False
)
return tuple(vector)


patch_model(AmusementPark)