|
2 | 2 |
|
3 | 3 | import dataclasses |
4 | 4 | import inspect |
| 5 | +import os |
5 | 6 | import sys |
6 | | -from typing import get_args, get_origin |
| 7 | +from typing import Any, get_args, get_origin |
7 | 8 |
|
8 | 9 | import gymnasium |
9 | 10 | from gymnasium import spaces |
@@ -88,6 +89,11 @@ def test_problem_impl(problem_class: type[Problem]) -> None: |
88 | 89 | print(f"Done testing {problem_class.__name__}.") |
89 | 90 |
|
90 | 91 |
|
| 92 | +def problem_id(problem_class: type[Problem]) -> str: |
| 93 | + id_, _ = problem_class.__module__.removeprefix("engibench.problems.").split(".", 1) |
| 94 | + return id_ |
| 95 | + |
| 96 | + |
91 | 97 | @pytest.mark.parametrize( |
92 | 98 | "problem_class", |
93 | 99 | PYTHON_PROBLEMS + (CONTAINER_PROBLEMS if sys.platform.startswith("linux") else []), |
@@ -121,15 +127,16 @@ def test_python_problem_impl(problem_class: type[Problem]) -> None: |
121 | 127 | # Test optimization outputs |
122 | 128 | print(f"Optimizing {problem_class.__name__}...") |
123 | 129 | # Skip optimization test for power electronics, airfoil, and heat conduction problems |
124 | | - if ( |
125 | | - problem_class.__module__.startswith("engibench.problems.power_electronics") |
126 | | - or problem_class.__module__.startswith("engibench.problems.airfoil") |
127 | | - or problem_class.__module__.startswith("engibench.problems.heatconduction") |
128 | | - ): |
| 130 | + if problem_id(problem_class) in {"power_electronics", "airfoil", "heatconduction"}: |
129 | 131 | print(f"Skipping optimization test for {problem_class.__name__}") |
130 | 132 | return |
131 | 133 | problem.reset(seed=1) |
132 | | - optimal_design, history = problem.optimize(starting_point=design) |
| 134 | + default_max_iter = 20 |
| 135 | + max_iter = os.environ.get("ENGIBENCH_MAX_ITER", default_max_iter) |
| 136 | + max_iter_config: dict[str, Any] = { |
| 137 | + key: max_iter for key in ("max_iter", "num_optimization_steps") if hasattr(problem_class.Config, key) |
| 138 | + } |
| 139 | + optimal_design, history = problem.optimize(starting_point=design, config=max_iter_config) |
133 | 140 | if isinstance(problem.design_space, spaces.Box): |
134 | 141 | assert np.all(optimal_design >= problem.design_space.low), ( |
135 | 142 | f"Problem {problem_class.__name__}: The optimal design should be within the design space." |
|
0 commit comments