Skip to content

Commit 2a1c641

Browse files
committed
feat: add ela_recompute_every parameter for ELA feature recomputation control
1 parent e12e711 commit 2a1c641

4 files changed

Lines changed: 18 additions & 5 deletions

File tree

das/env/das_env.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626
from das.env.reward import compute_reward
2727
from das.optimizers.base import get_checkpoints
2828

29-
# Recompute ELA every ~500 new population samples. pflacco runs regression,
30-
# nearest-neighbour search, and IC calculations on every call — running it
31-
# every step would dominate wall-clock time for long training runs.
32-
_ELA_RECOMPUTE_THRESHOLD = MAX_HISTORY_SAMPLE // 5
3329

3430

3531
class DASEnv(gym.Env):
@@ -68,6 +64,7 @@ def __init__(
6864
reward_option: int = 1,
6965
n_individuals: int = 100,
7066
seed: int | None = None,
67+
ela_recompute_every: int = MAX_HISTORY_SAMPLE // 5 # ~500,
7168
):
7269
super().__init__()
7370
self.problem_ids = problem_ids
@@ -79,6 +76,7 @@ def __init__(
7976
self.reward_option = reward_option
8077
self.n_individuals = n_individuals
8178
self._seed = seed
79+
self._ela_recompute_every = max(1, ela_recompute_every)
8280

8381
n_actions = len(optimizers)
8482
obs_dim = observation_dim(n_actions)
@@ -293,7 +291,10 @@ def _build_observation(self) -> np.ndarray:
293291
# _ela_cache starts as zeros (correct before 50 samples) and is reset
294292
# each episode, so stale features from a previous episode never leak in.
295293
current_len = len(self._x_history) if self._x_history is not None else 0
296-
if current_len >= 50 and current_len - self._ela_cache_len >= _ELA_RECOMPUTE_THRESHOLD:
294+
if current_len >= 50 and (
295+
self._ela_cache_len == 0
296+
or current_len - self._ela_cache_len >= self._ela_recompute_every
297+
):
297298
self._ela_cache = compute_ela_features(self._x_history, self._y_history)
298299
self._ela_cache_len = current_len
299300

das/training/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _init():
3131
reward_option=cfg["reward_option"],
3232
n_individuals=cfg["n_individuals"],
3333
seed=cfg.get("seed"),
34+
ela_recompute_every=cfg.get("ela_recompute_every", 500),
3435
)
3536

3637
return _init

das/training/ppo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def run_ppo(args) -> None:
138138
"reward_option": args.reward_option,
139139
"n_individuals": args.n_individuals,
140140
"seed": args.seed,
141+
"ela_recompute_every": args.ela_recompute_every,
141142
}
142143

143144
print(f"Portfolio : {args.portfolio}")

train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def _add_shared_args(p: argparse.ArgumentParser) -> None:
7575
)
7676
p.add_argument("--n-individuals", type=int, default=100, help="Population size")
7777
p.add_argument("--seed", type=int, default=42)
78+
p.add_argument(
79+
"--ela-recompute-every",
80+
type=int,
81+
default=500,
82+
help=(
83+
"Recompute ELA features every N new population samples. "
84+
"Set to 1 to recompute on every step (slow but maximally fresh). "
85+
"Default: 500."
86+
),
87+
)
7888

7989

8090
def _parse_args() -> argparse.Namespace:

0 commit comments

Comments
 (0)