Skip to content

Commit 5985c4b

Browse files
committed
test(problems): limit number of iterations in optimize
The number can be overridden by using the environment variable ENGIBENCH_MAX_ITER
1 parent bd15511 commit 5985c4b

1 file changed

Lines changed: 21 additions & 14 deletions

File tree

tests/test_problem_implementations.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import dataclasses
44
import inspect
5+
import os
56
import sys
6-
from typing import get_args, get_origin
7+
from typing import Any, get_args, get_origin
78

89
import gymnasium
910
from gymnasium import spaces
@@ -13,9 +14,6 @@
1314
from engibench import Problem
1415
from engibench.utils.all_problems import BUILTIN_PROBLEMS
1516

16-
PYTHON_PROBLEMS = [p for p in BUILTIN_PROBLEMS.values() if p.container_id is None]
17-
CONTAINER_PROBLEMS = [p for p in BUILTIN_PROBLEMS.values() if p.container_id is not None]
18-
1917

2018
@pytest.mark.parametrize("problem_class", BUILTIN_PROBLEMS.values())
2119
def test_problem_impl(problem_class: type[Problem]) -> None:
@@ -88,10 +86,12 @@ def test_problem_impl(problem_class: type[Problem]) -> None:
8886
print(f"Done testing {problem_class.__name__}.")
8987

9088

91-
@pytest.mark.parametrize(
92-
"problem_class",
93-
PYTHON_PROBLEMS + (CONTAINER_PROBLEMS if sys.platform.startswith("linux") else []),
94-
)
89+
def problem_id(problem_class: type[Problem]) -> str:
90+
id_, _ = problem_class.__module__.removeprefix("engibench.problems.").split(".", 1)
91+
return id_
92+
93+
94+
@pytest.mark.parametrize("problem_class", BUILTIN_PROBLEMS.values())
9595
def test_python_problem_impl(problem_class: type[Problem]) -> None:
9696
"""Check that all problems defined in Python files respect the API.
9797
@@ -100,6 +100,8 @@ def test_python_problem_impl(problem_class: type[Problem]) -> None:
100100
2. The optimization produces valid designs within the design space
101101
3. The optimization history contains valid objective values
102102
"""
103+
if problem_class.container_id is not None and not sys.platform.startswith("linux"):
104+
pytest.skip(f"Skipping containerized problem {problem_class.__name__} on non-linux platform")
103105
print(f"Testing optimization and simulation for {problem_class.__name__}...")
104106
# Initialize problem and get a random design
105107
problem = problem_class(seed=1)
@@ -121,15 +123,20 @@ def test_python_problem_impl(problem_class: type[Problem]) -> None:
121123
# Test optimization outputs
122124
print(f"Optimizing {problem_class.__name__}...")
123125
# Skip optimization test for power electronics, airfoil, and heat conduction problems
124-
if (
125-
problem_class.__module__.startswith("engibench.problems.power_electronics")
126-
or problem_class.__module__.startswith("engibench.problems.airfoil")
127-
or problem_class.__module__.startswith("engibench.problems.heatconduction")
128-
):
126+
if problem_id(problem_class) == "airfoil":
129127
print(f"Skipping optimization test for {problem_class.__name__}")
130128
return
131129
problem.reset(seed=1)
132-
optimal_design, history = problem.optimize(starting_point=design)
130+
default_max_iter = 10
131+
max_iter = os.environ.get("ENGIBENCH_MAX_ITER", default_max_iter)
132+
max_iter_config: dict[str, Any] = {
133+
key: max_iter for key in ("max_iter", "num_optimization_steps") if hasattr(problem_class.Config, key)
134+
}
135+
try:
136+
optimal_design, history = problem.optimize(starting_point=design, config=max_iter_config)
137+
except NotImplementedError:
138+
print("Problem class {problem_class.__name__} does not implement optimize - Skipping optimize")
139+
return
133140
if isinstance(problem.design_space, spaces.Box):
134141
assert np.all(optimal_design >= problem.design_space.low), (
135142
f"Problem {problem_class.__name__}: The optimal design should be within the design space."

0 commit comments

Comments
 (0)