diff --git a/hci/user-in-the-box/uitb/rl/__init__.py b/hci/user-in-the-box/uitb/rl/__init__.py new file mode 100644 index 00000000..f361d68e --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/__init__.py @@ -0,0 +1 @@ +from .sb3.PPO import PPO \ No newline at end of file diff --git a/hci/user-in-the-box/uitb/rl/base.py b/hci/user-in-the-box/uitb/rl/base.py new file mode 100644 index 00000000..8297fa62 --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/base.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +import os +import shutil +import inspect +import pathlib + +from ..utils.functions import parent_path + +class BaseRLModel(ABC): + + def __init__(self, **kwargs): + pass + + @abstractmethod + def learn(self, wandb_callback): + pass + + @classmethod + def clone(cls, simulator_folder, package_name): + + # Create 'rl' folder + dst = os.path.join(simulator_folder, package_name, "rl") + os.makedirs(dst, exist_ok=True) + + # Copy the rl library folder + src = parent_path(inspect.getfile(cls)) + shutil.copytree(src, os.path.join(dst, src.stem), dirs_exist_ok=True) + + # Copy this file + base_file = pathlib.Path(__file__) + shutil.copyfile(base_file, os.path.join(dst, base_file.name)) + + # Copy the file with encoders + encoder_file = os.path.join(base_file.parent, "encoders.py") + shutil.copyfile(encoder_file, os.path.join(dst, "encoders.py")) + + # Create an __init__.py file with the relevant import + modules = cls.__module__.split(".") + with open(os.path.join(dst, "__init__.py"), "w") as file: + file.write("from ." + ".".join(modules[2:]) + " import " + cls.__name__) \ No newline at end of file diff --git a/hci/user-in-the-box/uitb/rl/encoders.py b/hci/user-in-the-box/uitb/rl/encoders.py new file mode 100644 index 00000000..118232b9 --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/encoders.py @@ -0,0 +1,69 @@ +from torch import nn +import torch +from typing import final + + +class BaseEncoder: + """ Defines an encoder. Note that in stable baselines training we only use the model definitions given here. We don't + e.g. after training set the encoder parameters into these objects, but instead use the ones saved/loaded by stable + baselines. In other words, these encoders are not used after during/after training, only to initialise the encoders + for stable baselines. """ + + def __init__(self, observation_shape, **kwargs): + self._observation_shape = observation_shape + + # Define a PyTorch model (e.g. using torch.nn.Sequential) + self._model = None + + # We assume all encoders output a vector with self._out_features elements in it + self.out_features = None + + @final + @property + def model(self): + return self._model + +class Identity(BaseEncoder): + """ Define an identity encoder. Used when no encoder has been defined. Can only be used for one-dimensional + observations. """ + + def __init__(self, observation_shape): + super().__init__(observation_shape) + if len(observation_shape) > 1: + raise RuntimeError("You must not use the Identity encoder for higher dimensional observations. Use an encoder" + "that maps the high dimensional observations into one dimensional vectors.") + self._model = torch.nn.Identity() + self.out_features = observation_shape[0] + +class SmallCNN(BaseEncoder): + + def __init__(self, observation_shape, out_features): + super().__init__(observation_shape) + cnn = nn.Sequential( + nn.Conv2d(in_channels=observation_shape[0], out_channels=8, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2)), + nn.LeakyReLU(), + nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2)), + nn.LeakyReLU(), + nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2)), + nn.LeakyReLU(), + nn.Flatten()) + + # Compute shape by doing one forward pass + with torch.no_grad(): + n_flatten = cnn(torch.zeros(observation_shape)[None]).shape[1] + + self._model = nn.Sequential( + cnn, + nn.Linear(in_features=n_flatten, out_features=out_features), + nn.LeakyReLU()) + self.out_features = out_features + + +class OneLayer(BaseEncoder): + + def __init__(self, observation_shape, out_features): + super().__init__(observation_shape) + self.out_features = out_features + self._model = nn.Sequential( + nn.Linear(self._observation_shape[0], out_features), + nn.LeakyReLU()) diff --git a/hci/user-in-the-box/uitb/rl/sb3/PPO.py b/hci/user-in-the-box/uitb/rl/sb3/PPO.py new file mode 100644 index 00000000..2efaa00c --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/PPO.py @@ -0,0 +1,356 @@ +import os +import importlib +import numpy as np +import pathlib + +from stable_baselines3 import PPO as PPO_sb3 +from stable_baselines3.common.vec_env import SubprocVecEnv +# from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.callbacks import CheckpointCallback #, EvalCallback + +from typing import TypeVar +import sys, time +from stable_baselines3.common.type_aliases import MaybeCallback +from stable_baselines3.common.utils import safe_mean +SelfPPO = TypeVar("SelfPPO", bound="PPO") + +from typing import Any, Dict, Optional, SupportsFloat, Tuple +import gymnasium as gym +from gymnasium.core import ActType, ObsType +from collections import defaultdict + +from typing import Callable, Optional, Type, Union +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv +from stable_baselines3.common.vec_env.patch_gym import _patch_env + +from ..base import BaseRLModel +from .callbacks import EvalCallback + + +class PPO(BaseRLModel): + + def __init__(self, simulator, checkpoint_path=None, wandb_id=None, info_keywords=()): + super().__init__() + + rl_config = self.load_config(simulator) + run_parameters = simulator.run_parameters + simulator_folder = simulator.simulator_folder + + # Get total timesteps + self.total_timesteps = rl_config["total_timesteps"] + + # Combine info keywords to be logged from config and from passed kwarg. + # Note: Each entry of info_keywords needs to be of type Tuple[str, str], + # with the variable name as first and the episode operation (e.g., "sum", "mean", "final" (default)) as second string + self.info_keywords = tuple({tuple(k) for k in (run_parameters.get("info_keywords", []) + list(info_keywords))}) + + # Initialise parallel envs + self.n_envs = rl_config["num_workers"] + Monitor = Monitor_customops #use Monitor_customops instead of Monitor class in make_vec_env + parallel_envs = make_vec_env(simulator.__class__, n_envs=self.n_envs, + seed=run_parameters.get("random_seed", None), vec_env_cls=SubprocVecEnv, + monitor_kwargs={"info_keywords": self.info_keywords}, + env_kwargs={"simulator_folder": simulator_folder}) + + if checkpoint_path is not None: + # Resume training + self.model = PPO_sb3_customlogs.load(checkpoint_path, parallel_envs, verbose=1, #policy_kwargs=rl_config["policy_kwargs"], + tensorboard_log=simulator_folder, n_steps=rl_config["nsteps"], + batch_size=rl_config["batch_size"], target_kl=rl_config["target_kl"], + learning_rate=rl_config["lr"], device=rl_config["device"]) + self.training_resumed = True + else: + # Add feature and stateful information encoders to policy_kwargs + encoders = simulator.perception.encoders + if simulator.task.get_stateful_information_space_params()["shape"] != (0,): + #TODO: define stateful_information (and encoder) that can be used as default, if no stateful information is provided (zero-size array do not work with sb3 currently...) + encoders["stateful_information"] = simulator.task.stateful_information_encoder + rl_config["policy_kwargs"]["features_extractor_kwargs"] = {"encoders": encoders} + rl_config["policy_kwargs"]["wandb_id"] = wandb_id + + # Initialise model + self.model = PPO_sb3_customlogs(rl_config["policy_type"], parallel_envs, verbose=1, policy_kwargs=rl_config["policy_kwargs"], + tensorboard_log=simulator_folder, n_steps=rl_config["nsteps"], + batch_size=rl_config["batch_size"], target_kl=rl_config["target_kl"], + learning_rate=rl_config["lr"], device=rl_config["device"]) + self.training_resumed = False + + if "policy_init" in rl_config: + params = os.path.join(pathlib.Path(__file__).parent, rl_config["policy_init"]) + self.model.policy.load_from_vector(np.load(params)) + + # Create a checkpoint callback + save_freq = rl_config["save_freq"] // self.n_envs + checkpoint_folder = os.path.join(simulator_folder, 'checkpoints') + self.checkpoint_callback = CheckpointCallback(save_freq=save_freq, + save_path=checkpoint_folder, + name_prefix='model', + save_replay_buffer=True, + save_vecnormalize=True) + + # Get callbacks as a list + self.callbacks = [*simulator.callbacks.values()] + + # Create an evaluation env (only used if eval_callback=True is passed to learn()) + self.eval_env = simulator.__class__(**{"simulator_folder": simulator_folder}) + + def load_config(self, simulator): + config = simulator.config["rl"] + + # Need to translate strings into classes + config["policy_type"] = simulator.get_class("rl.sb3", config["policy_type"]) + + if "activation_fn" in config["policy_kwargs"]: + mods = config["policy_kwargs"]["activation_fn"].split(".") + config["policy_kwargs"]["activation_fn"] = getattr(importlib.import_module(".".join(mods[:-1])), mods[-1]) + + config["policy_kwargs"]["features_extractor_class"] = \ + simulator.get_class("rl.sb3", config["policy_kwargs"]["features_extractor_class"]) + + if "lr" in config: + if isinstance(config["lr"], dict): + config["lr"] = simulator.get_class("rl.sb3", config["lr"]["function"])(**config["lr"]["kwargs"]) + + return config + + def learn(self, wandb_callback, with_evaluation=False, eval_freq=400000, n_eval_episodes=5, eval_info_keywords=()): + if with_evaluation: + self.eval_env = Monitor(self.eval_env, info_keywords=eval_info_keywords) + self.eval_freq = eval_freq // self.n_envs + self.eval_callback = EvalCallback(self.eval_env, eval_freq=self.eval_freq, n_eval_episodes=n_eval_episodes, info_keywords=eval_info_keywords) + + self.model.learn(total_timesteps=self.total_timesteps, + callback=[wandb_callback, self.checkpoint_callback, self.eval_callback, *self.callbacks], + info_keywords=self.info_keywords, + reset_num_timesteps=not self.training_resumed) + else: + self.model.learn(total_timesteps=self.total_timesteps, + callback=[wandb_callback, self.checkpoint_callback, *self.callbacks], + info_keywords=self.info_keywords, + reset_num_timesteps=not self.training_resumed) + +class PPO_sb3_customlogs(PPO_sb3): + def learn( + self: SelfPPO, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + tb_log_name: str = "PPO", + info_keywords : tuple = (), + reset_num_timesteps: bool = True, + progress_bar: bool = False, + ) -> SelfPPO: + iteration = 0 + + total_timesteps, callback = self._setup_learn( + total_timesteps, + callback, + reset_num_timesteps, + tb_log_name, + progress_bar, + ) + + callback.on_training_start(locals(), globals()) + + assert self.env is not None + + while self.num_timesteps < total_timesteps: + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) + + if continue_training is False: + break + + iteration += 1 + self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + + # Display training infos + if log_interval is not None and iteration % log_interval == 0: + assert self.ep_info_buffer is not None + time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) + fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) + self.logger.record("time/iterations", iteration, exclude="tensorboard") + if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: + self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + for keyword, operation in info_keywords: + self.logger.record(f"rollout/ep_{keyword}_{operation}", safe_mean([ep_info[keyword] for ep_info in self.ep_info_buffer if keyword in ep_info])) + self.logger.record("time/fps", fps) + self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + self.train() + + callback.on_training_end() + + return self + +class Monitor_customops(Monitor): + """ + Modified monitor wrapper for Gym environments, which allows to accumulate logged values per episode (e.g., store sum or mean of a logged variable per episode). + To this end, info_keywords is a tuple containing (str, str) tuples, with variable name as first string and episode operation (e.g., "sum", "mean", or "final" (default)) as second string. + + :param env: The environment + :param filename: the location to save a log file, can be None for no log + :param allow_early_resets: allows the reset of the environment before it is done + :param reset_keywords: extra keywords for the reset call, + if extra parameters are needed at reset + :param info_keywords: extra information to log, from the information return of env.step() [see note above] + :param override_existing: appends to file if ``filename`` exists, otherwise + override existing files (default) + """ + def __init__( + self, + env: gym.Env, + filename: Optional[str] = None, + allow_early_resets: bool = True, + reset_keywords: Tuple[str, ...] = (), + info_keywords: Tuple[Tuple[str, str], ...] = (), + override_existing: bool = True, + ): + super().__init__(env=env, filename=filename, allow_early_resets=allow_early_resets, reset_keywords=reset_keywords, info_keywords=info_keywords, override_existing=override_existing) + + + def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: + """ + Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True + + :param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords + :return: the first observation of the environment + """ + if not self.allow_early_resets and not self.needs_reset: + raise RuntimeError( + "Tried to reset an environment before done. If you want to allow early resets, " + "wrap your env with Monitor_customops(env, path, allow_early_resets=True)" + ) + self.info_keywords_acc_valuedict = defaultdict(list) + return super().reset(**kwargs) + + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + """ + Step the environment with the given action + + :param action: the action + :return: observation, reward, terminated, truncated, information + """ + if self.needs_reset: + raise RuntimeError("Tried to step environment that needs reset") + observation, reward, terminated, truncated, info = self.env.step(action) + self.rewards.append(float(reward)) + for key, op in self.info_keywords: + if op in ["sum", "mean"]: + self.info_keywords_acc_valuedict[key].append(float(info[key])) + if terminated or truncated: + self.needs_reset = True + ep_rew = sum(self.rewards) + ep_len = len(self.rewards) + ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} + for key, op in self.info_keywords: + if op == "sum": + ep_info[key] = sum(self.info_keywords_acc_valuedict[key]) + elif op == "mean": + ep_info[key] = safe_mean(self.info_keywords_acc_valuedict[key]) + else: + ep_info[key] = info[key] + self.episode_returns.append(ep_rew) + self.episode_lengths.append(ep_len) + self.episode_times.append(time.time() - self.t_start) + ep_info.update(self.current_reset_info) + if self.results_writer: + self.results_writer.write_row(ep_info) + info["episode"] = ep_info + self.total_steps += 1 + return observation, reward, terminated, truncated, info + +def make_vec_env( + env_id: Union[str, Callable[..., gym.Env]], + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + monitor_dir: Optional[str] = None, + wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, + monitor_kwargs: Optional[Dict[str, Any]] = None, + wrapper_kwargs: Optional[Dict[str, Any]] = None, +) -> VecEnv: + """ + Create a wrapped, monitored ``VecEnv``. + By default it uses a ``DummyVecEnv`` which is usually faster + than a ``SubprocVecEnv``. + + :param env_id: either the env ID, the env class or a callable returning an env + :param n_envs: the number of environments you wish to have in parallel + :param seed: the initial seed for the random number generator + :param start_index: start rank index + :param monitor_dir: Path to a folder where the monitor files will be saved. + If None, no file will be written, however, the env will still be wrapped + in a Monitor_customops wrapper to provide additional information about training. + :param wrapper_class: Additional wrapper to use on the environment. + This can also be a function with single argument that wraps the environment in many things. + Note: the wrapper specified by this parameter will be applied after the ``Monitor_customops`` wrapper. + if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior. + See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894 + :param env_kwargs: Optional keyword argument to pass to the env constructor + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. + :param monitor_kwargs: Keyword arguments to pass to the ``Monitor_customops`` class constructor. + :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. + :return: The wrapped environment + """ + env_kwargs = env_kwargs or {} + vec_env_kwargs = vec_env_kwargs or {} + monitor_kwargs = monitor_kwargs or {} + wrapper_kwargs = wrapper_kwargs or {} + assert vec_env_kwargs is not None # for mypy + + def make_env(rank: int) -> Callable[[], gym.Env]: + def _init() -> gym.Env: + # For type checker: + assert monitor_kwargs is not None + assert wrapper_kwargs is not None + assert env_kwargs is not None + + if isinstance(env_id, str): + # if the render mode was not specified, we set it to `rgb_array` as default. + kwargs = {"render_mode": "rgb_array"} + kwargs.update(env_kwargs) + try: + env = gym.make(env_id, **kwargs) # type: ignore[arg-type] + except TypeError: + env = gym.make(env_id, **env_kwargs) + else: + env = env_id(**env_kwargs) + # Patch to support gym 0.21/0.26 and gymnasium + env = _patch_env(env) + + if seed is not None: + # Note: here we only seed the action space + # We will seed the env at the next reset + env.action_space.seed(seed + rank) + # Wrap the env in a Monitor wrapper + # to have additional training information + monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None + # Create the monitor folder if needed + if monitor_path is not None and monitor_dir is not None: + os.makedirs(monitor_dir, exist_ok=True) + env = Monitor_customops(env, filename=monitor_path, **monitor_kwargs) + # Optionally, wrap the environment with the provided wrapper + if wrapper_class is not None: + env = wrapper_class(env, **wrapper_kwargs) + return env + + return _init + + # No custom VecEnv is passed + if vec_env_cls is None: + # Default: use a DummyVecEnv + vec_env_cls = DummyVecEnv + + vec_env = vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) + # Prepare the seeds for the first reset + vec_env.seed(seed) + return vec_env + diff --git a/hci/user-in-the-box/uitb/rl/sb3/RecurrentPPO.py b/hci/user-in-the-box/uitb/rl/sb3/RecurrentPPO.py new file mode 100644 index 00000000..2f47e5ae --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/RecurrentPPO.py @@ -0,0 +1,39 @@ +import os +import numpy as np + +from sb3_contrib import RecurrentPPO as RecurrentPPO_sb3 +from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.callbacks import CheckpointCallback + +from uitb.rl.base import BaseModel + + +class RecurrentPPO(BaseModel): + + def __init__(self, config, run_folder): + super().__init__(config) + + # Initialise parallel envs_old_to_be_removed + parallel_envs = make_vec_env(config["env_name"], n_envs=config["num_workers"], seed=0, + vec_env_cls=SubprocVecEnv, env_kwargs=config["env_kwargs"], + vec_env_kwargs={'start_method': config["start_method"]}) + + # Initialise model + self.model = RecurrentPPO_sb3(config["policy_type"], parallel_envs, verbose=1, + policy_kwargs=config["policy_kwargs"], tensorboard_log=run_folder, + n_steps=config["nsteps"], batch_size=config["batch_size"], + target_kl=config["target_kl"], learning_rate=config["lr"], device=config["device"]) + + + save_freq = config["save_freq"] // config["num_workers"] + checkpoint_folder = os.path.join(run_folder, 'checkpoints') + self.checkpoint_callback = CheckpointCallback(save_freq=save_freq, + save_path=checkpoint_folder, + name_prefix='model', + save_replay_buffer=True, + save_vecnormalize=True) + + def learn(self, wandb_callback): + self.model.learn(total_timesteps=self.config["total_timesteps"], + callback=[wandb_callback, self.checkpoint_callback]) \ No newline at end of file diff --git a/hci/user-in-the-box/uitb/rl/sb3/__init__.py b/hci/user-in-the-box/uitb/rl/sb3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/hci/user-in-the-box/uitb/rl/sb3/callbacks.py b/hci/user-in-the-box/uitb/rl/sb3/callbacks.py new file mode 100644 index 00000000..87c4fd0c --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/callbacks.py @@ -0,0 +1,346 @@ +import numpy as np +import torch + +from typing import Any, Dict, Optional +import os +import warnings + +from stable_baselines3.common.callbacks import BaseCallback, EventCallback +#from stable_baselines3.common.evaluation import evaluate_policy +#from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization +from stable_baselines3.common.vec_env import VecEnv, sync_envs_normalization +from .dummy_vec_env import DummyVecEnv +from .evaluation import evaluate_policy + +class LinearStdDecayCallback(BaseCallback): + """ + Linearly decaying standard deviation + + :param initial_log_value: Log initial standard deviation value + :param threshold: Threshold for progress remaining until decay begins + :param min_value: Minimum value for standard deviation + :param verbose: (int) Verbosity level 0: not output 1: info 2: debug + """ + def __init__(self, initial_log_value, threshold, min_value, verbose=0): + super(LinearStdDecayCallback, self).__init__(verbose) + self.initial_value = np.exp(initial_log_value) + self.threshold = threshold + self.min_value = min_value + + def _on_rollout_start(self) -> None: + progress_remaining = self.model._current_progress_remaining + if progress_remaining > self.threshold: + pass + else: + new_std = self.min_value + (progress_remaining/self.threshold) * (self.initial_value-self.min_value) + self.model.policy.log_std.data = torch.tensor(np.log(new_std)).float() + + def _on_training_start(self) -> None: + pass + + def _on_step(self) -> bool: + return True + + def _on_rollout_end(self) -> None: + pass + + def _on_training_end(self) -> None: + pass + + +class LinearCurriculum(BaseCallback): + """ + A callback to implement linear curriculum for one parameter + + :param verbose: (int) Verbosity level 0: not output 1: info 2: debug + """ + + def __init__(self, name, start_value, end_value, end_timestep, start_timestep=0, verbose=0): + super().__init__(verbose) + self.name = name + self.variable = start_value + self.start_value = start_value + self.end_value = end_value + self.start_timestep = start_timestep + self.end_timestep = end_timestep + self.coeff = (end_value - start_value) / (end_timestep - start_timestep) + + def value(self): + return self.variable + + def update(self, num_timesteps): + if num_timesteps <= self.start_timestep: + self.variable = self.start_value + elif self.end_timestep >= num_timesteps > self.start_timestep: + self.variable = self.start_value + self.coeff * (num_timesteps - self.start_timestep) + else: + self.variable = self.end_value + + def _on_training_start(self) -> None: + pass + + def _on_rollout_start(self) -> None: + self.training_env.env_method("callback", self.name, self.num_timesteps) + + def _on_step(self) -> bool: + return True + + def _on_rollout_end(self) -> None: + pass + + def _on_training_end(self) -> None: + pass + + +# class CustomTrainLogCallback(BaseCallback): +# """ +# Log custom values from training envs at each time step. +# """ + +# def __init__(self, name, verbose: int = 0): +# super().__init__(verbose=verbose) +# self.name = name + +# def _on_rollout_start(self) -> None: +# # Get list of logging variables +# # self._log_variables = self.training_env.get_attr("task")[0]._info["log_dict"].keys() +# self._log_variables = self.training_env.env_method("get_logdict_keys")[0] +# # print(f"LOG VARIABLES: {self._log_variables}") + +# def _on_step(self) -> bool: +# # Check that the `_info` local variable is defined +# for log_var_key in self._log_variables: +# log_var_val = self.training_env.env_method("get_logdict_value", key=log_var_key) +# # print(f"rollout/{log_var_key}: {log_var_val}") +# log_var_val = np.mean(log_var_val) #TODO: does not work correctly (sometimes None values are returned...) + +# self.logger.record(f"rollout/{log_var_key}", log_var_val) + +# return True + +# def update(self, num_timesteps): +# pass + + +class EvalCallback(EventCallback): + """ + A custom callback for evaluating an agent that derives from ``EventCallback``. + .. warning:: + When using multiple environments, each call to ``env.step()`` + will effectively correspond to ``n_envs`` steps. + To account for that, you can use ``eval_freq = max(eval_freq // n_envs, 1)`` + :param eval_env: The environment used for initialization + :param callback_on_new_best: Callback to trigger + when there is a new best model according to the ``mean_reward`` + :param callback_after_eval: Callback to trigger after every evaluation + :param n_eval_episodes: The number of episodes to test the agent + :param eval_freq: Evaluate the agent every ``eval_freq`` call of the callback. + :param best_model_save_path: Path to a folder where the best model + according to performance on the eval env will be saved. + :param deterministic: Whether the evaluation should + use a stochastic or deterministic actions. + :param info_keywords: extra information to log, from the information return of env.step() + :param render: Whether to render or not the environment during evaluation + :param verbose: (int) Verbosity level 0: no output 1: info 2: debug + :param warn: Passed to ``evaluate_policy`` (warns if ``eval_env`` has not been + wrapped with a Monitor wrapper) + """ + + def __init__(self, eval_env, + callback_on_new_best: Optional[BaseCallback] = None, + callback_after_eval: Optional[BaseCallback] = None, + n_eval_episodes: int = 5, + eval_freq: int = 10000, + best_model_save_path: Optional[str] = None, + deterministic: bool = True, + info_keywords: tuple = (), + render: bool = False, + verbose: int = 1, + warn: bool = True): + super().__init__(callback_after_eval, verbose=verbose) + + self.callback_on_new_best = callback_on_new_best + if self.callback_on_new_best is not None: + # Give access to the parent + self.callback_on_new_best.parent = self + + self.n_eval_episodes = n_eval_episodes + self.eval_freq = eval_freq + self.best_mean_reward = -np.inf + self.last_mean_reward = -np.inf + self.deterministic = deterministic + self.info_keywords = info_keywords + self.render = render + self.warn = warn + + # Convert to VecEnv for consistency + if not isinstance(eval_env, VecEnv): + eval_env = DummyVecEnv([lambda: eval_env]) + + self.eval_env = eval_env + self.best_model_save_path = best_model_save_path + self._is_success_buffer = [] + + def _init_callback(self) -> None: + # Does not work in some corner cases, where the wrapper is not the same + if not isinstance(self.training_env, type(self.eval_env)): + warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}") + + # Create folders if needed + if self.best_model_save_path is not None: + os.makedirs(self.best_model_save_path, exist_ok=True) + + # Init callback called on new best model + if self.callback_on_new_best is not None: + self.callback_on_new_best.init_callback(self.model) + + def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + """ + Callback passed to the ``evaluate_policy`` function + in order to log the success rate (when applicable), + for instance when using HER. + :param locals_: + :param globals_: + """ + info = locals_["info"] + + if locals_["terminated"] or locals_["truncated"]: + maybe_is_success = info.get("is_success") + if maybe_is_success is not None: + self._is_success_buffer.append(maybe_is_success) + + def _on_training_start(self) -> None: + pass + + def _on_rollout_start(self) -> None: + pass + + def _on_step(self) -> bool: + + continue_training = True + + if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: + + # Sync training and eval env if there is VecNormalize + if self.model.get_vec_normalize_env() is not None: + try: + sync_envs_normalization(self.training_env, self.eval_env) + except AttributeError as e: + raise AssertionError( + "Training and eval env are not wrapped the same way, " + "see https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html#evalcallback " + "and warning above." + ) from e + + # Reset success rate buffer + self._is_success_buffer = [] + + episode_rewards, episode_lengths, episode_customlogs = evaluate_policy( + self.model, + self.eval_env, + n_eval_episodes=self.n_eval_episodes, + render=self.render, + deterministic=self.deterministic, + info_keywords=self.info_keywords, + return_episode_rewards=True, + warn=self.warn, + callback=self._log_success_callback, + ) + + mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) + mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) + mean_episode_customlogs, std_episode_customlogs = {k: np.mean(v) for k, v in episode_customlogs.items()}, {k: np.std(v) for k, v in episode_customlogs.items()} + self.last_mean_reward = mean_reward + + if self.verbose >= 1: + print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") + print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") + for key in episode_customlogs: + print(f"{key}: {mean_episode_customlogs[key]:.2f} +/- {std_episode_customlogs[key]:.2f}") + # Add to current Logger + self.logger.record("eval/mean_reward", float(mean_reward)) + self.logger.record("eval/mean_ep_length", mean_ep_length) + for key in episode_customlogs: + self.logger.record(f"eval/mean_{key}", mean_episode_customlogs[key]) + +# # Run a few episodes to evaluate progress with deterministic actions +# det_info = self.evaluate(deterministic=True) + +# # Log evaluations +# self.logger.record("evaluate/deterministic/ep_rew_mean", det_info[0]) +# self.logger.record("evaluate/deterministic/ep_len_mean", det_info[1]) +# self.logger.record("evaluate/deterministic/ep_targets_hit_mean", det_info[2]) + + if len(self._is_success_buffer) > 0: + success_rate = np.mean(self._is_success_buffer) + if self.verbose >= 1: + print(f"Success rate: {100 * success_rate:.2f}%") + self.logger.record("eval/success_rate", success_rate) + + + # # Run a few more episodes to evaluate progress without deterministic actions + # if self.stochastic_evals: + # sto_info = self.evaluate(deterministic=False) + # self.logger.record("evaluate/stochastic/ep_rew_mean", sto_info[0]) + # self.logger.record("evaluate/stochastic/ep_len_mean", sto_info[1]) + # self.logger.record("evaluate/stochastic/ep_targets_hit_mean", sto_info[2]) + + # Dump log so the evaluation results are printed with the correct timestep + self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") + self.logger.dump(step=self.num_timesteps) + + # mean_reward = det_info[0] + # self.last_mean_reward = mean_reward + if mean_reward > self.best_mean_reward: + if self.verbose >= 1: + print("New best mean reward!") + if self.best_model_save_path is not None: + self.model.save(os.path.join(self.best_model_save_path, "best_model")) + self.best_mean_reward = mean_reward + # Trigger callback on new best model, if needed + if self.callback_on_new_best is not None: + continue_training = self.callback_on_new_best.on_step() + + # Trigger callback after every evaluation, if needed + if self.callback is not None: + continue_training = continue_training and self._on_event() + + return continue_training + + def _on_rollout_end(self) -> None: + pass + + def _on_training_end(self) -> None: + pass + +# def evaluate(self, deterministic): +# rewards = np.zeros((self.n_eval_episodes,)) +# episode_lengths = np.zeros((self.n_eval_episodes,)) +# # episode_returns = np.zeros((self.n_eval_episodes,)) +# # episode_times = np.zeros((self.n_eval_episodes,)) +# targets_hit = np.zeros((self.n_eval_episodes,)) + +# for i in range(self.n_eval_episodes): +# obs = self.eval_env.reset() +# terminated = False +# truncated = False +# while not terminated and not truncated: +# action, _ = self.model.predict(obs, deterministic=deterministic) +# obs, r, terminated, truncated, info = self.eval_env.step(action) +# rewards[i] += r +# input(info) +# episode_lengths[i] = info["episode"]["l"] #self.eval_env.steps +# # episode_returns[i] = info["episode"]["r"] #self.eval_env.steps +# # episode_times[i] = info["episode"]["t"] #self.eval_env.steps +# targets_hit[i] = self.eval_env.trial_idx +# # assert np.allclose(episode_returns[i], rewards[i]) + +# return np.mean(rewards), np.mean(episode_lengths), np.mean(targets_hit) + + def update_child_locals(self, locals_: Dict[str, Any]) -> None: + """ + Update the references to the local variables. + :param locals_: the local variables during rollout collection + """ + if self.callback: + self.callback.update_locals(locals_) \ No newline at end of file diff --git a/hci/user-in-the-box/uitb/rl/sb3/dummy_vec_env.py b/hci/user-in-the-box/uitb/rl/sb3/dummy_vec_env.py new file mode 100644 index 00000000..ca88ecf2 --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/dummy_vec_env.py @@ -0,0 +1,128 @@ +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Callable, List, Optional, Sequence, Type, Union + +import gymnasium as gym +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn +from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info + + +class DummyVecEnv(VecEnv): + """ + Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current + Python process. This is useful for computationally simple environment such as ``cartpole-v1``, + as the overhead of multiprocess or multithread outweighs the environment computation time. + This can also be used for RL methods that + require a vectorized environment, but that you want a single environments to train with. + + :param env_fns: a list of functions + that return environments to vectorize + """ + + def __init__(self, env_fns: List[Callable[[], gym.Env]]): + self.envs = [fn() for fn in env_fns] + env = self.envs[0] + VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) + obs_space = env.observation_space + self.keys, shapes, dtypes = obs_space_info(obs_space) + + self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]) + self.buf_terminateds = np.zeros((self.num_envs,), dtype=bool) + self.buf_truncateds = np.zeros((self.num_envs,), dtype=bool) + self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) + self.buf_infos = [{} for _ in range(self.num_envs)] + self.actions = None + self.metadata = env.metadata + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def step_wait(self) -> VecEnvStepReturn: + for env_idx in range(self.num_envs): + obs, self.buf_rews[env_idx], self.buf_terminateds[env_idx], self.buf_truncateds[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + self.actions[env_idx] + ) + if self.buf_terminateds[env_idx] or self.buf_truncateds[env_idx]: + # save final observation where user can get it, then reset + self.buf_infos[env_idx]["terminal_observation"] = obs + obs, _ = self.envs[env_idx].reset() #TODO: store info returned by reset()? + self._save_obs(env_idx, obs) + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_terminateds), np.copy(self.buf_truncateds), deepcopy(self.buf_infos)) + + def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + seeds = [] + for idx, env in enumerate(self.envs): + seeds.append(env.seed(seed + idx)) + return seeds + + def reset(self) -> VecEnvObs: + for env_idx in range(self.num_envs): + obs, _ = self.envs[env_idx].reset() + self._save_obs(env_idx, obs) + return self._obs_from_buf() + + def close(self) -> None: + for env in self.envs: + env.close() + + def get_images(self) -> Sequence[np.ndarray]: + return [env.render(mode="rgb_array") for env in self.envs] + + def render(self, mode: str = "human") -> Optional[np.ndarray]: + """ + Gym environment rendering. If there are multiple environments then + they are tiled together in one image via ``BaseVecEnv.render()``. + Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the + underlying environment. + + Therefore, some arguments such as ``mode`` will have values that are valid + only when ``num_envs == 1``. + + :param mode: The rendering type. + """ + if self.num_envs == 1: + return self.envs[0].render(mode=mode) + else: + return super().render(mode=mode) + + def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: + for key in self.keys: + if key is None: + self.buf_obs[key][env_idx] = obs + else: + self.buf_obs[key][env_idx] = obs[key] + + def _obs_from_buf(self) -> VecEnvObs: + return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs)) + + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + """Return attribute from vectorized environment (see base class).""" + target_envs = self._get_target_envs(indices) + return [getattr(env_i, attr_name) for env_i in target_envs] + + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + """Set attribute inside vectorized environments (see base class).""" + target_envs = self._get_target_envs(indices) + for env_i in target_envs: + setattr(env_i, attr_name, value) + + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + """Call instance methods of vectorized environments.""" + target_envs = self._get_target_envs(indices) + return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs] + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + """Check if worker environments are wrapped with a given wrapper""" + target_envs = self._get_target_envs(indices) + # Import here to avoid a circular import + from stable_baselines3.common import env_util + + return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] + + def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: + indices = self._get_indices(indices) + return [self.envs[i] for i in indices] diff --git a/hci/user-in-the-box/uitb/rl/sb3/evaluation.py b/hci/user-in-the-box/uitb/rl/sb3/evaluation.py new file mode 100644 index 00000000..15a7ece6 --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/evaluation.py @@ -0,0 +1,148 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import gymnasium as gym +import numpy as np + +from stable_baselines3.common import base_class +# from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped +from stable_baselines3.common.vec_env import VecEnv, VecMonitor, is_vecenv_wrapped +from .dummy_vec_env import DummyVecEnv + +def evaluate_policy( + model: "base_class.BaseAlgorithm", + env: Union[gym.Env, VecEnv], + n_eval_episodes: int = 10, + deterministic: bool = True, + info_keywords: tuple = (), + render: bool = False, + callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, + reward_threshold: Optional[float] = None, + return_episode_rewards: bool = False, + warn: bool = True, + ) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: + """ + Runs policy for ``n_eval_episodes`` episodes and returns average reward. + If a vector env is passed in, this divides the episodes to evaluate onto the + different elements of the vector env. This static division of work is done to + remove bias. See https://github.com/DLR-RM/stable-baselines3/issues/402 for more + details and discussion. + .. note:: + If environment has not been wrapped with ``Monitor`` wrapper, reward and + episode lengths are counted as it appears with ``env.step`` calls. If + the environment contains wrappers that modify rewards or episode lengths + (e.g. reward scaling, early episode reset), these will affect the evaluation + results as well. You can avoid this by wrapping environment with ``Monitor`` + wrapper before anything else. + :param model: The RL agent you want to evaluate. + :param env: The gym environment or ``VecEnv`` environment. + :param n_eval_episodes: Number of episode to evaluate the agent + :param deterministic: Whether to use deterministic or stochastic actions + :param render: Whether to render the environment or not + :param callback: callback function to do additional checks, + called after each step. Gets locals() and globals() passed as parameters. + :param reward_threshold: Minimum expected reward per episode, + this will raise an error if the performance is not met + :param return_episode_rewards: If True, a list of rewards and episode lengths + per episode will be returned instead of the mean. + :param warn: If True (default), warns user about lack of a Monitor wrapper in the + evaluation environment. + :return: Mean reward per episode, std of reward per episode. + Returns ([float], [int]) when ``return_episode_rewards`` is True, first + list containing per-episode rewards and second containing per-episode lengths + (in n + umber of steps). + """ + is_monitor_wrapped = False + # Avoid circular import + from stable_baselines3.common.monitor import Monitor + if not isinstance(env, VecEnv): + env = DummyVecEnv([lambda: env]) + + is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] + + if not is_monitor_wrapped and warn: + warnings.warn( + "Evaluation environment is not wrapped with a ``Monitor`` wrapper. " + "This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. " + "Consider wrapping environment first with ``Monitor`` wrapper.", + UserWarning, + ) + + if not is_monitor_wrapped and len(info_keywords) > 0: + warning.warn( + f"Cannot store custom logs {info_keywords}, since evaluation environment is not wrapped with a ``Monitor`` wrapper. " + "Consider wrapping environment first with ``Monitor`` wrapper.", + UserWarning, + ) + + n_envs = env.num_envs + episode_rewards = [] + episode_lengths = [] + episode_customlogs = {k: [] for k in info_keywords} + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(n_eval_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = env.reset() + states = None + episode_starts = np.ones((env.num_envs,), dtype=bool) + while (episode_counts < episode_count_targets).any(): + actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) + observations, rewards, terminateds, truncateds, infos = env.step(actions) + current_rewards += rewards + current_lengths += 1 + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + + # unpack values so that the callback can access the local variables + reward = rewards[i] + terminated = terminateds[i] + truncated = truncateds[i] + info = infos[i] + episode_starts[i] = terminated or truncated + + if callback is not None: + callback(locals(), globals()) + + if terminateds[i] or truncateds[i]: + if is_monitor_wrapped: + # Atari wrapper can send a "done" signal when + # the agent loses a life, but it does not correspond + # to the true end of episode + if "episode" in info.keys(): + # Do not trust "done" with episode endings. + # Monitor wrapper includes "episode" key in info if environment + # has been wrapped with it. Use those rewards instead. + episode_rewards.append(info["episode"]["r"]) + episode_lengths.append(info["episode"]["l"]) + for key in info_keywords: + episode_customlogs[key].append(info["episode"][key]) + # Only increment at the real end of an episode + episode_counts[i] += 1 + else: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_counts[i] += 1 + current_rewards[i] = 0 + current_lengths[i] = 0 + + if render: + env.render() + + mean_reward = np.mean(episode_rewards) + std_reward = np.std(episode_rewards) + if reward_threshold is not None: + assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" + + + if return_episode_rewards: + return episode_rewards, episode_lengths, episode_customlogs + else: + mean_episode_customlogs = {k: np.mean(v) for k, v in episode_customlogs.items()} + std_episode_customlogs = {k: np.std(v) for k, v in episode_customlogs.items()} + + return mean_reward, std_reward, mean_episode_customlogs, std_episode_customlogs diff --git a/hci/user-in-the-box/uitb/rl/sb3/feature_extractor.py b/hci/user-in-the-box/uitb/rl/sb3/feature_extractor.py new file mode 100644 index 00000000..3d04bd1c --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/feature_extractor.py @@ -0,0 +1,30 @@ +import gymnasium as gym +import torch as th +from torch import nn + +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + + +class FeatureExtractor(BaseFeaturesExtractor): + def __init__(self, observation_space: gym.spaces.Dict, encoders): + + # Get the encoder models and define features_dim for parent class + total_concat_size = 0 + extractors = dict() + for key, encoder in encoders.items(): + extractors[key] = encoder.model + total_concat_size += encoder.out_features + + # Initialise parent class + super().__init__(observation_space, features_dim=total_concat_size) + + # Convert into ModuleDict + self.extractors = nn.ModuleDict(extractors) + + def forward(self, observations) -> th.Tensor: + encoded_tensor_list = [] + # self.extractors contain nn.Modules that do all the processing. + for key, extractor in self.extractors.items(): + encoded_tensor_list.append(extractor(observations[key])) + # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension. + return th.cat(encoded_tensor_list, dim=1) diff --git a/hci/user-in-the-box/uitb/rl/sb3/policies.py b/hci/user-in-the-box/uitb/rl/sb3/policies.py new file mode 100644 index 00000000..5315af9d --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/policies.py @@ -0,0 +1,711 @@ +import collections +import warnings +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gymnasium as gym +import numpy as np +import torch as th +from torch import nn + +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + Distribution, + MultiCategoricalDistribution, + StateDependentNoiseDistribution, + make_proba_distribution, +) +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, + create_mlp, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.policies import BasePolicy + + +class ActorCriticPolicyStdDecay(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param std_decay_threshold: If a value (0, 1] is given then std is not learned and instead decays linearly + :param std_decay_min: Minimum std value + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + std_decay_threshold: float = 0.0, + std_decay_min: float = 0.1, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + wandb_id: str = None + ): + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super(ActorCriticPolicyStdDecay, self).__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output, + ) + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == FlattenExtractor: + net_arch = [dict(pi=[64, 64], vf=[64, 64])] + else: + net_arch = [] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim + + self.normalize_images = normalize_images + assert 0 <= std_decay_threshold <= 1, "std decay threshold must be included in range [0, 1]" + self.std_decay_threshold = std_decay_threshold + self.std_decay_min = std_decay_min + self.log_std_init = log_std_init + dist_kwargs = None + # Keyword arguments for gSDE distribution + if use_sde: + dist_kwargs = { + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": sde_net_arch is not None, + } + + self.sde_features_extractor = None + self.sde_net_arch = sde_net_arch + self.use_sde = use_sde + self.dist_kwargs = dist_kwargs + + # Action distribution + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) + + self._build(lr_schedule) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=default_none_kwargs["squash_output"], + full_std=default_none_kwargs["full_std"], + sde_net_arch=default_none_kwargs["sde_net_arch"], + use_expln=default_none_kwargs["use_expln"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, n_envs: int = 1) -> None: + """ + Sample new weights for the exploration matrix. + + :param n_envs: + """ + assert isinstance(self.action_dist, + StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor( + self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device + ) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + latent_dim_pi = self.mlp_extractor.latent_dim_pi + + # Separate features extractor for gSDE + if self.sde_net_arch is not None: + self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( + self.features_dim, self.sde_net_arch, self.activation_fn + ) + + if isinstance(self.action_dist, DiagGaussianDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, CategoricalDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, BernoulliDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") + + # If we're doing linearly decaying std, then self.log_std must be excluded from gradient calculation graph + if self.std_decay_threshold > 0: + self.log_std.requires_grad_(False) + + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + latent_pi, latent_vf, latent_sde = self._get_latent(obs) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Get the latent code (i.e., activations of the last layer of each network) + for the different networks. + + :param obs: Observation + :return: Latent codes + for the actor, the value function and for gSDE function + """ + # Preprocess the observation if needed + features = self.extract_features(obs, self.features_extractor) + latent_pi, latent_vf = self.mlp_extractor(features) + + # Features for sde + latent_sde = latent_pi + if self.sde_features_extractor is not None: + latent_sde = self.sde_features_extractor(features) + return latent_pi, latent_vf, latent_sde + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :param latent_sde: Latent code for the gSDE exploration function + :return: Action distribution + """ + mean_actions = self.action_net(latent_pi) + + if isinstance(self.action_dist, DiagGaussianDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std) + elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) + else: + raise ValueError("Invalid action distribution") + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + latent_pi, _, latent_sde = self._get_latent(observation) + distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) + return distribution.get_actions(deterministic=deterministic) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + latent_pi, latent_vf, latent_sde = self._get_latent(obs) + distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + +class ActorCriticPolicyTanhActions(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + wandb_id: str = None + ): + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super(ActorCriticPolicyTanhActions, self).__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output + ) + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [dict(pi=[64, 64], vf=[64, 64])] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim + + self.normalize_images = normalize_images + self.log_std_init = log_std_init + dist_kwargs = None + # Keyword arguments for gSDE distribution + if use_sde: + dist_kwargs = { + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": False, + } + + if sde_net_arch is not None: + warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) + + self.use_sde = use_sde + self.dist_kwargs = dist_kwargs + + # Action distribution + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) + + self._build(lr_schedule) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=default_none_kwargs["squash_output"], + full_std=default_none_kwargs["full_std"], + use_expln=default_none_kwargs["use_expln"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, n_envs: int = 1) -> None: + """ + Sample new weights for the exploration matrix. + + :param n_envs: + """ + assert isinstance(self.action_dist, + StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + latent_dim_pi = self.mlp_extractor.latent_dim_pi + + if isinstance(self.action_dist, DiagGaussianDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") + + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs, self.features_extractor) + latent_pi, latent_vf = self.mlp_extractor(features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :return: Action distribution + """ + mean_actions = self.action_net(latent_pi) + mean_actions = th.tanh(mean_actions) + + if isinstance(self.action_dist, DiagGaussianDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std) + elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi) + else: + raise ValueError("Invalid action distribution") + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + return self.get_distribution(observation).get_actions(deterministic=deterministic) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs, self.features_extractor) + latent_pi, latent_vf = self.mlp_extractor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def get_distribution(self, obs: th.Tensor) -> Distribution: + """ + Get the current policy distribution given the observations. + + :param obs: + :return: the action distribution. + """ + features = self.extract_features(obs, self.features_extractor) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values(self, obs: th.Tensor) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: + :return: the estimated values. + """ + features = self.extract_features(obs, self.features_extractor) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + +class MultiInputActorCriticPolicyTanhActions(ActorCriticPolicyTanhActions): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space (Tuple) + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Uses the CombinedExtractor + :param features_extractor_kwargs: Keyword arguments + to pass to the feature extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + wandb_id: str = None + ): + super(MultiInputActorCriticPolicyTanhActions, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) diff --git a/hci/user-in-the-box/uitb/rl/sb3/recurrent_policies.py b/hci/user-in-the-box/uitb/rl/sb3/recurrent_policies.py new file mode 100644 index 00000000..26b70336 --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/recurrent_policies.py @@ -0,0 +1,461 @@ +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gymnasium as gym +import numpy as np +import torch as th +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import zip_strict +from torch import nn + +from sb3_contrib.common.recurrent.type_aliases import RNNStates + +from uitb.rl.sb3.policies import ActorCriticPolicyTanhActions + + +class RecurrentActorCriticPolicyTanhActions(ActorCriticPolicyTanhActions): + """ + Recurrent policy class for actor-critic algorithms (has both policy and value prediction). + To be used with A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + shared_lstm: bool = False, + enable_critic_lstm: bool = False, + ): + self.lstm_output_dim = lstm_hidden_size + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + self.shared_lstm = shared_lstm + self.enable_critic_lstm = enable_critic_lstm + self.lstm_actor = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) + self.lstm_shape = (n_lstm_layers, 1, lstm_hidden_size) + self.critic = None + self.lstm_critic = None + assert not ( + self.shared_lstm and self.enable_critic_lstm + ), "You must choose between shared LSTM, seperate or no LSTM for the critic" + + if not (self.shared_lstm or self.enable_critic_lstm): + self.critic = nn.Linear(self.features_dim, lstm_hidden_size) + + if self.enable_critic_lstm: + self.lstm_critic = nn.LSTM(self.features_dim, lstm_hidden_size, num_layers=n_lstm_layers) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + self.mlp_extractor = MlpExtractor( + self.lstm_output_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + @staticmethod + def _process_sequence( + features: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + lstm: nn.LSTM, + ) -> Tuple[th.Tensor, th.Tensor]: + # LSTM logic + # (sequence length, n_envs, features dim) (batch size = n envs_old_to_be_removed) + n_envs = lstm_states[0].shape[1] + # Batch to sequence + features_sequence = features.reshape((n_envs, -1, lstm.input_size)).swapaxes(0, 1) + episode_starts = episode_starts.reshape((n_envs, -1)).swapaxes(0, 1) + + lstm_output = [] + # Iterate over the sequence + for features, episode_start in zip_strict(features_sequence, episode_starts): + hidden, lstm_states = lstm( + features.unsqueeze(dim=0), + ( + (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[0], + (1.0 - episode_start).view(1, n_envs, 1) * lstm_states[1], + ), + ) + lstm_output += [hidden] + # Sequence to batch + lstm_output = th.flatten(th.cat(lstm_output).transpose(0, 1), start_dim=0, end_dim=1) + return lstm_output, lstm_states + + def forward( + self, + obs: th.Tensor, + lstm_states: RNNStates, + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, RNNStates]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation. Observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + # Preprocess the observation if needed + features = self.extract_features(obs, self.features_extractor) + # latent_pi, latent_vf = self.mlp_extractor(features) + latent_pi, lstm_states_pi = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Re-use LSTM features but do not backpropagate + latent_vf = latent_pi.detach() + lstm_states_vf = (lstm_states_pi[0].detach(), lstm_states_pi[1].detach()) + else: + latent_vf = self.critic(features) + lstm_states_vf = lstm_states_pi + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob, RNNStates(lstm_states_pi, lstm_states_vf) + + def get_distribution( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> Tuple[Distribution, Tuple[th.Tensor, ...]]: + """ + Get the current policy distribution given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the action distribution and new hidden states. + """ + features = self.extract_features(obs, self.features_extractor) + latent_pi, lstm_states = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + return self._get_action_dist_from_latent(latent_pi), lstm_states + + def predict_values( + self, + obs: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + ) -> th.Tensor: + """ + Get the estimated values according to the current policy given the observations. + + :param obs: Observation. + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: the estimated values. + """ + features = self.extract_features(obs, self.features_extractor) + if self.lstm_critic is not None: + latent_vf, lstm_states_vf = self._process_sequence(features, lstm_states, episode_starts, self.lstm_critic) + elif self.shared_lstm: + # Use LSTM from the actor + latent_pi, _ = self._process_sequence(features, lstm_states, episode_starts, self.lstm_actor) + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + return self.value_net(latent_vf) + + def evaluate_actions( + self, + obs: th.Tensor, + actions: th.Tensor, + lstm_states: RNNStates, + episode_starts: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: Observation. + :param actions: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + # Preprocess the observation if needed + features = self.extract_features(obs, self.features_extractor) + latent_pi, _ = self._process_sequence(features, lstm_states.pi, episode_starts, self.lstm_actor) + + if self.lstm_critic is not None: + latent_vf, _ = self._process_sequence(features, lstm_states.vf, episode_starts, self.lstm_critic) + elif self.shared_lstm: + latent_vf = latent_pi.detach() + else: + latent_vf = self.critic(features) + + latent_pi = self.mlp_extractor.forward_actor(latent_pi) + latent_vf = self.mlp_extractor.forward_critic(latent_vf) + + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def _predict( + self, + observation: th.Tensor, + lstm_states: Tuple[th.Tensor, th.Tensor], + episode_starts: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, Tuple[th.Tensor, ...]]: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy and hidden states of the RNN + """ + distribution, lstm_states = self.get_distribution(observation, lstm_states, episode_starts) + return distribution.get_actions(deterministic=deterministic), lstm_states + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[Tuple[np.ndarray, ...]] = None, + episode_start: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + """ + Get the policy action from an observation (and optional hidden state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param lstm_states: The last hidden and memory states for the LSTM. + :param episode_starts: Whether the observations correspond to new episodes + or not (we reset the lstm states in that case). + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next hidden state + (used in recurrent policies) + """ + # Switch to eval mode (this affects batch norm / dropout) + self.set_training_mode(False) + + observation, vectorized_env = self.obs_to_tensor(observation) + + if isinstance(observation, dict): + n_envs = observation[list(observation.keys())[0]].shape[0] + else: + n_envs = observation.shape[0] + # state : (n_layers, n_envs, dim) + if state is None: + # Initialize hidden states to zeros + state = np.concatenate([np.zeros(self.lstm_shape) for _ in range(n_envs)], axis=1) + state = (state, state) + + if episode_start is None: + episode_start = np.array([False for _ in range(n_envs)]) + + with th.no_grad(): + # Convert to PyTorch tensors + states = th.tensor(state[0]).float().to(self.device), th.tensor(state[1]).float().to(self.device) + episode_starts = th.tensor(episode_start).float().to(self.device) + actions, states = self._predict( + observation, lstm_states=states, episode_starts=episode_starts, deterministic=deterministic + ) + states = (states[0].cpu().numpy(), states[1].cpu().numpy()) + + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + # Remove batch dimension if needed + if not vectorized_env: + actions = actions[0] + + return actions, states + + +class RecurrentMultiInputActorCriticPolicyTanhActions(RecurrentActorCriticPolicyTanhActions): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param lstm_hidden_size: Number of hidden units for each LSTM layer. + :param n_lstm_layers: Number of LSTM layers. + :param shared_lstm: Whether the LSTM is shared between the actor and the critic. + By default, only the actor has a recurrent network. + :param enable_critic_lstm: Use a seperate LSTM for the critic. + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + lstm_hidden_size: int = 256, + n_lstm_layers: int = 1, + enable_critic_lstm: bool = False, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + lstm_hidden_size, + n_lstm_layers, + enable_critic_lstm, + ) diff --git a/hci/user-in-the-box/uitb/rl/sb3/schedule.py b/hci/user-in-the-box/uitb/rl/sb3/schedule.py new file mode 100644 index 00000000..0c460534 --- /dev/null +++ b/hci/user-in-the-box/uitb/rl/sb3/schedule.py @@ -0,0 +1,27 @@ +from typing import Callable + +def linear_schedule(initial_value: float, min_value: float, threshold: float = 1.0) -> Callable[[float], float]: + """ + Linear learning rate schedule. Adapted from the example at + https://stable-baselines3.readthedocs.io/en/master/guide/examples.html#learning-rate-schedule + + :param initial_value: Initial learning rate. + :param min_value: Minimum learning rate. + :param threshold: Threshold (of progress) when decay begins. + :return: schedule that computes + current learning rate depending on remaining progress + """ + + def func(progress_remaining: float) -> float: + """ + Progress will decrease from 1 (beginning) to 0. + + :param progress_remaining: + :return: current learning rate + """ + if progress_remaining > threshold: + return initial_value + else: + return min_value + (progress_remaining/threshold) * (initial_value - min_value) + + return func