22
33import dataclasses
44import inspect
5+ import os
56import sys
6- from typing import get_args , get_origin
7+ from typing import Any , get_args , get_origin
78
89import gymnasium
910from gymnasium import spaces
1314from engibench import Problem
1415from engibench .utils .all_problems import BUILTIN_PROBLEMS
1516
16- PYTHON_PROBLEMS = [p for p in BUILTIN_PROBLEMS .values () if p .container_id is None ]
17- CONTAINER_PROBLEMS = [p for p in BUILTIN_PROBLEMS .values () if p .container_id is not None ]
18-
1917
2018@pytest .mark .parametrize ("problem_class" , BUILTIN_PROBLEMS .values ())
2119def test_problem_impl (problem_class : type [Problem ]) -> None :
@@ -88,10 +86,12 @@ def test_problem_impl(problem_class: type[Problem]) -> None:
8886 print (f"Done testing { problem_class .__name__ } ." )
8987
9088
91- @pytest .mark .parametrize (
92- "problem_class" ,
93- PYTHON_PROBLEMS + (CONTAINER_PROBLEMS if sys .platform .startswith ("linux" ) else []),
94- )
89+ def problem_id (problem_class : type [Problem ]) -> str :
90+ id_ , _ = problem_class .__module__ .removeprefix ("engibench.problems." ).split ("." , 1 )
91+ return id_
92+
93+
94+ @pytest .mark .parametrize ("problem_class" , BUILTIN_PROBLEMS .values ())
9595def test_python_problem_impl (problem_class : type [Problem ]) -> None :
9696 """Check that all problems defined in Python files respect the API.
9797
@@ -100,6 +100,8 @@ def test_python_problem_impl(problem_class: type[Problem]) -> None:
100100 2. The optimization produces valid designs within the design space
101101 3. The optimization history contains valid objective values
102102 """
103+ if problem_class .container_id is not None and not sys .platform .startswith ("linux" ):
104+ pytest .skip (f"Skipping containerized problem { problem_class .__name__ } on non-linux platform" )
103105 print (f"Testing optimization and simulation for { problem_class .__name__ } ..." )
104106 # Initialize problem and get a random design
105107 problem = problem_class (seed = 1 )
@@ -121,15 +123,20 @@ def test_python_problem_impl(problem_class: type[Problem]) -> None:
121123 # Test optimization outputs
122124 print (f"Optimizing { problem_class .__name__ } ..." )
123125 # Skip optimization test for power electronics, airfoil, and heat conduction problems
124- if (
125- problem_class .__module__ .startswith ("engibench.problems.power_electronics" )
126- or problem_class .__module__ .startswith ("engibench.problems.airfoil" )
127- or problem_class .__module__ .startswith ("engibench.problems.heatconduction" )
128- ):
126+ if problem_id (problem_class ) == "airfoil" :
129127 print (f"Skipping optimization test for { problem_class .__name__ } " )
130128 return
131129 problem .reset (seed = 1 )
132- optimal_design , history = problem .optimize (starting_point = design )
130+ default_max_iter = 10
131+ max_iter = os .environ .get ("ENGIBENCH_MAX_ITER" , default_max_iter )
132+ max_iter_config : dict [str , Any ] = {
133+ key : max_iter for key in ("max_iter" , "num_optimization_steps" ) if hasattr (problem_class .Config , key )
134+ }
135+ try :
136+ optimal_design , history = problem .optimize (starting_point = design , config = max_iter_config )
137+ except NotImplementedError :
138+ print ("Problem class {problem_class.__name__} does not implement optimize - Skipping optimize" )
139+ return
133140 if isinstance (problem .design_space , spaces .Box ):
134141 assert np .all (optimal_design >= problem .design_space .low ), (
135142 f"Problem { problem_class .__name__ } : The optimal design should be within the design space."
0 commit comments