diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index 56b367c3f..7056b65db 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -78,10 +78,11 @@ def _init_training_state( alpha_optimizer: optax.GradientTransformation, policy_optimizer: optax.GradientTransformation, q_optimizer: optax.GradientTransformation, + initial_alpha: float = 1.0, ) -> TrainingState: """Inits the training state and replicates it over devices.""" key_policy, key_q = jax.random.split(key) - log_alpha = jnp.asarray(0.0, dtype=jnp.float32) + log_alpha = jnp.asarray(jnp.log(initial_alpha), dtype=jnp.float32) alpha_optimizer_state = alpha_optimizer.init(log_alpha) policy_params = sac_network.policy_network.init(key_policy) @@ -128,6 +129,7 @@ def train( max_devices_per_host: Optional[int] = None, reward_scaling: float = 1.0, tau: float = 0.005, + initial_alpha: float = 1.0, min_replay_size: int = 0, max_replay_size: Optional[int] = None, grad_updates_per_step: int = 1, @@ -143,7 +145,54 @@ def train( checkpoint_logdir: Optional[str] = None, restore_checkpoint_path: Optional[str] = None, ): - """SAC training.""" + """SAC training. + + Args: + environment: the environment to train + num_timesteps: the total number of environment steps to use during training + episode_length: the length of an environment episode + wrap_env: If True, wrap the environment for training. Otherwise use the + environment as is. + wrap_env_fn: a custom function that wraps the environment for training. If + not specified, the environment is wrapped with the default training + wrapper. + action_repeat: the number of timesteps to repeat an action + num_envs: the number of parallel environments to use for rollouts + NOTE: `num_envs` must be divisible by the total number of chips since each + chip gets `num_envs // total_number_of_chips` environments to roll out + num_eval_envs: the number of envs to use for evluation. Each env will run 1 + episode, and all envs run in parallel during eval. + learning_rate: learning rate for SAC loss + discounting: discounting rate + seed: random seed + batch_size: the batch size for each minibatch SGD step + num_evals: the number of evals to run during the entire training run. + Increasing the number of evals increases total training time + normalize_observations: whether to normalize observations + max_devices_per_host: maximum number of chips to use per host process + reward_scaling: float scaling for reward + tau: interpolation factor in polyak averaging for target networks + initial_alpha: initial value for the temperature parameter α + min_replay_size: the minimum number of samples in the replay buffer before + starting training. This is used to prefill the replay buffer with random + samples before training starts + max_replay_size: the maximum number of samples in the replay buffer. If None, + the replay buffer will be filled with `num_timesteps` samples + grad_updates_per_step: the number of gradient updates to run per actor step. + deterministic_eval: whether to run the eval with a deterministic policy + network_factory: function that generates networks for policy and value + functions + progress_fn: a user-defined callback function for reporting/plotting metrics + eval_env: an optional environment for eval only, defaults to `environment` + randomization_fn: a user-defined callback function that generates randomized + environments + checkpoint_logdir: the path used to save checkpoints. If None, no checkpoints + are saved. The checkpoint will be saved every `num_evals` steps + restore_checkpoint_path: the path used to restore previous model params + + Returns: + Tuple of (make_policy function, network params, metrics) + """ process_id = jax.process_index() local_devices_to_use = jax.local_device_count() if max_devices_per_host is not None: @@ -485,6 +534,7 @@ def training_epoch_with_timing( alpha_optimizer=alpha_optimizer, policy_optimizer=policy_optimizer, q_optimizer=q_optimizer, + initial_alpha=initial_alpha, ) del global_key