From df56c3f3e284137f0c7b693235bab531b02fa425 Mon Sep 17 00:00:00 2001 From: jamesheald Date: Tue, 9 Sep 2025 14:26:39 +0100 Subject: [PATCH] policy_params_fn --- brax/training/agents/sac/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index 00eebd902..079a3c8b7 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -136,6 +136,7 @@ def train( sac_networks.SACNetworks ] = sac_networks.make_sac_networks, progress_fn: Callable[[int, Metrics], None] = lambda *args: None, + policy_params_fn: Callable[..., None] = lambda *args: None, eval_env: Optional[envs.Env] = None, randomization_fn: Optional[ Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]] @@ -557,6 +558,12 @@ def training_epoch_with_timing( training_walltime = time.time() - t current_step = 0 + + params = _unpmap( + (training_state.normalizer_params, training_state.policy_params) + ) + policy_params_fn(current_step, make_policy, params) + for _ in range(num_evals_after_init): logging.info('step %s', current_step) @@ -572,10 +579,11 @@ def training_epoch_with_timing( # Eval and logging if process_id == 0: - if checkpoint_logdir: - params = _unpmap( + params = _unpmap( (training_state.normalizer_params, training_state.policy_params) ) + policy_params_fn(current_step, make_policy, params) + if checkpoint_logdir: ckpt_config = checkpoint.network_config( observation_size=obs_size, action_size=env.action_size,