Skip to content

Commit c93efb8

Browse files
committed
ruff fix
1 parent 592a7c7 commit c93efb8

3 files changed

Lines changed: 12 additions & 5 deletions

File tree

agents/rl_das/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,12 @@ def train(
131131
# Log per-epoch PPO diagnostics so training instability is visible
132132
# (e.g. actor_loss explosion, entropy collapse) without manual debugging.
133133
if epoch_diagnostics:
134-
entry["actor_loss"] = float(np.mean([d["actor_loss"] for d in epoch_diagnostics]))
135-
entry["critic_loss"] = float(np.mean([d["critic_loss"] for d in epoch_diagnostics]))
134+
entry["actor_loss"] = float(
135+
np.mean([d["actor_loss"] for d in epoch_diagnostics])
136+
)
137+
entry["critic_loss"] = float(
138+
np.mean([d["critic_loss"] for d in epoch_diagnostics])
139+
)
136140
entry["entropy"] = float(np.mean([d["entropy"] for d in epoch_diagnostics]))
137141

138142
if epoch % eval_interval == 0:

das/env/das_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import gymnasium as gym
1717
from gymnasium import spaces
1818

19-
from das.env.observation import (compute_observation, observation_dim, MAX_HISTORY_SAMPLE)
19+
from das.env.observation import compute_observation, observation_dim, MAX_HISTORY_SAMPLE
2020
from das.env.reward import compute_reward
2121
from das.optimizers.base import get_checkpoints
2222

@@ -257,7 +257,9 @@ def _update_episode_state(self, result: dict, prev_best_y: float):
257257
# derive scale from the magnitude of the initial best fitness.
258258
if self._initial_range[0] == float("inf"):
259259
safe_worst = (
260-
worst_y if np.isfinite(worst_y) else new_best_y + max(abs(new_best_y), 1.0)
260+
worst_y
261+
if np.isfinite(worst_y)
262+
else new_best_y + max(abs(new_best_y), 1.0)
261263
)
262264
self._initial_range = (new_best_y, max(safe_worst, new_best_y + 1e-5))
263265

das/training/rldas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def run_rl_das(args) -> None:
2424

2525
# Local variable — avoid mutating args so the caller's namespace stays predictable.
2626
k_epoch = (
27-
args.k_epoch if args.k_epoch is not None
27+
args.k_epoch
28+
if args.k_epoch is not None
2829
else max(1, int(0.3 * args.n_checkpoints))
2930
)
3031

0 commit comments

Comments
 (0)