Skip to content

Commit d28a518

Browse files
committed
test(problems): limit number of iterations in optimize
The number can be overridden by using the environment variable ENGIBENCH_MAX_ITER
1 parent 2f961bf commit d28a518

1 file changed

Lines changed: 14 additions & 7 deletions

File tree

tests/test_problem_implementations.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import dataclasses
44
import inspect
5+
import os
56
import sys
6-
from typing import get_args, get_origin
7+
from typing import Any, get_args, get_origin
78

89
import gymnasium
910
from gymnasium import spaces
@@ -88,6 +89,11 @@ def test_problem_impl(problem_class: type[Problem]) -> None:
8889
print(f"Done testing {problem_class.__name__}.")
8990

9091

92+
def problem_id(problem_class: type[Problem]) -> str:
93+
id_, _ = problem_class.__module__.removeprefix("engibench.problems.").split(".", 1)
94+
return id_
95+
96+
9197
@pytest.mark.parametrize(
9298
"problem_class",
9399
PYTHON_PROBLEMS + (CONTAINER_PROBLEMS if sys.platform.startswith("linux") else []),
@@ -121,15 +127,16 @@ def test_python_problem_impl(problem_class: type[Problem]) -> None:
121127
# Test optimization outputs
122128
print(f"Optimizing {problem_class.__name__}...")
123129
# Skip optimization test for power electronics, airfoil, and heat conduction problems
124-
if (
125-
problem_class.__module__.startswith("engibench.problems.power_electronics")
126-
or problem_class.__module__.startswith("engibench.problems.airfoil")
127-
or problem_class.__module__.startswith("engibench.problems.heatconduction")
128-
):
130+
if problem_id(problem_class) in {"power_electronics", "airfoil", "heatconduction"}:
129131
print(f"Skipping optimization test for {problem_class.__name__}")
130132
return
131133
problem.reset(seed=1)
132-
optimal_design, history = problem.optimize(starting_point=design)
134+
default_max_iter = 20
135+
max_iter = os.environ.get("ENGIBENCH_MAX_ITER", default_max_iter)
136+
max_iter_config: dict[str, Any] = {
137+
key: max_iter for key in ("max_iter", "num_optimization_steps") if hasattr(problem_class.Config, key)
138+
}
139+
optimal_design, history = problem.optimize(starting_point=design, config=max_iter_config)
133140
if isinstance(problem.design_space, spaces.Box):
134141
assert np.all(optimal_design >= problem.design_space.low), (
135142
f"Problem {problem_class.__name__}: The optimal design should be within the design space."

0 commit comments

Comments
 (0)