2626from das .env .reward import compute_reward
2727from das .optimizers .base import get_checkpoints
2828
29- # Recompute ELA every ~500 new population samples. pflacco runs regression,
30- # nearest-neighbour search, and IC calculations on every call — running it
31- # every step would dominate wall-clock time for long training runs.
32- _ELA_RECOMPUTE_THRESHOLD = MAX_HISTORY_SAMPLE // 5
3329
3430
3531class DASEnv (gym .Env ):
@@ -68,6 +64,7 @@ def __init__(
6864 reward_option : int = 1 ,
6965 n_individuals : int = 100 ,
7066 seed : int | None = None ,
67+ ela_recompute_every : int = MAX_HISTORY_SAMPLE // 5 # ~500,
7168 ):
7269 super ().__init__ ()
7370 self .problem_ids = problem_ids
@@ -79,6 +76,7 @@ def __init__(
7976 self .reward_option = reward_option
8077 self .n_individuals = n_individuals
8178 self ._seed = seed
79+ self ._ela_recompute_every = max (1 , ela_recompute_every )
8280
8381 n_actions = len (optimizers )
8482 obs_dim = observation_dim (n_actions )
@@ -293,7 +291,10 @@ def _build_observation(self) -> np.ndarray:
293291 # _ela_cache starts as zeros (correct before 50 samples) and is reset
294292 # each episode, so stale features from a previous episode never leak in.
295293 current_len = len (self ._x_history ) if self ._x_history is not None else 0
296- if current_len >= 50 and current_len - self ._ela_cache_len >= _ELA_RECOMPUTE_THRESHOLD :
294+ if current_len >= 50 and (
295+ self ._ela_cache_len == 0
296+ or current_len - self ._ela_cache_len >= self ._ela_recompute_every
297+ ):
297298 self ._ela_cache = compute_ela_features (self ._x_history , self ._y_history )
298299 self ._ela_cache_len = current_len
299300
0 commit comments