1616import gymnasium as gym
1717from gymnasium import spaces
1818
19- from das .env .observation import (
20- compute_observation ,
21- observation_dim ,
22- compute_ela_features ,
23- MAX_HISTORY_SAMPLE ,
24- ELA_DIM ,
25- )
19+ from das .env .observation import (compute_observation , observation_dim , MAX_HISTORY_SAMPLE )
2620from das .env .reward import compute_reward
2721from das .optimizers .base import get_checkpoints
2822
29- # Recompute ELA every ~500 new population samples. pflacco runs regression,
30- # nearest-neighbour search, and IC calculations on every call — running it
31- # every step would dominate wall-clock time for long training runs.
32- _ELA_RECOMPUTE_THRESHOLD = MAX_HISTORY_SAMPLE // 5
33-
3423
3524class DASEnv (gym .Env ):
3625 """DAS environment.
@@ -121,11 +110,6 @@ def __init__(
121110 self ._stagnation_count = 0
122111 self ._choices_history : list [int ] = []
123112
124- # ELA features are expensive; cache the last computed vector and refresh
125- # lazily once _ELA_RECOMPUTE_THRESHOLD new samples have arrived.
126- self ._ela_cache : np .ndarray = np .zeros (ELA_DIM , dtype = np .float32 )
127- self ._ela_cache_len : int = 0
128-
129113 # ------------------------------------------------------------------ #
130114 # Gymnasium interface #
131115 # ------------------------------------------------------------------ #
@@ -156,8 +140,6 @@ def reset(self, seed=None, options=None):
156140 self ._initial_range = (float ("inf" ), - np .inf )
157141 self ._stagnation_count = 0
158142 self ._choices_history = []
159- self ._ela_cache = np .zeros (ELA_DIM , dtype = np .float32 )
160- self ._ela_cache_len = 0
161143
162144 obs = self ._build_observation ()
163145 info = {"problem_id" : problem_id , "dimension" : dim }
@@ -312,13 +294,6 @@ def _update_episode_state(self, result: dict, prev_best_y: float):
312294 )
313295
314296 def _build_observation (self ) -> np .ndarray :
315- # Recompute ELA only when enough new samples have arrived.
316- # _ela_cache starts as zeros (correct before 50 samples) and is reset
317- # each episode, so stale features from a previous episode never leak in.
318- current_len = len (self ._x_history ) if self ._x_history is not None else 0
319- if current_len >= 50 and current_len - self ._ela_cache_len >= _ELA_RECOMPUTE_THRESHOLD :
320- self ._ela_cache = compute_ela_features (self ._x_history , self ._y_history )
321- self ._ela_cache_len = current_len
322297
323298 return compute_observation (
324299 x_history = self ._x_history ,
@@ -330,5 +305,4 @@ def _build_observation(self) -> np.ndarray:
330305 max_fe = max (self ._max_fe , 1 ),
331306 stagnation_count = self ._stagnation_count ,
332307 ndim_problem = self ._problem .dimension if self ._problem is not None else 1 ,
333- ela = self ._ela_cache ,
334308 )
0 commit comments