|
| 1 | +"""Gymnasium environment for Dynamic Algorithm Selection on COCO-BBOB. |
| 2 | +
|
| 3 | +Each episode corresponds to one optimization run on a single BBOB problem. |
| 4 | +At every timestep the agent picks which sub-optimizer to run next; the |
| 5 | +optimizer then runs until the next exponentially-spaced checkpoint. |
| 6 | +
|
| 7 | +Observation space : Box(-inf, +inf, shape=(state_dim,)) – normalized externally |
| 8 | + via stable-baselines3's VecNormalize. |
| 9 | +Action space : Discrete(n_optimizers) |
| 10 | +Reward : Fitness improvement, scaled and shaped by reward_option. |
| 11 | +""" |
| 12 | + |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import gymnasium as gym |
| 17 | +from gymnasium import spaces |
| 18 | + |
| 19 | +from das.env.observation import compute_observation, observation_dim |
| 20 | +from das.env.reward import compute_reward |
| 21 | +from das.optimizers.base import get_checkpoints |
| 22 | + |
| 23 | + |
| 24 | +class DASEnv(gym.Env): |
| 25 | + """DAS environment. |
| 26 | +
|
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + problem_ids: |
| 30 | + BBOB problem IDs to cycle through (one per episode). |
| 31 | + suite: |
| 32 | + cocoex Suite object to fetch problems from. |
| 33 | + optimizers: |
| 34 | + Ordered list of sub-optimizer classes (defines the action space). |
| 35 | + fe_multiplier: |
| 36 | + Budget = fe_multiplier * problem_dimension. |
| 37 | + n_checkpoints: |
| 38 | + Number of optimizer-selection steps per episode. |
| 39 | + checkpoint_division_base (cdb): |
| 40 | + cdb=1.0 → uniform checkpoints; cdb>1.0 → exponentially growing intervals. |
| 41 | + reward_option: |
| 42 | + 1=log-scaled, 2=linear, 3=sparse, 4=binary (see das/env/reward.py). |
| 43 | + n_individuals: |
| 44 | + Population size shared across all sub-optimizers. |
| 45 | + """ |
| 46 | + |
| 47 | + metadata = {"render_modes": []} |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + problem_ids: list[str], |
| 52 | + suite, |
| 53 | + optimizers: list, |
| 54 | + fe_multiplier: int = 10_000, |
| 55 | + n_checkpoints: int = 10, |
| 56 | + checkpoint_division_base: float = 1.0, |
| 57 | + reward_option: int = 1, |
| 58 | + n_individuals: int = 100, |
| 59 | + ): |
| 60 | + super().__init__() |
| 61 | + self.problem_ids = problem_ids |
| 62 | + self.suite = suite |
| 63 | + self.optimizers = optimizers |
| 64 | + self.fe_multiplier = fe_multiplier |
| 65 | + self.n_checkpoints = n_checkpoints |
| 66 | + self.cdb = checkpoint_division_base |
| 67 | + self.reward_option = reward_option |
| 68 | + self.n_individuals = n_individuals |
| 69 | + |
| 70 | + n_actions = len(optimizers) |
| 71 | + obs_dim = observation_dim(n_actions) |
| 72 | + |
| 73 | + self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32) |
| 74 | + self.action_space = spaces.Discrete(n_actions) |
| 75 | + |
| 76 | + # Episode state – reset() initialises these |
| 77 | + self._problem = None |
| 78 | + self._problem_idx = 0 |
| 79 | + self._max_fe = 0 |
| 80 | + self._n_fe = 0 |
| 81 | + self._checkpoints: np.ndarray | None = None |
| 82 | + self._checkpoint_idx = 0 |
| 83 | + |
| 84 | + self._optimizer_state: dict = {} # passed between sub-optimizers for warm-starting |
| 85 | + self._x_history: np.ndarray | None = None |
| 86 | + self._y_history: np.ndarray | None = None |
| 87 | + |
| 88 | + self._best_y = float("inf") |
| 89 | + self._best_x: np.ndarray | None = None |
| 90 | + self._worst_y = -np.inf |
| 91 | + self._initial_range: tuple[float, float] = (float("inf"), -np.inf) |
| 92 | + self._stagnation_count = 0 |
| 93 | + self._choices_history: list[int] = [] |
| 94 | + |
| 95 | + # ------------------------------------------------------------------ # |
| 96 | + # Gymnasium interface # |
| 97 | + # ------------------------------------------------------------------ # |
| 98 | + |
| 99 | + def reset(self, seed=None, options=None): |
| 100 | + super().reset(seed=seed) |
| 101 | + |
| 102 | + problem_id = self.problem_ids[self._problem_idx % len(self.problem_ids)] |
| 103 | + self._problem_idx += 1 |
| 104 | + |
| 105 | + self._problem = self.suite.get_problem(problem_id) |
| 106 | + dim = self._problem.dimension |
| 107 | + self._max_fe = self.fe_multiplier * dim |
| 108 | + self._checkpoints = get_checkpoints(self.n_checkpoints, self._max_fe, self.n_individuals, self.cdb) |
| 109 | + |
| 110 | + # Reset episode bookkeeping |
| 111 | + self._n_fe = 0 |
| 112 | + self._checkpoint_idx = 0 |
| 113 | + self._optimizer_state = {} |
| 114 | + self._x_history = None |
| 115 | + self._y_history = None |
| 116 | + self._best_y = float("inf") |
| 117 | + self._best_x = None |
| 118 | + self._worst_y = -np.inf |
| 119 | + self._initial_range = (float("inf"), -np.inf) |
| 120 | + self._stagnation_count = 0 |
| 121 | + self._choices_history = [] |
| 122 | + |
| 123 | + obs = self._build_observation() |
| 124 | + info = {"problem_id": problem_id, "dimension": dim} |
| 125 | + return obs, info |
| 126 | + |
| 127 | + def step(self, action: int): |
| 128 | + assert self._problem is not None, "Call reset() before step()" |
| 129 | + |
| 130 | + target_fe = int(self._checkpoints[self._checkpoint_idx]) |
| 131 | + prev_best_y = self._best_y |
| 132 | + |
| 133 | + result = self._run_optimizer(action, target_fe) |
| 134 | + |
| 135 | + self._update_episode_state(result, prev_best_y) |
| 136 | + self._choices_history.append(action) |
| 137 | + self._checkpoint_idx += 1 |
| 138 | + |
| 139 | + terminated = ( |
| 140 | + self._checkpoint_idx >= self.n_checkpoints |
| 141 | + or self._n_fe >= self._max_fe |
| 142 | + ) |
| 143 | + reward = compute_reward( |
| 144 | + self._best_y, |
| 145 | + prev_best_y, |
| 146 | + self._initial_range, |
| 147 | + option=self.reward_option, |
| 148 | + is_final=terminated, |
| 149 | + ) |
| 150 | + |
| 151 | + obs = self._build_observation() |
| 152 | + info = { |
| 153 | + "best_y": self._best_y, |
| 154 | + "n_fe": self._n_fe, |
| 155 | + "checkpoint": self._checkpoint_idx, |
| 156 | + } |
| 157 | + return obs, reward, terminated, False, info |
| 158 | + |
| 159 | + # ------------------------------------------------------------------ # |
| 160 | + # Internal helpers # |
| 161 | + # ------------------------------------------------------------------ # |
| 162 | + |
| 163 | + def _run_optimizer(self, action: int, target_fe: int) -> dict: |
| 164 | + """Instantiate the selected sub-optimizer and run it to target_fe.""" |
| 165 | + optimizer_class = self.optimizers[action] |
| 166 | + problem_config = { |
| 167 | + "fitness_function": self._problem, |
| 168 | + "ndim_problem": self._problem.dimension, |
| 169 | + "lower_boundary": self._problem.lower_bounds, |
| 170 | + "upper_boundary": self._problem.upper_bounds, |
| 171 | + } |
| 172 | + options = { |
| 173 | + "max_function_evaluations": self._max_fe, |
| 174 | + "target_fe": target_fe, |
| 175 | + "n_individuals": self.n_individuals, |
| 176 | + "best_so_far_y": self._best_y, |
| 177 | + "verbose": False, |
| 178 | + } |
| 179 | + optimizer = optimizer_class(problem_config, options) |
| 180 | + optimizer.n_function_evaluations = self._n_fe |
| 181 | + |
| 182 | + optimizer.set_data( |
| 183 | + best_x=self._best_x, |
| 184 | + best_y=self._best_y if self._best_y < float("inf") else None, |
| 185 | + **self._optimizer_state, |
| 186 | + ) |
| 187 | + result = optimizer.optimize() |
| 188 | + # result may be (result_dict, agent_state) tuple in subclasses; normalise |
| 189 | + if isinstance(result, tuple): |
| 190 | + result = result[0] |
| 191 | + |
| 192 | + # Update warm-start state for next step |
| 193 | + new_state = optimizer.get_data() |
| 194 | + if new_state: |
| 195 | + self._optimizer_state = new_state |
| 196 | + else: |
| 197 | + # Fallback: carry x/y from the population history |
| 198 | + if len(optimizer.x_history) > 0: |
| 199 | + self._optimizer_state = { |
| 200 | + "x": np.array(optimizer.x_history[-self.n_individuals :]), |
| 201 | + "y": np.array(optimizer.y_history[-self.n_individuals :]), |
| 202 | + } |
| 203 | + |
| 204 | + return result |
| 205 | + |
| 206 | + def _update_episode_state(self, result: dict, prev_best_y: float): |
| 207 | + new_best_y: float = result.get("best_so_far_y", float("inf")) |
| 208 | + new_best_x: np.ndarray | None = result.get("best_so_far_x") |
| 209 | + worst_y: float = result.get("worst_so_far_y", -np.inf) |
| 210 | + |
| 211 | + if new_best_y < self._best_y: |
| 212 | + self._best_y = new_best_y |
| 213 | + self._best_x = new_best_x |
| 214 | + |
| 215 | + if worst_y > self._worst_y: |
| 216 | + self._worst_y = worst_y |
| 217 | + |
| 218 | + # Set initial range on first step |
| 219 | + if self._initial_range[0] == float("inf"): |
| 220 | + self._initial_range = (new_best_y, max(worst_y, new_best_y + 1e-5)) |
| 221 | + |
| 222 | + # Stagnation counter |
| 223 | + x_hist: np.ndarray | None = result.get("x_history") |
| 224 | + y_hist: np.ndarray | None = result.get("y_history") |
| 225 | + n_fe_step = len(y_hist) if y_hist is not None else 0 |
| 226 | + |
| 227 | + if new_best_y >= prev_best_y: |
| 228 | + self._stagnation_count += n_fe_step |
| 229 | + else: |
| 230 | + self._stagnation_count = 0 |
| 231 | + |
| 232 | + self._n_fe = result.get("n_function_evaluations", self._n_fe + n_fe_step) |
| 233 | + |
| 234 | + # Accumulate population history for ELA |
| 235 | + if x_hist is not None and len(x_hist) > 0: |
| 236 | + self._x_history = x_hist if self._x_history is None else np.concatenate([self._x_history, x_hist]) |
| 237 | + self._y_history = y_hist if self._y_history is None else np.concatenate([self._y_history, y_hist]) |
| 238 | + |
| 239 | + def _build_observation(self) -> np.ndarray: |
| 240 | + return compute_observation( |
| 241 | + x_history=self._x_history, |
| 242 | + y_history=self._y_history, |
| 243 | + choices_history=self._choices_history, |
| 244 | + n_actions=len(self.optimizers), |
| 245 | + n_checkpoints=self.n_checkpoints, |
| 246 | + n_fe=self._n_fe, |
| 247 | + max_fe=max(self._max_fe, 1), |
| 248 | + stagnation_count=self._stagnation_count, |
| 249 | + ndim_problem=self._problem.dimension if self._problem is not None else 1, |
| 250 | + ) |
0 commit comments