Skip to content

Commit 6db1e02

Browse files
committed
fix normalization and tests
1 parent 4c82268 commit 6db1e02

5 files changed

Lines changed: 30 additions & 20 deletions

File tree

agents/exponential_das/normalizer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Running normalizers for observations and rewards.
22
33
Both use Welford's online algorithm for numerically stable mean/variance.
4-
Normalisation is only updated during the warmup phase (while the buffer is
5-
filling for the first time); afterwards the statistics are frozen. This
6-
mirrors the StateNormalizer behaviour in the source project.
4+
5+
ObservationNormalizer statistics are frozen after the warmup phase (first
6+
buffer fill) so the obs space presented to the actor/critic networks stays
7+
stable.
8+
9+
RewardNormalizer keeps updating throughout training so that its per-step
10+
statistics track the shifting reward distribution as the agent improves.
11+
This matches the StepwiseRewardNormalizer behaviour in the reference project.
712
"""
813

914
from __future__ import annotations

agents/exponential_das/trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,9 @@ def train(
7777
next_obs, reward, terminated, truncated, step_info = train_env.step(action)
7878
done = terminated or truncated
7979

80-
# Reward normalisation (update only during warmup)
81-
normed_reward = agent.rew_norm.normalize(
82-
reward, step_idx, update=not agent.buffer.warmed_up
83-
)
80+
# Reward normalisation: always update so stats track the shifting
81+
# reward distribution as the agent improves (matches reference).
82+
normed_reward = agent.rew_norm.normalize(reward, step_idx, update=True)
8483
ep_reward += reward
8584

8685
agent.buffer.add(obs, action, log_prob, value, normed_reward, done)

tests/test_baselines.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,16 +288,16 @@ def test_fitness_history_y_is_nonincreasing(self):
288288
assert ys[i] < ys[i - 1]
289289

290290
def test_fitness_history_nonempty_after_episode(self):
291-
"""At least one improvement must occur (first evaluation beats inf)."""
291+
"""reset() probe establishes a finite initial best; optimizer steps may
292+
not improve on it, so fitness_history from steps can be empty."""
292293
env = make_env()
293-
_, fitness_history = run_episode(env, random_policy)
294-
assert len(fitness_history) >= 1
294+
env.reset()
295+
assert np.isfinite(env._best_y)
295296

296297
def test_fixed_policy_runs_full_episode(self):
297298
env = make_env()
298-
step_info, fitness_history = run_episode(env, fixed_policy(0))
299+
step_info, _ = run_episode(env, fixed_policy(0))
299300
assert np.isfinite(step_info["best_y"])
300-
assert len(fitness_history) >= 1
301301

302302
def test_episode_advances_problem_idx(self):
303303
env = make_env()
@@ -858,7 +858,9 @@ def test_fitness_history_step_fe_within_budget(self):
858858
assert 1 <= fe <= max_fe
859859

860860
def test_fitness_history_step_accumulated_across_checkpoints(self):
861-
"""Full episode fitness history must contain at least as many points as one step."""
861+
"""fitness_history_step records improvements over the probe best.
862+
The probe in reset() may already be the episode's best, so this
863+
list can legitimately be empty; verify it is a list of valid tuples."""
862864
env = make_env()
863865
env.reset()
864866
all_history = []
@@ -868,8 +870,9 @@ def test_fitness_history_step_accumulated_across_checkpoints(self):
868870
done = terminated or truncated
869871
all_history.extend(info["fitness_history_step"])
870872

871-
# At minimum one improvement in the first checkpoint (from inf)
872-
assert len(all_history) >= 1
873+
assert isinstance(all_history, list)
874+
for fe, y in all_history:
875+
assert isinstance(fe, int) and isinstance(y, float)
873876

874877
def test_fitness_history_step_fe_monotone_across_episode(self):
875878
"""FE values accumulated across all checkpoints must be strictly increasing."""

tests/test_heterogeneous_portfolios.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def test_reset_clears_state_between_episodes(self, spec, fe_mult):
370370
env.reset()
371371
drain(env)
372372
env.reset()
373-
assert env._n_fe == 0
374-
assert env._best_y == float("inf")
373+
# reset() runs a random probe, so _n_fe > 0 and _best_y is finite
374+
assert env._n_fe > 0
375+
assert np.isfinite(env._best_y)
375376
assert env._optimizer_state == {}

tests/test_parallel_envs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ def test_reset_clears_all_episode_state(self):
169169
env.step(0)
170170
env.step(0) # mid-episode
171171
env.reset() # full reset
172-
assert env._n_fe == 0
173-
assert env._best_y == float("inf")
172+
# reset() runs a random probe, so _n_fe > 0 and _best_y is finite
173+
assert env._n_fe > 0
174+
assert np.isfinite(env._best_y)
174175
assert env._checkpoint_idx == 0
175176
assert env._choices_history == []
176177
assert env._optimizer_state == {}
@@ -202,10 +203,11 @@ def test_best_y_is_independent(self):
202203
env_b = make_env(suite=suite)
203204
env_a.reset()
204205
env_b.reset()
206+
best_y_before = env_b._best_y # probe value set during reset
205207

206208
env_a.step(0)
207209

208-
assert env_b._best_y == float("inf")
210+
assert env_b._best_y == best_y_before
209211

210212
def test_optimizer_state_does_not_leak(self):
211213
"""Warm-start population in env_a must not appear in env_b."""

0 commit comments

Comments
 (0)