diff --git a/.github/unittest/linux_libs/scripts_botorch/environment.yml b/.github/unittest/linux_libs/scripts_botorch/environment.yml new file mode 100644 index 00000000000..c6a5013405e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/environment.yml @@ -0,0 +1,23 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-json-report + - pytest-error-for-skips + - expecttest + - pybind11[global] + - pyyaml + - scipy + - botorch + - gpytorch + - psutil diff --git a/.github/unittest/linux_libs/scripts_botorch/install.sh b/.github/unittest/linux_libs/scripts_botorch/install.sh new file mode 100755 index 00000000000..395e15ea99e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/install.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION + +set -euxo pipefail + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu128" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git --progress-bar off + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python -m pip install -e . --no-build-isolation + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_botorch/post_process.sh b/.github/unittest/linux_libs/scripts_botorch/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_botorch/run_test.sh b/.github/unittest/linux_libs/scripts_botorch/run_test.sh new file mode 100755 index 00000000000..3d732357ef6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/run_test.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False + +python -m torch.utils.collect_env +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +export MKL_THREADING_LAYER=GNU + +# smoke test +python -c "import botorch; print('botorch', botorch.__version__)" +python -c "import gpytorch; print('gpytorch', gpytorch.__version__)" + +# JSON report for flaky test tracking +json_report_dir="${RUNNER_ARTIFACT_DIR:-${root_dir}}" +json_report_args="--json-report --json-report-file=${json_report_dir}/test-results-botorch.json --json-report-indent=2" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_objectives.py ${json_report_args} --instafail -v --durations 200 --capture no -k TestGPWorldModel --error-for-skips +coverage combine -q +coverage xml -i + +# Upload test results with metadata for flaky tracking +python .github/unittest/helpers/upload_test_results.py || echo "Warning: Failed to process test results for flaky tracking" diff --git a/.github/unittest/linux_libs/scripts_botorch/setup_env.sh b/.github/unittest/linux_libs/scripts_botorch/setup_env.sh new file mode 100755 index 00000000000..d7dbd1bb7e6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/setup_env.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +apt-get update && apt-get upgrade -y && apt-get install -y git cmake +git config --global --add safe.directory '*' +apt-get install -y wget gcc g++ + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index b1840a11fef..0f413792711 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -93,6 +93,44 @@ jobs: bash .github/unittest/linux_libs/scripts_brax/run_all.sh + unittests-botorch: + strategy: + matrix: + python_version: ["3.10"] + cuda_arch_version: ["12.8"] + if: ${{ github.event_name == 'push' || github.event_name == 'workflow_call' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'Modules') || contains(github.event.pull_request.labels.*.name, 'Objectives') }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "12.8" + docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.10" + export CU_VERSION="12.8" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export TD_GET_DEFAULTS_TO_NONE=1 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_botorch/setup_env.sh + bash .github/unittest/linux_libs/scripts_botorch/install.sh + bash .github/unittest/linux_libs/scripts_botorch/run_test.sh + bash .github/unittest/linux_libs/scripts_botorch/post_process.sh + # unittests-d4rl: # strategy: # matrix: diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst index bf6ba8b9a96..66f91682377 100644 --- a/docs/source/reference/envs_api.rst +++ b/docs/source/reference/envs_api.rst @@ -191,6 +191,7 @@ Domain-specific ModelBasedEnvBase model_based.dreamer.DreamerEnv model_based.dreamer.DreamerDecoder + model_based.imagined.ImaginedEnv Helpers ------- diff --git a/docs/source/reference/envs_transforms.rst b/docs/source/reference/envs_transforms.rst index e3f8ab55fab..b345c493131 100644 --- a/docs/source/reference/envs_transforms.rst +++ b/docs/source/reference/envs_transforms.rst @@ -273,6 +273,7 @@ Available Transforms Hash InitTracker LineariseRewards + MeanActionSelector ModuleTransform MultiAction NoopResetEnv diff --git a/docs/source/reference/modules_models.rst b/docs/source/reference/modules_models.rst index be3e74ef0c7..68891fe4e67 100644 --- a/docs/source/reference/modules_models.rst +++ b/docs/source/reference/modules_models.rst @@ -16,3 +16,15 @@ Modules for model-based reinforcement learning, including world models and dynam RSSMPosterior RSSMPrior RSSMRollout + +PILCO +----- + +Components for moment-matching model-based policy search (PILCO). + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GPWorldModel + RBFController diff --git a/docs/source/reference/objectives_other.rst b/docs/source/reference/objectives_other.rst index 018268ed7f6..b97d3efea50 100644 --- a/docs/source/reference/objectives_other.rst +++ b/docs/source/reference/objectives_other.rst @@ -15,3 +15,4 @@ Additional loss modules for specialized algorithms. DreamerActorLoss DreamerModelLoss DreamerValueLoss + ExponentialQuadraticCost diff --git a/pyproject.toml b/pyproject.toml index 9b5b4a22d2f..f83ac2a7b57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,10 @@ marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot; python_version>='3 open_spiel = ["open_spiel>=1.5"] brax = ["jax>=0.7.0; python_version>='3.11'", "brax; python_version>='3.11'"] procgen = ["procgen"] +pilco = [ + "botorch", + "gpytorch", +] # Base LLM dependencies (no inference backend - use llm-vllm or llm-sglang) llm = [ "transformers", diff --git a/setup-and-run.sh b/setup-and-run.sh index f4414cf1b38..c45565327cf 100755 --- a/setup-and-run.sh +++ b/setup-and-run.sh @@ -30,6 +30,7 @@ REPO_DIR="/root/rl" VENV_DIR="/root/torchrl_venv" MODE="isaac" # "isaac" or "dmcontrol" BUILD_ONLY=false +GPUS="" # explicit GPU set, e.g. "3,4,5" EXTRA_ARGS=() # extra Hydra overrides forwarded to the training script # ---- Parse arguments -------------------------------------------------------- @@ -38,6 +39,7 @@ for arg in "$@"; do --build-only) BUILD_ONLY=true ;; --dmcontrol) MODE="dmcontrol" ;; --isaac) MODE="isaac" ;; + --gpus=*) GPUS="${arg#--gpus=}" ;; *) EXTRA_ARGS+=("$arg") ;; esac done @@ -45,15 +47,38 @@ done # Avoid "'': unknown terminal type" in headless containers export TERM="${TERM:-xterm}" +# Resolve GPU set early so we can use it for zombie cleanup +if [[ -n "$GPUS" ]]; then + export CUDA_VISIBLE_DEVICES="$GPUS" +elif [[ "$MODE" == "isaac" ]]; then + export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2}" +fi + echo "============================================================" echo " setup-and-run.sh" echo " mode=$MODE build_only=$BUILD_ONLY" +echo " gpus=${CUDA_VISIBLE_DEVICES:-}" echo " extra_args=${EXTRA_ARGS[*]:-}" echo "============================================================" -# ---- 0) Kill zombie Python processes from previous runs --------------------- -echo "* Killing leftover Python processes..." -pkill -9 -f python || true +# ---- 0) Kill zombie Python processes on the SAME GPUs ---------------------- +# Only kill dreamer processes whose CUDA_VISIBLE_DEVICES matches ours, +# so that a second experiment on different GPUs is left untouched. +echo "* Killing leftover dreamer processes on GPUs=${CUDA_VISIBLE_DEVICES:-}..." +if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + # Find dreamer_isaac.py PIDs whose /proc//environ contains our GPU set + for pid in $(pgrep -f "dreamer_isaac.py|dreamer.py" 2>/dev/null || true); do + proc_env=$(tr '\0' '\n' < /proc/$pid/environ 2>/dev/null || true) + proc_gpus=$(echo "$proc_env" | grep '^CUDA_VISIBLE_DEVICES=' | head -1 | cut -d= -f2) + if [[ "$proc_gpus" == "$CUDA_VISIBLE_DEVICES" ]] || [[ -z "$proc_gpus" ]]; then + echo " Killing PID $pid (CUDA_VISIBLE_DEVICES=$proc_gpus)" + kill -9 "$pid" 2>/dev/null || true + fi + done +else + # No GPU constraint — kill all dreamer processes + pkill -9 -f "dreamer_isaac.py|dreamer.py" || true +fi sleep 1 # ---- 1) System dependencies ------------------------------------------------ @@ -203,8 +228,8 @@ echo "============================================================" cd "$REPO_DIR" if [[ "$MODE" == "isaac" ]]; then - # Expose 3 GPUs: GPU 0 = sim, GPU 1 = training, GPU 2 = eval (rendering) - export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2}" + # GPUs already set above: GPU0 = sim, GPU1 = training, GPU2 = eval (rendering) + echo " CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" $PYTHON "sota-implementations/dreamer/dreamer_isaac.py" "${EXTRA_ARGS[@]}" else export MUJOCO_GL=egl diff --git a/sota-check/run_pilco.sh b/sota-check/run_pilco.sh new file mode 100644 index 00000000000..393b2ed7332 --- /dev/null +++ b/sota-check/run_pilco.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=pilco +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/pilco_%j.txt +#SBATCH --error=slurm_errors/pilco_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="pilco" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/pilco/pilco.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-implementations/pilco/config.yaml b/sota-implementations/pilco/config.yaml new file mode 100644 index 00000000000..93ee8ff126e --- /dev/null +++ b/sota-implementations/pilco/config.yaml @@ -0,0 +1,18 @@ +env: + env_name: InvertedPendulum-v5 + library: gym +device: null +logger: + backend: wandb + project_name: torchrl_pilco + group_name: null + video: True +optim: + policy_lr: 5e-3 +pilco: + horizon: 40 + initial_rollout_length: 200 + max_rollout_length: 350 + epochs: 3 + policy_training_steps: 100 + policy_n_basis: 10 diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py new file mode 100644 index 00000000000..63fb765f169 --- /dev/null +++ b/sota-implementations/pilco/pilco.py @@ -0,0 +1,159 @@ +import hydra +import tensordict +import torch +from omegaconf import DictConfig + +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +from torchrl._utils import get_available_device +from torchrl.envs import EnvBase, TransformedEnv +from torchrl.envs.model_based import ImaginedEnv +from torchrl.envs.transforms import MeanActionSelector +from torchrl.envs.utils import RandomPolicy +from torchrl.modules.models import GPWorldModel, RBFController +from torchrl.objectives import ExponentialQuadraticCost +from torchrl.record.loggers import generate_exp_name, get_logger, Logger + +from utils import make_env + + +def pilco_loop( + cfg: DictConfig, env: EnvBase, logger: Logger | None = None +) -> TensorDictModule: + obs_dim = env.observation_spec["observation"].shape[-1] + action_dim = env.action_spec.shape[-1] + + random_policy = RandomPolicy(action_spec=env.action_spec) + rollout = env.rollout( + max_steps=cfg.pilco.initial_rollout_length, + policy=random_policy, + break_when_all_done=False, + break_when_any_done=False, + ) + + base_policy = ( + RBFController( + input_dim=obs_dim, + output_dim=action_dim, + n_basis=cfg.pilco.policy_n_basis, + max_action=env.action_spec.high, + ) + .to(env.device) + .double() + ) + policy_module = TensorDictModule( + module=base_policy, + in_keys=[("observation", "mean"), ("observation", "var")], + out_keys=[ + ("action", "mean"), + ("action", "var"), + ("action", "cross_covariance"), + ], + ) + optimizer = torch.optim.Adam(policy_module.parameters(), lr=cfg.optim.policy_lr) + + dtype = torch.float64 + initial_observation = TensorDict( + { + ("observation", "mean"): torch.zeros( + obs_dim, device=env.device, dtype=dtype + ), + ("observation", "var"): torch.eye(obs_dim, device=env.device, dtype=dtype) + * 1e-3, + } + ) + + eval_env = TransformedEnv(env, MeanActionSelector()) + + cost_module = ExponentialQuadraticCost(reduction="none").to(env.device) + for epoch in range(cfg.pilco.epochs): + base_world_model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim).to( + env.device + ) + base_world_model.fit(rollout) + base_world_model.eval() + + imagined_env = ImaginedEnv( + world_model_module=base_world_model, + base_env=env, + ) + reset_td = initial_observation.expand(*imagined_env.batch_size) + + for step in range(cfg.pilco.policy_training_steps): + logger_step = (epoch * cfg.pilco.policy_training_steps) + step + optimizer.zero_grad() + + imagined_data = imagined_env.rollout( + max_steps=cfg.pilco.horizon, + policy=policy_module, + tensordict=reset_td, + ) + + loss_td = cost_module(imagined_data) + loss = loss_td.get("loss_cost").sum(dim=-1).mean() + + loss.backward() + optimizer.step() + + if logger: + logger.log_scalar( + "train/trajectory_cost", loss.item(), step=logger_step + ) + + test_rollout = eval_env.rollout( + max_steps=100, + policy=policy_module, + break_when_any_done=True, + ) + + reward = test_rollout["episode_reward"][-1].tolist() + steps = test_rollout["step_count"].max().tolist() + + if logger: + logger.log_scalar("eval/reward", reward, step=logger_step) + logger.log_scalar("eval/steps", steps, step=logger_step) + + test_rollout.set("observation", test_rollout.get(("observation", "mean"))) + test_rollout.set("action", test_rollout.get(("action", "mean"))) + test_rollout.set( + ("next", "observation"), test_rollout.get(("next", "observation", "mean")) + ) + + test_rollout = test_rollout.select( + *rollout.keys(include_nested=True, leaves_only=True) + ) + rollout = tensordict.cat([rollout, test_rollout], dim=0) + + if len(rollout) > cfg.pilco.max_rollout_length: + rollout = rollout[-cfg.pilco.max_rollout_length :] + + return policy_module + + +@hydra.main(config_path="", config_name="config", version_base="1.1") +def main(cfg: DictConfig) -> None: + device = torch.device(cfg.device) if cfg.device else get_available_device() + + env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video) + + if cfg.logger.backend: + exp_name = generate_exp_name("PILCO", cfg.env.env_name) + logger = get_logger( + cfg.logger.backend, + logger_name="pilco", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + pilco_loop(cfg, env, logger=logger) + + if not env.is_closed: + env.close() + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/pilco/utils.py b/sota-implementations/pilco/utils.py new file mode 100644 index 00000000000..034cc90ca24 --- /dev/null +++ b/sota-implementations/pilco/utils.py @@ -0,0 +1,14 @@ +import torch +from torchrl.envs import GymEnv, RewardSum, StepCounter, TransformedEnv + + +def make_env( + env_name: str, device: str | torch.device, from_pixels: bool = False +) -> TransformedEnv: + """Creates the transformed environment for PILCO experiments.""" + env = TransformedEnv( + GymEnv(env_name, pixels_only=False, from_pixels=from_pixels, device=device) + ) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + return env diff --git a/test/test_objectives.py b/test/test_objectives.py index 39ad38aa2f6..13c4d71eba8 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -58,7 +58,8 @@ from torchrl.envs import EnvBase, GymEnv, InitTracker, SerialEnv from torchrl.envs.libs.gym import _has_gym from torchrl.envs.model_based.dreamer import DreamerEnv -from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv +from torchrl.envs.model_based.imagined import ImaginedEnv +from torchrl.envs.transforms import MeanActionSelector, TensorDictPrimer, TransformedEnv from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type from torchrl.modules import ( DistributionalQValueActor, @@ -82,6 +83,7 @@ RSSMRollout, ) from torchrl.modules.models.models import MLP +from torchrl.modules.models.rbf_controller import RBFController from torchrl.modules.tensordict_module.actors import ( Actor, ActorCriticOperator, @@ -105,6 +107,7 @@ DreamerModelLoss, DreamerValueLoss, DTLoss, + ExponentialQuadraticCost, GAILLoss, IQLLoss, KLPENPPOLoss, @@ -175,6 +178,7 @@ FUNCTORCH_ERR = str(err) _has_transformers = bool(importlib.util.find_spec("transformers")) +_has_botorch = bool(importlib.util.find_spec("botorch")) TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) IS_WINDOWS = sys.platform == "win32" @@ -18295,6 +18299,646 @@ def test_make_value_estimator_with_gae_instance(self, device): assert loss_fn.value_type is GAE +class TestRBFController: + @pytest.mark.parametrize("input_dim", [2, 4]) + @pytest.mark.parametrize("output_dim", [1, 3]) + @pytest.mark.parametrize("n_basis", [5, 10]) + def test_forward_shapes(self, input_dim, output_dim, n_basis): + max_action = torch.ones(output_dim) + controller = RBFController( + input_dim=input_dim, + output_dim=output_dim, + max_action=max_action, + n_basis=n_basis, + ).double() + + batch_size = 3 + mean = torch.randn(batch_size, input_dim, dtype=torch.float64) + cov = ( + torch.eye(input_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch_size, -1, -1) + * 0.1 + ) + + action_mean, action_cov, cross_cov = controller(mean, cov) + + assert action_mean.shape == (batch_size, output_dim) + assert action_cov.shape == (batch_size, output_dim, output_dim) + assert cross_cov.shape == (batch_size, input_dim, output_dim) + + def test_action_covariance_is_symmetric(self): + controller = RBFController( + input_dim=4, output_dim=2, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + cov = torch.eye(4, dtype=torch.float64).unsqueeze(0).expand(2, -1, -1) * 0.1 + + _, action_cov, _ = controller(mean, cov) + + torch.testing.assert_close( + action_cov, action_cov.transpose(-2, -1), atol=1e-6, rtol=1e-5 + ) + + def test_action_covariance_is_positive_semidefinite(self): + controller = RBFController( + input_dim=4, output_dim=2, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + cov = torch.eye(4, dtype=torch.float64).unsqueeze(0).expand(2, -1, -1) * 0.1 + + _, action_cov, _ = controller(mean, cov) + + eigenvalues = torch.linalg.eigvalsh(action_cov) + assert ( + eigenvalues >= -1e-6 + ).all(), f"Negative eigenvalues found: {eigenvalues}" + + @pytest.mark.parametrize("max_action", [0.5, 1.0, 2.0]) + def test_squash_sin_bounds(self, max_action): + mean = torch.randn(10, 3, dtype=torch.float64) + cov = torch.eye(3, dtype=torch.float64).unsqueeze(0).expand(10, -1, -1) * 0.01 + + squashed_mean, squashed_cov, cross_cov = RBFController.squash_sin( + mean, cov, max_action + ) + + assert (squashed_mean.abs() <= max_action + 1e-6).all() + assert squashed_cov.shape == (10, 3, 3) + assert cross_cov.shape == (10, 3, 3) + + def test_deterministic_with_zero_variance(self): + controller = RBFController( + input_dim=4, output_dim=1, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + zero_cov = torch.zeros(2, 4, 4, dtype=torch.float64) + + action_mean1, _, _ = controller(mean, zero_cov) + action_mean2, _, _ = controller(mean, zero_cov) + + torch.testing.assert_close(action_mean1, action_mean2) + + def test_gradients_flow(self): + controller = RBFController( + input_dim=4, output_dim=1, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + cov = torch.eye(4, dtype=torch.float64).unsqueeze(0).expand(2, -1, -1) * 0.1 + + action_mean, action_cov, cross_cov = controller(mean, cov) + loss = action_mean.sum() + action_cov.sum() + loss.backward() + + for name, param in controller.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + def test_as_tensordict_module(self): + controller = RBFController( + input_dim=4, output_dim=1, max_action=1.0, n_basis=5 + ).double() + + module = TensorDictModule( + module=controller, + in_keys=[("observation", "mean"), ("observation", "var")], + out_keys=[ + ("action", "mean"), + ("action", "var"), + ("action", "cross_covariance"), + ], + ) + + td = TensorDict( + { + ("observation", "mean"): torch.randn(2, 4, dtype=torch.float64), + ("observation", "var"): torch.eye(4, dtype=torch.float64) + .unsqueeze(0) + .expand(2, -1, -1) + * 0.1, + }, + batch_size=[2], + ) + + out = module(td) + assert ("action", "mean") in out.keys(True) + assert ("action", "var") in out.keys(True) + assert ("action", "cross_covariance") in out.keys(True) + + +class TestExponentialQuadraticCost: + def test_forward_shapes_default(self): + cost = ExponentialQuadraticCost(reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.randn(2, 5, 4), + ("observation", "var"): torch.eye(4) + .unsqueeze(0) + .unsqueeze(0) + .expand(2, 5, -1, -1) + * 0.1, + }, + batch_size=[2, 5], + ) + + out = cost(td) + loss = out["loss_cost"] + assert loss.shape == (2, 5) + + def test_cost_at_target_is_low(self): + target = torch.zeros(4) + cost = ExponentialQuadraticCost(target=target, reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.zeros(1, 4), + ("observation", "var"): torch.eye(4).unsqueeze(0) * 1e-6, + }, + batch_size=[1], + ) + + out = cost(td) + assert out["loss_cost"].item() < 0.01 + + def test_cost_far_from_target_is_high(self): + target = torch.zeros(4) + cost = ExponentialQuadraticCost(target=target, reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.ones(1, 4) * 10.0, + ("observation", "var"): torch.eye(4).unsqueeze(0) * 0.1, + }, + batch_size=[1], + ) + + out = cost(td) + assert out["loss_cost"].item() > 0.9 + + def test_cost_bounded_zero_one(self): + cost = ExponentialQuadraticCost(reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.randn(10, 4), + ("observation", "var"): torch.eye(4).unsqueeze(0).expand(10, -1, -1) + * 0.1, + }, + batch_size=[10], + ) + + out = cost(td) + loss = out["loss_cost"] + assert (loss >= -1e-6).all() + assert (loss <= 1.0 + 1e-6).all() + + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + def test_reductions(self, reduction): + cost = ExponentialQuadraticCost(reduction=reduction) + + td = TensorDict( + { + ("observation", "mean"): torch.randn(3, 5, 4), + ("observation", "var"): torch.eye(4) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, 5, -1, -1) + * 0.1, + }, + batch_size=[3, 5], + ) + + out = cost(td) + loss = out["loss_cost"] + + if reduction == "none": + assert loss.shape == (3, 5) + else: + assert loss.shape == () + + def test_custom_weights_and_target(self): + weights = torch.diag(torch.tensor([2.0, 0.5, 1.0, 1.0])) + target = torch.tensor([1.0, 0.0, 0.0, 0.0]) + cost = ExponentialQuadraticCost( + target=target, weights=weights, reduction="none" + ) + + td = TensorDict( + { + ("observation", "mean"): target.unsqueeze(0), + ("observation", "var"): torch.eye(4).unsqueeze(0) * 1e-6, + }, + batch_size=[1], + ) + + out = cost(td) + assert out["loss_cost"].item() < 0.01 + + def test_gradients_flow(self): + cost = ExponentialQuadraticCost(reduction="mean") + + mean = torch.randn(2, 4, requires_grad=True) + var = torch.eye(4).unsqueeze(0).expand(2, -1, -1) * 0.1 + + td = TensorDict( + {("observation", "mean"): mean, ("observation", "var"): var}, + batch_size=[2], + ) + + out = cost(td) + out["loss_cost"].backward() + assert mean.grad is not None + + +class TestImaginedEnv: + @staticmethod + def _make_dummy_world_model(obs_dim, action_dim): + class DummyWM(torch.nn.Module): + def __init__(self, obs_dim): + super().__init__() + self.obs_dim = obs_dim + + def forward(self, action, observation): + mean = observation.get("mean") + var = ( + torch.eye( + self.obs_dim, device=mean.device, dtype=mean.dtype + ).expand(*mean.shape[:-1], -1, -1) + * 0.01 + ) + return mean + 0.1, var + + return TensorDictModule( + DummyWM(obs_dim), + in_keys=["action", "observation"], + out_keys=[("next_observation", "mean"), ("next_observation", "var")], + ) + + @staticmethod + def _make_base_env(obs_dim, action_dim): + class StubEnv(EnvBase): + def __init__(self, obs_dim, action_dim): + super().__init__(batch_size=torch.Size([])) + self.observation_spec = Composite( + observation=Unbounded(shape=(obs_dim,)) + ) + self.action_spec = Unbounded(shape=(action_dim,)) + self.reward_spec = Unbounded(shape=(1,)) + + def _reset(self, tensordict=None): + return TensorDict( + {"observation": torch.zeros(obs_dim)}, + batch_size=self.batch_size, + ) + + def _step(self, tensordict): + return TensorDict( + { + "observation": torch.randn(obs_dim), + "reward": torch.zeros(1), + "done": torch.tensor(False).unsqueeze(0), + "terminated": torch.tensor(False).unsqueeze(0), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed): + pass + + return StubEnv(obs_dim, action_dim) + + def test_creation(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + assert env.batch_size == torch.Size([1]) + + def test_creation_with_batch_size(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env, batch_size=[3]) + assert env.batch_size == torch.Size([3]) + + def test_reset_with_observation(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + + reset_td = TensorDict( + { + ("observation", "mean"): torch.zeros(1, obs_dim), + ("observation", "var"): torch.eye(obs_dim).unsqueeze(0) * 1e-3, + }, + batch_size=[1], + ) + + out = env.reset(reset_td) + assert ("observation", "mean") in out.keys(True) + assert ("observation", "var") in out.keys(True) + + def test_step(self): + obs_dim, action_dim = 4, 1 + next_observation_key = ( + "next_observation" # ("next", "observation") could also be a possibility + ) + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv( + world_model_module=wm, + base_env=base_env, + next_observation_key=next_observation_key, + ) + + td = TensorDict( + { + ("observation", "mean"): torch.zeros(1, obs_dim), + ("observation", "var"): torch.eye(obs_dim).unsqueeze(0) * 1e-3, + ("action", "mean"): torch.zeros(1, action_dim), + ("action", "var"): torch.zeros(1, action_dim, action_dim), + ("action", "cross_covariance"): torch.zeros(1, obs_dim, action_dim), + }, + batch_size=[1], + ) + + out = env.step(td) + next_td = out["next"] + assert ("observation", "mean") in next_td.keys(True) + assert ("observation", "var") in next_td.keys(True) + assert "done" in next_td.keys() + assert not next_td["done"].any() + + def test_never_terminates(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + + td = TensorDict( + {"done": torch.ones(1, 1, dtype=torch.bool)}, + batch_size=[1], + ) + assert not env.any_done(td) + + +class TestMeanActionSelector: + @staticmethod + def _make_base_env(obs_dim, action_dim): + class StubEnv(EnvBase): + def __init__(self, obs_dim, action_dim): + super().__init__(batch_size=torch.Size([])) + self.observation_spec = Composite( + observation=Unbounded(shape=(obs_dim,)) + ) + self.action_spec = Unbounded(shape=(action_dim,)) + self.reward_spec = Unbounded(shape=(1,)) + + def _reset(self, tensordict=None): + return TensorDict( + {"observation": torch.zeros(obs_dim)}, + batch_size=self.batch_size, + ) + + def _step(self, tensordict): + return TensorDict( + { + "observation": torch.randn(obs_dim), + "reward": torch.zeros(1), + "done": torch.tensor(False).unsqueeze(0), + "terminated": torch.tensor(False).unsqueeze(0), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed): + pass + + return StubEnv(obs_dim, action_dim) + + def test_forward_wraps_observation(self): + transform = MeanActionSelector() + obs = torch.randn(4) + td = TensorDict( + {"observation": obs.clone()}, + batch_size=[], + ) + + out = transform._call(td) + assert ("observation", "mean") in out.keys(True) + assert ("observation", "var") in out.keys(True) + assert out["observation", "var"].shape == (4, 4) + torch.testing.assert_close(out["observation", "mean"], obs) + + def test_inverse_extracts_action_mean(self): + transform = MeanActionSelector() + action_mean = torch.randn(2) + td = TensorDict( + { + ("action", "mean"): action_mean, + ("action", "var"): torch.eye(2), + }, + batch_size=[], + ) + + out = transform._inv_call(td) + assert "action" in out.keys() + torch.testing.assert_close(out["action"], action_mean) + + def test_with_transformed_env_reset(self): + obs_dim, action_dim = 4, 1 + base_env = self._make_base_env(obs_dim, action_dim) + env = TransformedEnv(base_env, MeanActionSelector()) + + td = env.reset() + assert ("observation", "mean") in td.keys(True) + assert ("observation", "var") in td.keys(True) + + def test_observation_spec_transformed(self): + obs_dim, action_dim = 4, 1 + base_env = self._make_base_env(obs_dim, action_dim) + env = TransformedEnv(base_env, MeanActionSelector()) + + obs_spec = env.observation_spec + assert ("observation", "mean") in obs_spec.keys(True) + assert ("observation", "var") in obs_spec.keys(True) + + def test_zero_variance_on_reset(self): + obs_dim, action_dim = 4, 1 + base_env = self._make_base_env(obs_dim, action_dim) + env = TransformedEnv(base_env, MeanActionSelector()) + + td = env.reset() + var = td["observation", "var"] + torch.testing.assert_close(var, torch.zeros(obs_dim, obs_dim)) + + +@pytest.mark.skipif(not _has_botorch, reason="botorch/gpytorch not installed") +class TestGPWorldModel: + def test_creation(self): + from torchrl.modules.models.gp import GPWorldModel + + model = GPWorldModel(obs_dim=4, action_dim=1) + assert model.obs_dim == 4 + assert model.action_dim == 1 + assert model.state_action_dim == 5 + + def test_fit_and_deterministic_forward(self): + from torchrl.modules.models.gp import GPWorldModel + + obs_dim, action_dim = 2, 1 + model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) + + n_samples = 20 + obs = torch.randn(n_samples, obs_dim).double() + action = torch.randn(n_samples, action_dim).double() + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim).double() + + dataset = TensorDict( + { + "observation": obs, + "action": action, + ("next", "observation"): next_obs, + }, + batch_size=[n_samples], + ) + + model.fit(dataset) + model.eval() + + td = TensorDict( + { + ("observation", "mean"): torch.randn(3, obs_dim), + ("action", "mean"): torch.randn(3, action_dim), + }, + batch_size=[3], + ) + + forward_td = model.deterministic_forward(td) + + assert forward_td[("next", "observation", "mean")].shape == (3, obs_dim) + assert forward_td[("next", "observation", "var")].shape == (3, obs_dim, obs_dim) + + def test_uncertain_forward(self): + from torchrl.modules.models.gp import GPWorldModel + + obs_dim, action_dim = 2, 1 + model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) + + n_samples = 20 + obs = torch.randn(n_samples, obs_dim).double() + action = torch.randn(n_samples, action_dim).double() + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim).double() + + dataset = TensorDict( + { + "observation": obs, + "action": action, + ("next", "observation"): next_obs, + }, + batch_size=[n_samples], + ) + + model.double() + model.fit(dataset) + model.eval() + + batch = 2 + td = TensorDict( + { + "observation": { + "mean": torch.randn(batch, obs_dim, dtype=torch.float64), + "var": torch.eye(obs_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + }, + "action": { + "mean": torch.randn(batch, action_dim, dtype=torch.float64), + "var": torch.eye(action_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + "cross_covariance": torch.zeros( + batch, obs_dim, action_dim, dtype=torch.float64 + ), + }, + }, + batch_size=[batch], + ) + + forward_td = model.uncertain_forward(td) + + mean, var = ( + forward_td[("next", "observation", "mean")], + forward_td[("next", "observation", "var")], + ) + assert mean.shape == (batch, obs_dim) + assert var.shape == (batch, obs_dim, obs_dim) + + torch.testing.assert_close(var, var.transpose(-2, -1), atol=1e-5, rtol=1e-4) + + def test_forward_dispatch(self): + from torchrl.modules.models.gp import GPWorldModel + + obs_dim, action_dim = 2, 1 + model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) + + n_samples = 20 + obs = torch.randn(n_samples, obs_dim).double() + action = torch.randn(n_samples, action_dim).double() + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim).double() + + dataset = TensorDict( + { + "observation": obs, + "action": action, + ("next", "observation"): next_obs, + }, + batch_size=[n_samples], + ) + + model.fit(dataset) + model.eval() + + batch = 2 + td = TensorDict( + { + "observation": { + "mean": torch.randn(batch, obs_dim, dtype=torch.float64), + "var": torch.eye(obs_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + }, + "action": { + "mean": torch.randn(batch, action_dim, dtype=torch.float64), + "var": torch.eye(action_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + "cross_covariance": torch.zeros( + batch, obs_dim, action_dim, dtype=torch.float64 + ), + }, + }, + batch_size=[batch], + ) + forward_td = model(td) + mean = forward_td[("next", "observation", "mean")] + assert mean.shape == (2, obs_dim) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 05ecfc564e2..0bc594a8615 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -52,7 +52,7 @@ VmasEnv, VmasWrapper, ) -from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase +from .model_based import DreamerDecoder, DreamerEnv, ImaginedEnv, ModelBasedEnvBase from .transforms import ( ActionDiscretizer, ActionMask, @@ -83,6 +83,7 @@ Hash, InitTracker, LineariseRewards, + MeanActionSelector, MultiAction, MultiStepTransform, NoopResetEnv, @@ -184,6 +185,7 @@ "HabitatEnv", "Hash", "InitTracker", + "ImaginedEnv", "IsaacGymEnv", "IsaacGymWrapper", "JumanjiEnv", @@ -193,6 +195,7 @@ "MOGymEnv", "MOGymWrapper", "MarlGroupMapType", + "MeanActionSelector", "MeltingpotEnv", "MeltingpotWrapper", "ModelBasedEnvBase", diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index cb387af7ff8..11af9351561 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -5,5 +5,6 @@ from .common import ModelBasedEnvBase from .dreamer import DreamerDecoder, DreamerEnv +from .imagined import ImaginedEnv -__all__ = ["ModelBasedEnvBase", "DreamerDecoder", "DreamerEnv"] +__all__ = ["DreamerDecoder", "DreamerEnv", "ImaginedEnv", "ModelBasedEnvBase"] diff --git a/torchrl/envs/model_based/imagined.py b/torchrl/envs/model_based/imagined.py new file mode 100644 index 00000000000..2d13d02e9d8 --- /dev/null +++ b/torchrl/envs/model_based/imagined.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModule +from torchrl.envs.common import EnvBase +from torchrl.envs.model_based import ModelBasedEnvBase + + +class ImaginedEnv(ModelBasedEnvBase): + """Imagination environment for model-based policy search. + + Wraps a learned world model (e.g. a Gaussian Process) as a standard + TorchRL environment so that imagined rollouts can be collected with + :meth:`~torchrl.envs.EnvBase.rollout`. Observations carry both mean + and covariance (under keys ``("observation", "mean")`` and + ``("observation", "var")``) to support uncertainty-aware moment-matching + controllers. + + The environment never terminates on its own -- rollout length is + controlled solely by the ``max_steps`` argument of + :meth:`~torchrl.envs.EnvBase.rollout`. The ``done`` and ``terminated`` + flags are always ``False``. + + Args: + world_model_module (TensorDictModule): A :class:`~tensordict.nn.TensorDictModule` + that takes ``"action"`` and ``"observation"`` entries and produces + ``("next_observation", "mean")`` and ``("next_observation", "var")``. + base_env (EnvBase): The real environment whose specs (observation, action, + reward, done) are copied into this imagined environment. + batch_size (int, Sequence[int], torch.Size, optional): Override batch size. + If ``None``, inferred from ``base_env`` (with a minimum of ``[1]``). + next_observation_key (str or tuple of str, optional): The key where the world + model writes the predicted next observation. Defaults to ``("next", "observation")``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.envs.model_based import ImaginedEnv, ModelBasedEnvBase + >>> from torchrl.data import Composite, Unbounded + >>> base_env = GymEnv("Pendulum-v1") + >>> obs_dim = base_env.observation_spec["observation"].shape[-1] + >>> # A toy world model that returns zero-mean, identity covariance + >>> class DummyWorldModel(torch.nn.Module): + ... def __init__(self, obs_dim): + ... super().__init__() + ... self.obs_dim = obs_dim + ... def forward(self, action, observation): + ... # Assuming observation comes in as a dict with a "mean" key + ... mean = observation.get("mean", observation) + ... var = torch.eye(self.obs_dim).expand(*mean.shape[:-1], -1, -1) + ... return mean, var + >>> wm = TensorDictModule( + ... DummyWorldModel(obs_dim), + ... in_keys=["action", "observation"], + ... out_keys=[("next", "observation", "mean"), ("next", "observation", "var")], + ... ) + >>> imagined_env = ImaginedEnv(wm, base_env, next_observation_key=("next", "observation")) + >>> # Collect an imagined rollout + >>> rollout = imagined_env.rollout(max_steps=5, policy=RandomPolicy(imagined_env.action_spec)) + """ + + def __init__( + self, + world_model_module: TensorDictModule, + base_env: EnvBase, + batch_size: int | torch.Size | Sequence[int] | None = None, + next_observation_key: str | tuple[str, ...] = ("next", "observation"), + **kwargs, + ) -> None: + self.next_observation_key = next_observation_key + + if batch_size is not None: + batch_size = ( + torch.Size(batch_size) + if not isinstance(batch_size, torch.Size) + else batch_size + ) + elif len(base_env.batch_size) == 0: + batch_size = torch.Size([1]) + else: + batch_size = base_env.batch_size + + super().__init__( + world_model_module, + device=base_env.device, + batch_size=batch_size, + allow_done_after_reset=True, + **kwargs, + ) + + self.observation_spec = base_env.observation_spec.expand( + self.batch_size + ).clone() + self.action_spec = base_env.action_spec.expand(self.batch_size).clone() + self.reward_spec = base_env.reward_spec.expand(self.batch_size).clone() + self.done_spec = base_env.done_spec.expand(self.batch_size).clone() + + def any_done(self, tensordict) -> bool: + """Returns False -- imagination rollouts never terminate. + + Overridden to avoid CUDA sync from ``done.any()`` in the parent class. + """ + return False + + def maybe_reset(self, tensordict): + """No-op -- imagination rollouts do not need partial resets. + + Overridden to avoid CUDA sync from done checks in the parent class. + """ + return tensordict + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict = self.world_model(tensordict) + + reward = torch.zeros(*tensordict.shape, 1, device=self.device) + done = torch.zeros(*tensordict.shape, 1, dtype=torch.bool, device=self.device) + out = TensorDict( + { + "observation": tensordict.get(self.next_observation_key), + "reward": reward, + "done": done, + "terminated": done.clone(), + }, + tensordict.shape, + ) + return out + + def _reset( + self, tensordict: TensorDictBase | None = None, **kwargs + ) -> TensorDictBase: + if tensordict is None: + tensordict = TensorDict({}, batch_size=self.batch_size, device=self.device) + + if ( + tensordict.get(("observation", "var"), None) is not None + and tensordict.get(("observation", "mean"), None) is not None + ): + return tensordict.copy() + + obs = tensordict.get("observation", None) + if obs is None: + obs = self.observation_spec.rand(shape=self.batch_size).get("observation") + if obs.ndim == 1: + obs = obs.expand(self.batch_size[0], -1) + + obs = obs.to(self.device) + B, D = obs.shape + + out = TensorDict( + { + ("observation", "mean"): obs, + ("observation", "var"): torch.zeros( + B, D, D, dtype=obs.dtype, device=self.device + ), + }, + batch_size=self.batch_size, + device=self.device, + ) + + out.set("done", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) + out.set( + "terminated", + torch.zeros(B, 1, dtype=torch.bool, device=self.device), + ) + + return out diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 1a4230bf962..d1eafa8dc83 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .gym_transforms import EndOfLifeTransform +from .mean_action_selector import MeanActionSelector from .module import ModuleTransform from .r3m import R3MTransform from .ray_service import RayTransform @@ -103,6 +104,7 @@ "Hash", "InitTracker", "LineariseRewards", + "MeanActionSelector", "ModuleTransform", "MultiAction", "MultiStepTransform", diff --git a/torchrl/envs/transforms/mean_action_selector.py b/torchrl/envs/transforms/mean_action_selector.py new file mode 100644 index 00000000000..50c63eca91c --- /dev/null +++ b/torchrl/envs/transforms/mean_action_selector.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import TensorDictBase +from torchrl.data.tensor_specs import Composite, Unbounded +from torchrl.envs.transforms.transforms import Transform + + +class MeanActionSelector(Transform): + """Bridges Gaussian belief-space policies with standard environments. + + Gaussian policies used in moment-matching model-based RL (e.g. PILCO) operate + on state *beliefs* -- ``(mean, covariance)`` pairs -- and produce + action distributions with ``("action", "mean")``, ``("action", "var")``, etc. + This transform adapts a standard environment so that such a policy can be + used directly with :meth:`~torchrl.envs.EnvBase.rollout`: + + * **Forward** (env output -> policy input): wraps the flat ``"observation"`` + tensor into ``("observation", "mean")`` with a zero-covariance + ``("observation", "var")``, representing a deterministic state belief. + * **Inverse** (policy output -> env input): extracts ``("action", "mean")`` + from the policy output and writes it as the flat ``"action"`` for the + base environment step. + + Args: + observation_key (str, optional): The observation key to read from the + base environment. Defaults to ``"observation"``. + action_key (str, optional): The action key expected by the base + environment. Defaults to ``"action"``. + + Examples: + >>> import torch + >>> from torchrl.envs import GymEnv, TransformedEnv + >>> from torchrl.envs.transforms import MeanActionSelector + >>> base_env = GymEnv("Pendulum-v1") + >>> env = TransformedEnv(base_env, MeanActionSelector()) + >>> td = env.reset() + >>> # The policy now sees ("observation", "mean") and ("observation", "var") + >>> print(td["observation", "mean"].shape) + >>> print(td["observation", "var"].shape) + """ + + def __init__( + self, + observation_key: str = "observation", + action_key: str = "action", + ) -> None: + super().__init__( + in_keys=[observation_key], + out_keys=[(observation_key, "mean"), (observation_key, "var")], + in_keys_inv=[action_key], + out_keys_inv=[(action_key, "mean")], + ) + self._observation_key = observation_key + self._action_key = action_key + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + obs = tensordict.get(self._observation_key) + + is_nested = isinstance(obs, TensorDictBase) + if is_nested: + return tensordict + + batch_shape = obs.shape[:-1] + D = obs.shape[-1] + device = obs.device + dtype = obs.dtype + + tensordict.pop(self._observation_key) + + tensordict.set( + (self._observation_key, "mean"), + obs, + ) + tensordict.set( + (self._observation_key, "var"), + torch.zeros(*batch_shape, D, D, device=device, dtype=dtype), + ) + + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + action_mean = tensordict.get((self._action_key, "mean"), None) + if action_mean is not None: + tensordict.set(self._action_key, action_mean) + return tensordict + + def transform_observation_spec(self, observation_spec): + obs_spec = observation_spec[self._observation_key] + D = obs_spec.shape[-1] + observation_spec[self._observation_key] = Composite( + mean=obs_spec.clone(), + var=Unbounded(shape=(*obs_spec.shape, D), dtype=obs_spec.dtype), + shape=obs_spec.shape, + ) + return observation_spec + + def _reset(self, tensordict, tensordict_reset): + return self._call(tensordict_reset) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index ed0b0863fde..a790879dfa9 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -36,6 +36,7 @@ DreamerActor, DTActor, DuelingCnnDQNet, + GPWorldModel, MLP, MultiAgentConvNet, MultiAgentMLP, @@ -46,6 +47,7 @@ ObsEncoder, OnlineDTActor, QMixer, + RBFController, reset_noise, RSSMPosterior, RSSMPrior, @@ -136,6 +138,7 @@ "DreamerActor", "DuelingCnnDQNet", "EGreedyModule", + "GPWorldModel", "EGreedyWrapper", "GRU", "GRUCell", @@ -174,6 +177,7 @@ "PUCTScore", "QMixer", "QValueActor", + "RBFController", "QValueHook", "QValueModule", "RSSMPosterior", diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 98d34666cf8..b8d85025a44 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -16,6 +16,7 @@ NoisyLinear, reset_noise, ) +from .gp import GPWorldModel from .llm import GPT2RewardModel from .model_based import ( DreamerActor, @@ -46,24 +47,13 @@ QMixer, VDNMixer, ) +from .rbf_controller import RBFController from .utils import Squeeze2dLayer, SqueezeLayer __all__ = [ - "DistributionalDQNnet", "BatchRenorm1d", - "DecisionTransformer", - "GPT2RewardModel", "ConsistentDropout", "ConsistentDropoutModule", - "NoisyLazyLinear", - "NoisyLinear", - "reset_noise", - "DreamerActor", - "ObsDecoder", - "ObsEncoder", - "RSSMPosterior", - "RSSMPrior", - "RSSMRollout", "Conv2dNet", "Conv3dNet", "ConvNet", @@ -71,16 +61,30 @@ "DdpgCnnQNet", "DdpgMlpActor", "DdpgMlpQNet", + "DecisionTransformer", + "DistributionalDQNnet", + "DreamerActor", "DTActor", "DuelingCnnDQNet", "DuelingMlpDQNet", + "GPT2RewardModel", + "GPWorldModel", "MLP", - "OnlineDTActor", "MultiAgentConvNet", "MultiAgentMLP", "MultiAgentNetBase", + "NoisyLazyLinear", + "NoisyLinear", + "ObsDecoder", + "ObsEncoder", + "OnlineDTActor", "QMixer", - "VDNMixer", + "RBFController", + "RSSMPosterior", + "RSSMPrior", + "RSSMRollout", "Squeeze2dLayer", "SqueezeLayer", + "VDNMixer", + "reset_noise", ] diff --git a/torchrl/modules/models/gp.py b/torchrl/modules/models/gp.py new file mode 100644 index 00000000000..18ed35f37d4 --- /dev/null +++ b/torchrl/modules/models/gp.py @@ -0,0 +1,637 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# Variable naming follows Deisenroth & Rasmussen (2011), "PILCO: A Model-Based +# and Data-Efficient Approach to Policy Search" (cited inline as "Eq. N"). +# +# Key symbols +# ----------- +# x̃ := [x, u] concatenated state-action input (Eq. 1 / Sec. 2.1) +# Δ := x_t - x_{t-1} transition residual (Sec. 2.1) +# K_a Gram matrix K_{a,ij}=k_a(x̃_i,x̃_j) (Eq. 6) +# β_a := (K_a + σ²_ε I)^{-1}y_a GP weight vector (Eq. 7) +# q_a kernel-mean vector (Eq. 15) +# Q_{ab} cross-kernel matrix (Eqs. 21-22) +# μ̃ / Σ̃ joint state-action mean/cov (Sec. 2.2) +# μ_Δ / Σ_Δ predictive mean/cov of Δ (Eqs. 14, 17-23) +# μ_t / Σ_t next-state mean/cov (Eqs. 10-11) + +import importlib.util + +import torch +import torch.nn as nn +from tensordict import TensorDictBase + +_has_gpytorch = importlib.util.find_spec("gpytorch") is not None +_has_botorch = importlib.util.find_spec("botorch") is not None + + +class GPWorldModel(nn.Module): + """Gaussian Process world model with moment-matching uncertainty propagation. + + Implements the probabilistic dynamics model from PILCO + (Deisenroth & Rasmussen, 2011). One independent GP is fit per state + dimension, each predicting the transition residual + ``Δ = x_t - x_{t-1}`` from the concatenated state-action input + ``x̃ = [x, u]`` (Sec. 2.1). + + :meth:`forward` supports two modes depending on whether the input + observation carries non-zero variance: + + - **Deterministic**: uses the GP posterior mean and variance directly + (Eqs. 7-8). + - **Uncertain** (moment-matching): propagates a Gaussian belief + ``N(μ, Σ)`` through the GP analytically (Eqs. 10-23). + + .. note:: + Requires ``botorch`` and ``gpytorch`` as optional dependencies. + + Args: + obs_dim (int): Dimension D of the observation (state) space. + action_dim (int): Dimension F of the action (control) space. + in_keys (list of NestedKey, optional): Keys to read from the input + :class:`~tensordict.TensorDictBase`. Must contain five entries in + order: action mean, action covariance, state-action + cross-covariance, observation mean, observation covariance. + Defaults to ``[("action", "mean"), ("action", "var"), + ("action", "cross_covariance"), ("observation", "mean"), + ("observation", "var")]``. + out_keys (list of NestedKey, optional): Keys to write the predicted + next-state mean and covariance to. Defaults to + ``[("next", "observation", "mean"), + ("next", "observation", "var")]``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> model = GPWorldModel(obs_dim=4, action_dim=1) + >>> dataset = TensorDict( + ... { + ... "observation": torch.randn(50, 4), + ... "action": torch.randn(50, 1), + ... ("next", "observation"): torch.randn(50, 4), + ... }, + ... batch_size=[50], + ... ) + >>> model.fit(dataset) + + Reference: + Deisenroth, M. P. & Rasmussen, C. E. (2011). PILCO: A model-based + and data-efficient approach to policy search. *ICML*. + """ + + def __init__( + self, + obs_dim: int, + action_dim: int, + in_keys: list[str | tuple[str, ...]] | None = None, + out_keys: list[str | tuple[str, ...]] | None = None, + ) -> None: + if not _has_botorch or not _has_gpytorch: + raise ImportError( + "botorch and gpytorch are required to use GPWorldModel. " + "Please install them to proceed." + ) + super().__init__() + self.obs_dim = obs_dim # D in the paper + self.action_dim = action_dim # F in the paper + self.state_action_dim = obs_dim + action_dim # D+F, dimension of x̃ (Sec. 2.1) + + self.in_keys = ( + in_keys + if in_keys is not None + else [ + ("action", "mean"), + ("action", "var"), + ("action", "cross_covariance"), + ("observation", "mean"), + ("observation", "var"), + ] + ) + + self.out_keys = ( + out_keys + if out_keys is not None + else [ + ("next", "observation", "mean"), + ("next", "observation", "var"), + ] + ) + + self.model_list = None + + # X̃ = [x̃_1, ..., x̃_n] ∈ R^{n×(D+F)} – training inputs (Sec. 2.1) + self.register_buffer("X_tilde_train", torch.empty(0)) + + # ℓ_a ∈ R^{D+F} – ARD length-scales for each output dimension a (Eq. 6). + # Stored as [D, D+F]; the full matrix Λ_a = diag(ℓ_a²) is never + # materialised — ℓ_a is squared on the fly wherever needed. + # Note: GPyTorch's .lengthscale returns ℓ directly (not ℓ²). + self.register_buffer("ell", torch.zeros(self.obs_dim, self.state_action_dim)) + + # α²_a – signal variance for each output dimension a (Eq. 6); shape [D, 1] + self.register_buffer("alpha_sq", torch.zeros(self.obs_dim, 1)) + + # σ²_{ε_a} – noise variance for each output dimension a (Sec. 2.1); shape [D] + self.register_buffer("sigma_sq_eps", torch.zeros(self.obs_dim)) + + # (K_a + σ²_{ε_a} I)^{-1} – cached inverse Gram matrices (Eq. 7); shape [D, n, n]. + # Registered as buffers so they survive .to(device) and state_dict round-trips. + self.register_buffer("_cached_inv_K_noisy", None) + + # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a – GP weight vectors (Eq. 7); shape [D, n]. + # Registered as a buffer so it survives .to(device) and state_dict round-trips. + self.register_buffer("_cached_beta", None) + + @property + def device(self) -> torch.device: + return self.ell.device + + def fit(self, dataset: TensorDictBase) -> None: + """Fit one GP per state dimension to a dataset of transitions. + + Constructs training inputs ``X̃ = [x, u]`` and targets + ``Δ_a = x_{t,a} - x_{t-1,a}``, then maximises the marginal + log-likelihood to learn SE kernel hyper-parameters + (ℓ_a, α²_a, σ²_{ε_a}) for each output dimension (Sec. 2.1, Eq. 6). + + .. note:: + The dataset is expected to be flat with shape ``[n, *]``. If your + replay buffer returns multi-dimensional batches (e.g. ``[B, T, *]``), + call ``dataset.reshape(-1)`` before passing it here. + + Args: + dataset (TensorDictBase): Transition dataset with keys + ``"observation"`` of shape ``(n, D)``, + ``"action"`` of shape ``(n, F)``, and + ``("next", "observation")`` of shape ``(n, D)``. + """ + from botorch.fit import fit_gpytorch_mll + from botorch.models import ModelListGP, SingleTaskGP + from gpytorch.kernels import RBFKernel, ScaleKernel + from gpytorch.mlls import SumMarginalLogLikelihood + from gpytorch.priors import GammaPrior + + x_t_minus_1 = dataset["observation"] # x_{t-1} ∈ R^{n×D} + u_t_minus_1 = dataset["action"] # u_{t-1} ∈ R^{n×F} + x_t = dataset[("next", "observation")] # x_t ∈ R^{n×D} + + # x̃ = [x_{t-1}, u_{t-1}] ∈ R^{n×(D+F)} – training inputs (Sec. 2.1) + X_tilde_train = ( + torch.cat([x_t_minus_1, u_t_minus_1], dim=-1).detach().to(self.device) + ) + + # Δ ∈ R^{n×D}, Δ_{i,a} = x_{t,a} - x_{t-1,a} – training targets (Sec. 2.1) + Delta_train = (x_t - x_t_minus_1).detach().to(self.device) + + self.X_tilde_train = X_tilde_train + + models = [] + for a in range(self.obs_dim): + # Each GP_a models p(Δ_a | x̃) independently (Sec. 2.1) + Delta_a = Delta_train[:, a].unsqueeze(-1) # y_a ∈ R^{n×1} + + covar_module = ScaleKernel( + # SE kernel k_a(x̃, x̃') with ARD length-scales (one ℓ_{a,i} + # per input dimension, Eq. 6) + RBFKernel( + ard_num_dims=self.state_action_dim, + lengthscale_prior=GammaPrior(1.1, 0.1), + ), + outputscale_prior=GammaPrior(1.5, 0.5), # prior on α²_a (Eq. 6) + ) + + gp_a = SingleTaskGP( + train_X=X_tilde_train, + train_Y=Delta_a, + covar_module=covar_module, + ) + gp_a.likelihood.noise_covar.register_prior( + "noise_prior", + GammaPrior(1.2, 0.05), + "noise", # prior on σ²_{ε_a} (Sec. 2.1) + ) + + models.append(gp_a) + + self.model_list = ModelListGP(*models).to(self.device) + mll = SumMarginalLogLikelihood(self.model_list.likelihood, self.model_list) + + fit_gpytorch_mll(mll) # evidence maximisation (Sec. 2.1) + self._extract_and_cache_parameters(Delta_train) + + def _extract_and_cache_parameters(self, Delta_train: torch.Tensor) -> None: + # Extract learned hyper-parameters from each GP_a and pre-compute the + # quantities that are fixed after fitting: + # ℓ_a, α²_a, σ²_{ε_a} (Eq. 6 / Sec. 2.1) + # (K_a + σ²_{ε_a} I)^{-1} (Eq. 7) + # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a (Eq. 7) + ell_list, alpha_sq_list, sigma_sq_eps_list = [], [], [] + inv_K_noisy_list, beta_list = [], [] + + n = self.X_tilde_train.shape[0] # number of training points + + for a, gp_a in enumerate(self.model_list.models): + gp_a.eval() + gp_a.likelihood.eval() + + # ℓ_a ∈ R^{D+F} – ARD length-scales for GP_a (Eq. 6). + # GPyTorch's .lengthscale returns ℓ directly (not ℓ²). + ell_a = gp_a.covar_module.base_kernel.lengthscale.squeeze().detach() + + # α²_a – signal variance for GP_a (Eq. 6) + alpha_sq_a = gp_a.covar_module.outputscale.detach() + + # σ²_{ε_a} – noise variance for GP_a (Sec. 2.1) + sigma_sq_eps_a = gp_a.likelihood.noise.squeeze().detach() + + ell_list.append(ell_a) + alpha_sq_list.append(alpha_sq_a) + sigma_sq_eps_list.append(sigma_sq_eps_a) + + # K_{a,ij} = α²_a exp(-½ (x̃_i-x̃_j)^T Λ_a^{-1} (x̃_i-x̃_j)) (Eq. 6) + # Dividing X̃ by ℓ_a gives Λ_a^{-1/2}-scaled inputs for cdist. + X_tilde_scaled = self.X_tilde_train / ell_a + sq_dist = torch.cdist(X_tilde_scaled, X_tilde_scaled, p=2) ** 2 + K_a = alpha_sq_a * torch.exp(-0.5 * sq_dist) + + # K_{a,noisy} = K_a + σ²_{ε_a} I (denominator in Eq. 7) + K_a_noisy = K_a + (sigma_sq_eps_a + 1e-6) * torch.eye(n, device=self.device) + + L_a = torch.linalg.cholesky(K_a_noisy) + eye_n = torch.eye(n, dtype=L_a.dtype, device=L_a.device) + + # (K_a + σ²_{ε_a} I)^{-1} (Eq. 7) + inv_K_a_noisy = torch.cholesky_solve(eye_n, L_a) + + # y_a = [Δ_{1,a}, ..., Δ_{n,a}]^T – targets for GP_a (Sec. 2.1) + y_a = Delta_train[:, a].unsqueeze(-1) + + # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a (Eq. 7) + beta_a = torch.cholesky_solve(y_a, L_a).squeeze(-1) + + inv_K_noisy_list.append(inv_K_a_noisy) + beta_list.append(beta_a) + + self.ell = torch.stack(ell_list) # [D, D+F] + self.alpha_sq = torch.stack(alpha_sq_list).unsqueeze(-1) # [D, 1] + self.sigma_sq_eps = torch.stack(sigma_sq_eps_list) # [D] + self._cached_inv_K_noisy = torch.stack(inv_K_noisy_list) # [D, n, n] + self._cached_beta = torch.stack(beta_list) # [D, n] + + def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: + """Return the cached kernel inverses and GP weight vectors. + + Returns: + tuple[Tensor, Tensor]: A pair ``(inv_K_noisy, beta)`` where + ``inv_K_noisy`` has shape ``(D, n, n)`` and contains + ``(K_a + σ²_{ε_a} I)^{-1}`` for each output dimension (Eq. 7), + and ``beta`` has shape ``(D, n)`` and contains + ``β_a = (K_a + σ²_{ε_a} I)^{-1} y_a`` (Eq. 7). + """ + return self._cached_inv_K_noisy, self._cached_beta + + def _gather_gp_hyperparams(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Returns (ell, alpha_sq, sigma_sq_eps) — the SE kernel hyper-parameters + # for each GP_a (Eq. 6 / Sec. 2.1): + # ell: ℓ_{a,i}, shape [D, D+F] (ℓ, not ℓ²) + # alpha_sq: α²_a, shape [D, 1] + # sigma_sq_eps: σ²_{ε_a}, shape [D] + return self.ell, self.alpha_sq, self.sigma_sq_eps + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Predict the next-state distribution given the current state and action. + + Routes to :meth:`uncertain_forward` (moment-matching, Eqs. 10-23) when + the input observation covariance is non-zero, and to + :meth:`deterministic_forward` (Eqs. 7-8) otherwise. + + Args: + tensordict (TensorDictBase): Input tensordict containing keys + defined by ``in_keys``. Observation and action tensors may be + unbatched ``(D,)`` / ``(F,)`` or batched ``(B, D)`` / + ``(B, F)``; a leading batch dimension will be added and removed + automatically for unbatched inputs. The observation covariance, + when present, must be a full matrix of shape ``(..., D, D)`` + — per-dimension variance vectors are not accepted; use + :func:`torch.diag_embed` to convert them first. + + Returns: + TensorDictBase: The same tensordict, updated in-place with the + predicted next-state mean and covariance written to ``out_keys``. + """ + u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys + + Sigma_x = tensordict.get(x_var_key, None) + if Sigma_x is not None and Sigma_x.dim() < 2: + raise ValueError( + f"Expected observation covariance to have at least 2 dimensions " + f"(..., D, D), got shape {tuple(Sigma_x.shape)}. " + "Convert per-dimension variances with torch.diag_embed() first." + ) + + observation_uncertain = Sigma_x is not None and not torch.all( + torch.isclose(Sigma_x, torch.zeros_like(Sigma_x)) + ) + + if observation_uncertain: + return self.uncertain_forward(tensordict) + else: + return self.deterministic_forward(tensordict) + + def uncertain_forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Moment-matching forward pass for a Gaussian input belief (Eqs. 10-23). + + Propagates the joint Gaussian belief + ``p(x̃_{t-1}) = N(μ̃_{t-1}, Σ̃_{t-1})`` (Sec. 2.2) through the GP + dynamics model and returns a Gaussian approximation to ``p(x_t)`` + via exact moment matching. + + Args: + tensordict (TensorDictBase): Input tensordict with keys defined by + ``in_keys``. Supports unbatched ``(D,)`` inputs or batched + inputs with a single leading batch dimension ``(B, D)``. + + Returns: + TensorDictBase: The same tensordict updated with next-state mean + ``μ_t`` (Eq. 10) and covariance ``Σ_t`` (Eq. 11) at ``out_keys``. + """ + inv_K_noisy, beta = self.compute_factorizations() + ell, alpha_sq, sigma_sq_eps = self._gather_gp_hyperparams() + u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys + + mu_x = tensordict.get(x_mean_key) # μ_x, shape (B×)D + Sigma_x = tensordict.get(x_var_key) # Σ_x, shape (B×)D×D + mu_u = tensordict.get(u_mean_key) # μ_u, shape (B×)F + Sigma_u = tensordict.get(u_var_key) # Σ_u, shape (B×)F×F + C_xu = tensordict.get(u_cc_key) # cov[x_{t-1}, u_{t-1}], (B×)D×F (Eq. 12) + + # Support unbatched inputs by temporarily adding a leading batch dimension. + unbatched = mu_x.dim() == 1 + if unbatched: + mu_x, Sigma_x, mu_u, Sigma_u, C_xu = ( + mu_x.unsqueeze(0), + Sigma_x.unsqueeze(0), + mu_u.unsqueeze(0), + Sigma_u.unsqueeze(0), + C_xu.unsqueeze(0), + ) + + device, dtype = mu_x.device, mu_x.dtype + B = mu_x.shape[0] # batch size + n = self.X_tilde_train.shape[0] # number of training points + D = self.obs_dim # state dimension + DF = self.state_action_dim # D+F, dimension of x̃ + + # ---- Build joint state-action distribution p(x̃_{t-1}) (Sec. 2.2) ---- + # μ̃_{t-1} = [μ_x; μ_u] ∈ R^{B×(D+F)} + mu_tilde = torch.cat([mu_x, mu_u], dim=-1) + + # Σ̃_{t-1} = [[Σ_x, Σ_x C_xu ], + # [C_xu^T Σ_x^T, Σ_u ]] ∈ R^{B×(D+F)×(D+F)} + Sigma_x_C_xu = Sigma_x @ C_xu # upper-right block [B, D, F] + Sigma_tilde = torch.cat( + [ + torch.cat([Sigma_x, Sigma_x_C_xu], dim=-1), + torch.cat([Sigma_x_C_xu.transpose(-1, -2), Sigma_u], dim=-1), + ], + dim=-2, + ) # [B, D+F, D+F] + + # ---- Compute q_a (mean-prediction kernel vector, Eq. 15) ---- + # ν_i = x̃_i - μ̃_{t-1} (Eq. 16); shape [B, n, D+F] + nu = self.X_tilde_train - mu_tilde.unsqueeze(1) + + # Λ_a^{-1} as diagonal matrices; shape [D, D+F, D+F]. + # ell stores ℓ_a (not ℓ²_a), so 1/ℓ_a gives the diagonal of Λ_a^{-1/2}; + # used here to form the full Λ_a^{-1} = diag(1/ℓ²_a) = diag(1/ℓ_a)². + inv_Lambda_diag_mats = torch.diag_embed(1.0 / ell).to( + device=device, dtype=dtype + ) + + # Λ_a^{-1} ν_i; shape [B, D, n, D+F] + inv_Lambda_nu = nu.unsqueeze(1) @ inv_Lambda_diag_mats.unsqueeze(0) + + # R_a = Λ_a^{-1} Σ̃_{t-1} Λ_a^{-1} + I – normalising matrix in Eq. 15; + # shape [B, D, D+F, D+F] + R_a = ( + inv_Lambda_diag_mats.unsqueeze(0) + @ Sigma_tilde.unsqueeze(1) + @ inv_Lambda_diag_mats.unsqueeze(0) + ) + R_a = R_a + torch.eye(DF, device=device, dtype=dtype).view(1, 1, DF, DF) + + # Solve R_a t = (Λ_a^{-1} ν_i)^T → t = R_a^{-1} Λ_a^{-1} ν_i^T + t = torch.linalg.solve(R_a, inv_Lambda_nu.transpose(-2, -1)).transpose(-2, -1) + + # exp(-½ ν_i^T (Σ̃ + Λ_a)^{-1} ν_i) – exponent in Eq. 15; shape [B, D, n] + scaled_exp = torch.exp(-0.5 * torch.sum(inv_Lambda_nu * t, dim=-1)) + + # Scalar prefactor α²_a / sqrt(|Σ̃_{t-1} Λ_a^{-1} + I|) from Eq. 15; shape [B, D] + det_R_a = torch.linalg.det(R_a) + c_a = alpha_sq.squeeze(-1).unsqueeze(0) / torch.sqrt(det_R_a) + + # β_a ⊙ q_a (pointwise); shape [B, D, n] + beta_q_a = scaled_exp * beta.unsqueeze(0) + + # μ^a_Δ = β_a^T q_a (Eq. 14); shape [B, D] + mu_Delta = torch.sum(beta_q_a, dim=-1) * c_a.squeeze(0) + + # ---- Cross-covariance cov[x̃_{t-1}, Δ_t] (used in Eq. 12) ---- + # Derivative of μ_Δ w.r.t. μ̃, contracted with Σ̃ (Deisenroth 2010); + # shape [B, D+F, D] + t_inv_Lambda = t @ inv_Lambda_diag_mats.unsqueeze(0) + cov_xtilde_Delta = ( + torch.matmul( + t_inv_Lambda.transpose(-2, -1), beta_q_a.unsqueeze(-1) + ).squeeze(-1) + * c_a.unsqueeze(-1) + ).transpose(-2, -1) + + # ---- Compute Q_{ab} (cross-kernel matrix, Eqs. 21-22) ---- + X_i = self.X_tilde_train.unsqueeze(1) # [n, 1, D+F] + X_j = self.X_tilde_train.unsqueeze(0) # [1, n, D+F] + diff_ij = X_i - X_j # x̃_i - x̃_j; [n, n, D+F] (Eq. 22) + + # ell stores ℓ_a; ℓ²_a is the diagonal of Λ_a (Eq. 6) + ell_sq_a = (ell**2)[:, None, :] # [D, 1, D+F] + ell_sq_b = (ell**2)[None, :, :] # [1, D, D+F] + + # Λ_{ab} = (Λ_a^{-1} + Λ_b^{-1})^{-1}, diagonal entries; [D, D, D+F] + inv_ell_sq_sum = 1.0 / ell_sq_a + 1.0 / ell_sq_b + Lambda_ab = 1.0 / inv_ell_sq_sum + + # First exponential in Q_{ab,ij}: kernel product at training inputs (Eq. 22) + # -½ (x̃_i - x̃_j)^T (Λ_a + Λ_b)^{-1} (x̃_i - x̃_j); shape [D, D, n, n] + inv_ell_sq_sum_ab = 1.0 / (ell_sq_a + ell_sq_b) + exp1 = -0.5 * torch.sum( + diff_ij.unsqueeze(0).unsqueeze(0) + * inv_ell_sq_sum_ab.unsqueeze(2).unsqueeze(2) + * diff_ij.unsqueeze(0).unsqueeze(0), + dim=-1, + ) # [D, D, n, n] + + # z̄_{ij} = Λ_{ab} (Λ_a^{-1} x̃_i + Λ_b^{-1} x̃_j) – midpoint (Eq. 22); + # shape [D, D, n, n, D+F] + z_bar = Lambda_ab.unsqueeze(2).unsqueeze(2) * ( + X_i.unsqueeze(0).unsqueeze(0) / ell_sq_a.unsqueeze(2).unsqueeze(2) + + X_j.unsqueeze(0).unsqueeze(0) / ell_sq_b.unsqueeze(2).unsqueeze(2) + ) + + # z_{ij} = z̄_{ij} - μ̃_{t-1}; shape [B, D, D, n, n, D+F] + z_bar = z_bar.unsqueeze(0).expand(B, -1, -1, -1, -1, -1) + z_ij = z_bar - mu_tilde[:, None, None, None, None, :] + z_ij_flat = z_ij.view(B, D, D, n * n, DF) + + # M_{ab} = Σ̃_{t-1} + diag(Λ_{ab}) – matrix in second exp of Eq. 22; + # shape [B, D, D, D+F, D+F] + M_ab = Sigma_tilde[:, None, None] + torch.diag_embed(Lambda_ab) + + # Second exponential: -½ z_{ij}^T M_{ab}^{-1} z_{ij}; shape [B, D, D, n, n] + M_ab_solved = torch.linalg.solve(M_ab, z_ij_flat.transpose(-2, -1)).transpose( + -2, -1 + ) + exp2 = (-0.5 * torch.sum(z_ij_flat * M_ab_solved, dim=-1)).view(B, D, D, n, n) + + # R_{ab} = Σ̃_{t-1} (Λ_a^{-1} + Λ_b^{-1}) + I – normalising matrix (Eq. 22); + # shape [B, D, D, D+F, D+F] + R_ab = Sigma_tilde[:, None, None] @ torch.diag_embed( + inv_ell_sq_sum + ) + torch.eye(DF, device=device, dtype=dtype) + det_R_ab = torch.linalg.det(R_ab) # [B, D, D] + + # Scalar prefactor α²_a α²_b / sqrt(|R_{ab}|) (Eq. 22); shape [B, D, D] + c_ab = (alpha_sq.view(1, D, 1) * alpha_sq.view(1, 1, D)) / torch.sqrt(det_R_ab) + + # Q_{ab,ij} (Eq. 22); shape [B, D, D, n, n] + Q_ab = c_ab.unsqueeze(-1).unsqueeze(-1) * torch.exp(exp1.unsqueeze(0) + exp2) + + # ---- Σ_Δ = predictive covariance of Δ (Eqs. 17-23) ---- + # Off-diagonal entries: σ²_{ab} = β_a^T Q_{ab} β_b - μ^a_Δ μ^b_Δ (Eqs. 18, 20) + beta_a = beta.view(1, D, 1, n) # [1, D, 1, n] + beta_b = beta.view(1, 1, D, n) # [1, 1, D, n] + + Q_ab_beta_b = torch.matmul(Q_ab, beta_b.unsqueeze(-1)).squeeze( + -1 + ) # [B, D, D, n] + Sigma_Delta = ( + torch.matmul(beta_a.unsqueeze(-2), Q_ab_beta_b.unsqueeze(-1)) + .squeeze(-1) + .squeeze(-1) + ) # [B, D, D] – β_a^T Q_{ab} β_b (Eq. 20) + + # Diagonal correction E_{x̃}[var_f[Δ_a | x̃]] = α²_a - tr(K_a^{-1} Q_{aa}) + # added to σ²_{aa} (Eqs. 17, 23) + invK_Q = torch.matmul( + inv_K_noisy.unsqueeze(0).unsqueeze(2), # [1, D, 1, n, n] + Q_ab, # [B, D, D, n, n] + ) # [B, D, D, n, n] + trace_invK_Q = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) # [B, D, D] + + diag_idx = torch.arange(D, device=device) + alpha_sq_b = alpha_sq.squeeze(-1).unsqueeze(0).expand(B, -1) # [B, D] + sigma_sq_eps_b = sigma_sq_eps.unsqueeze(0).expand(B, -1) # [B, D] + + # Add α²_a - tr(K_a^{-1} Q_{aa}) + σ²_{ε_a} to the diagonal (Eqs. 17, 23) + Sigma_Delta[:, diag_idx, diag_idx] += ( + alpha_sq_b - trace_invK_Q[:, diag_idx, diag_idx] + sigma_sq_eps_b + ) + + # Subtract outer product of means: Σ_Δ -= μ_Δ μ_Δ^T (Eqs. 17-18) + Sigma_Delta = Sigma_Delta - torch.bmm( + mu_Delta.unsqueeze(-1), mu_Delta.unsqueeze(-2) + ) + Sigma_Delta = ( + Sigma_Delta + Sigma_Delta.transpose(-2, -1) + ) / 2 # enforce symmetry + + # ---- Propagate to next-state belief (Eqs. 10-12) ---- + # cov[x_{t-1}, Δ_t] = cov[x_{t-1}, x̃_{t-1}] · cov_xtilde_Delta (Eq. 12) + # cov[x_{t-1}, x̃_{t-1}] is the top-D rows of Σ̃_{t-1}: shape [B, D, D+F]. + # Using only Sigma_x_C_xu ([B, D, F]) here would be wrong — it drops + # the Σ_x block and produces a [B, D, F] @ [B, D+F, D] shape mismatch. + Sigma_x_rows = Sigma_tilde[:, :D, :] # [B, D, D+F] + cov_x_Delta = Sigma_x_rows @ cov_xtilde_Delta # [B, D, D] + + # μ_t = μ_{t-1} + μ_Δ (Eq. 10) + mu_t = mu_x + mu_Delta + + # Σ_t = Σ_{t-1} + Σ_Δ + cov[x_{t-1},Δ_t] + cov[Δ_t,x_{t-1}] (Eq. 11) + Sigma_t = Sigma_x + Sigma_Delta + cov_x_Delta + cov_x_Delta.transpose(-2, -1) + Sigma_t = (Sigma_t + Sigma_t.transpose(-2, -1)) / 2 # enforce symmetry + Sigma_t = Sigma_t + 1e-8 * torch.eye(D, device=device).expand( + B, -1, -1 + ) # jitter + + if unbatched: + mu_t = mu_t.squeeze(0) + Sigma_t = Sigma_t.squeeze(0) + + out_mean_key, out_var_key = self.out_keys + tensordict.set(out_mean_key, mu_t) + tensordict.set(out_var_key, Sigma_t) + return tensordict + + def deterministic_forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Deterministic forward pass using GP posterior mean and variance (Eqs. 7-8). + + Used when the input observation is a point estimate with no uncertainty. + Returns the GP posterior mean ``m_f(x̃_*)`` (Eq. 7) and per-dimension + variance ``σ²_f(x̃_*)`` (Eq. 8) for each state dimension. + + Args: + tensordict (TensorDictBase): Input tensordict with keys defined by + ``in_keys``. Supports arbitrary leading batch dimensions + ``(*batch, D)`` / ``(*batch, F)``, as well as unbatched + ``(D,)`` / ``(F,)`` inputs. + + Returns: + TensorDictBase: The same tensordict updated with next-state mean + ``μ_t`` and diagonal covariance ``Σ_t = diag(σ²_Δ)`` at + ``out_keys``. + """ + u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys + mu_x = tensordict.get(x_mean_key) # x_{t-1}, shape (*batch, D) or (D,) + mu_u = tensordict.get(u_mean_key) # u_{t-1}, shape (*batch, F) or (F,) + + batch_shape = mu_x.shape[:-1] # leading dims; () for unbatched inputs + + # Flatten all leading batch dimensions to a single axis for the GP + # posterior call, then restore the original shape afterwards. + x_flat = mu_x.reshape(-1, self.obs_dim) # [B_flat, D] + u_flat = mu_u.reshape(-1, self.action_dim) # [B_flat, F] + + # x̃_* = [x_{t-1}, u_{t-1}] ∈ R^{B_flat×(D+F)} (Sec. 2.1) + X_tilde_test = torch.cat([x_flat, u_flat], dim=-1) + + # GP posterior mean m_f(x̃_*) (Eq. 7) and std σ_f(x̃_*) (Eq. 8) + mu_Delta_list, sigma_Delta_list = [], [] + + with torch.no_grad(): + for gp_a in self.model_list.models: + posterior_a = gp_a.posterior(X_tilde_test) + mu_Delta_list.append(posterior_a.mean.squeeze(-1)) # m_f (Eq. 7) + sigma_Delta_list.append( + torch.sqrt(posterior_a.variance).squeeze(-1) # σ_f (Eq. 8) + ) + + # μ_Δ – predicted residual mean; restore original batch shape + mu_Delta = torch.stack(mu_Delta_list, dim=-1).view(*batch_shape, self.obs_dim) + + # σ_Δ – predicted residual std; restore original batch shape + sigma_Delta = torch.stack(sigma_Delta_list, dim=-1).view( + *batch_shape, self.obs_dim + ) + + # μ_t = x_{t-1} + μ_Δ (deterministic version of Eq. 10) + mu_t = mu_x + mu_Delta + + # Σ_t = diag(σ²_Δ) – diagonal covariance from independent GP variances (Eq. 8) + Sigma_t = torch.diag_embed(sigma_Delta**2) + + out_mean_key, out_var_key = self.out_keys + tensordict.set(out_mean_key, mu_t) + tensordict.set(out_var_key, Sigma_t) + return tensordict diff --git a/torchrl/modules/models/rbf_controller.py b/torchrl/modules/models/rbf_controller.py new file mode 100644 index 00000000000..5a490bcdf11 --- /dev/null +++ b/torchrl/modules/models/rbf_controller.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +import torch.nn as nn + + +class RBFController(nn.Module): + """Radial Basis Function controller for moment-matching policy search. + + Implements a policy that maps Gaussian-distributed state beliefs + ``(mean, covariance)`` to Gaussian-distributed actions using an RBF network + followed by a sinusoidal squashing function. The moment-matching formulas + allow analytic gradient computation through the policy during model-based + optimization (e.g., PILCO). + + The controller uses ``n_basis`` RBF basis functions, each parameterised + by a centre vector and a shared diagonal lengthscale. The output is a + weighted sum of basis activations, optionally squashed through + :meth:`squash_sin` to enforce action bounds. + + Reference: Deisenroth & Rasmussen, "PILCO: A Model-Based and Data-Efficient + Approach to Policy Search", ICML 2011. + + Args: + input_dim (int): Dimensionality of the state (observation) space. + output_dim (int): Dimensionality of the action space. + max_action (float or Tensor): Element-wise upper bound on action + magnitude. When provided, actions are squashed through + :meth:`squash_sin`. + n_basis (int, optional): Number of RBF basis functions. + Defaults to ``10``. + + Inputs: + mean (Tensor): State mean of shape ``(*batch, input_dim)``. + covariance (Tensor): State covariance of shape + ``(*batch, input_dim, input_dim)``. + + Returns: + action_mean (Tensor): Action mean of shape ``(*batch, output_dim)``. + action_covariance (Tensor): Action covariance of shape + ``(*batch, output_dim, output_dim)``. + cross_covariance (Tensor): Input–output cross-covariance of shape + ``(*batch, input_dim, output_dim)``. + + Examples: + >>> import torch + >>> controller = RBFController(input_dim=4, output_dim=1, max_action=2.0, n_basis=5) + >>> mean = torch.randn(2, 4) + >>> covariance = torch.eye(4).unsqueeze(0).expand(2, -1, -1) * 0.1 + >>> action_mean, action_cov, cross_cov = controller(mean, covariance) + >>> action_mean.shape + torch.Size([2, 1]) + >>> action_cov.shape + torch.Size([2, 1, 1]) + >>> cross_cov.shape + torch.Size([2, 4, 1]) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + max_action: float | torch.Tensor, + n_basis: int = 10, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.max_action = max_action + self.n_basis = n_basis + + self.centers = nn.Parameter(torch.randn(n_basis, input_dim) * 0.5) + self.weights = nn.Parameter(torch.randn(n_basis, output_dim) * 0.1) + self.lengthscales = nn.Parameter(torch.ones(input_dim)) + self.variance = 1.0 + + @staticmethod + def squash_sin( + mean: torch.Tensor, + covariance: torch.Tensor, + max_action: float | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Propagates a Gaussian through an element-wise ``max_action * sin(x)`` squashing. + + Computes the exact moments of the transformed distribution using + the moment-matching identities for sine applied to Gaussian inputs. + + Args: + mean (Tensor): Input mean, shape ``(*batch, K)``. + covariance (Tensor): Input covariance, shape ``(*batch, K, K)``. + max_action (float or Tensor): Per-dimension action bound. + + Returns: + squashed_mean (Tensor): Output mean, shape ``(*batch, K)``. + squashed_covariance (Tensor): Output covariance, shape ``(*batch, K, K)``. + cross_covariance (Tensor): Input–output cross-covariance, shape ``(*batch, K, K)``. + """ + K = mean.shape[-1] + device = mean.device + dtype = mean.dtype + + if not isinstance(max_action, torch.Tensor): + max_action = torch.tensor(max_action, dtype=dtype, device=device) + + max_action = max_action.view(-1) + if max_action.shape[0] == 1 and K > 1: + max_action = max_action.expand(K) + + diag_cov = torch.diagonal(covariance, dim1=-2, dim2=-1) + + squashed_mean = max_action * torch.exp(-diag_cov / 2.0) * torch.sin(mean) + + lq = -(diag_cov.unsqueeze(-1) + diag_cov.unsqueeze(-2)) / 2.0 + q = torch.exp(lq) + + mean_diff = mean.unsqueeze(-1) - mean.unsqueeze(-2) + mean_sum = mean.unsqueeze(-1) + mean.unsqueeze(-2) + + squashed_covariance = (torch.exp(lq + covariance) - q) * torch.cos( + mean_diff + ) - (torch.exp(lq - covariance) - q) * torch.cos(mean_sum) + + outer_max = max_action.unsqueeze(-2) * max_action.unsqueeze(-1) + squashed_covariance = outer_max * squashed_covariance / 2.0 + + cross_covariance = torch.diag_embed( + max_action * torch.exp(-diag_cov / 2.0) * torch.cos(mean) + ) + + return squashed_mean, squashed_covariance, cross_covariance + + def forward( + self, mean: torch.Tensor, covariance: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_shape = mean.shape[:-1] + D = mean.shape[-1] + N = self.n_basis + device = mean.device + + # Flatten batch dimensions for computation + mean_flat = mean.reshape(-1, D) + covariance_flat = covariance.reshape(-1, D, D) + B = mean_flat.shape[0] + + inv_lengthscale = torch.diag(1.0 / self.lengthscales) + inv_lengthscale_batch = inv_lengthscale.unsqueeze(0) + + inp = self.centers.unsqueeze(0) - mean_flat.unsqueeze(1) + + B_mat = ( + inv_lengthscale_batch @ covariance_flat @ inv_lengthscale_batch + + torch.eye(D, device=device, dtype=mean.dtype).unsqueeze(0) + ) + + scaled_inp = inp @ inv_lengthscale + + t = torch.linalg.solve(B_mat, scaled_inp.mT).mT + + exp_term = torch.exp(-0.5 * torch.sum(scaled_inp * t, dim=-1)) + log_det_sign, log_det = torch.linalg.slogdet(B_mat) + normalizer = self.variance * torch.exp(-0.5 * log_det) + phi_mean = normalizer.unsqueeze(-1) * exp_term + + action_mean = phi_mean @ self.weights + + t_scaled = t @ inv_lengthscale + cross_cov = torch.bmm(t_scaled.mT, phi_mean.unsqueeze(-1) * self.weights) + + # Pairwise basis covariance (Eq. A.42–A.45 in Deisenroth thesis) + centers_i = self.centers.unsqueeze(1) + centers_j = self.centers.unsqueeze(0) + diff = centers_i - centers_j + center_bar = (centers_i + centers_j) / 2.0 + + inv_lambda = 1.0 / (self.lengthscales**2) + exp1 = -0.25 * torch.sum((diff**2) * inv_lambda, dim=-1) + + lambda_half = torch.diag((self.lengthscales**2) / 2.0) + B_q = covariance_flat + lambda_half.unsqueeze(0) + + z = center_bar.unsqueeze(0) - mean_flat.unsqueeze(1).unsqueeze(1) + z_flat = z.view(B, N * N, D) + + solved_z_flat = torch.linalg.solve(B_q, z_flat.mT).mT + exp2 = -0.5 * torch.sum(z_flat * solved_z_flat, dim=-1).view(B, N, N) + + log_det_lambda_half = torch.sum(torch.log((self.lengthscales**2) / 2.0)) + _, log_det_bq = torch.linalg.slogdet(B_q) + c_q = torch.exp(0.5 * (log_det_lambda_half - log_det_bq)) + + Q = (self.variance**2 * c_q.view(B, 1, 1)) * torch.exp( + exp1.unsqueeze(0) + exp2 + ) + + W_batch = self.weights.unsqueeze(0).expand(B, N, -1) + action_cov = torch.bmm(W_batch.mT, torch.bmm(Q, W_batch)) + + outer_mean = torch.bmm(action_mean.unsqueeze(-1), action_mean.unsqueeze(1)) + action_cov = action_cov - outer_mean + + action_cov = (action_cov + action_cov.mT) / 2.0 + action_cov = ( + action_cov + + torch.eye(self.output_dim, device=device, dtype=mean.dtype).unsqueeze(0) + * 1e-6 + ) + + if self.max_action is not None: + action_mean, action_cov, C = self.squash_sin( + action_mean, action_cov, self.max_action + ) + cross_cov = torch.bmm(cross_cov, C) + + # Reshape back to original batch shape + action_mean = action_mean.reshape(*batch_shape, self.output_dim) + action_cov = action_cov.reshape(*batch_shape, self.output_dim, self.output_dim) + cross_cov = cross_cov.reshape(*batch_shape, D, self.output_dim) + + return action_mean, action_cov, cross_cov diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 2df2da650ca..f8e47d73519 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -18,6 +18,7 @@ from torchrl.objectives.gail import GAILLoss from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss from torchrl.objectives.multiagent import QMixerLoss +from torchrl.objectives.pilco import ExponentialQuadraticCost from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss @@ -52,6 +53,7 @@ "DreamerActorLoss", "DreamerModelLoss", "DreamerValueLoss", + "ExponentialQuadraticCost", "GAILLoss", "HardUpdate", "IQLLoss", diff --git a/torchrl/objectives/pilco.py b/torchrl/objectives/pilco.py new file mode 100644 index 00000000000..ae523415b43 --- /dev/null +++ b/torchrl/objectives/pilco.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass + +import torch +from tensordict import TensorDict, TensorDictBase +from torchrl.objectives.common import LossModule + + +class ExponentialQuadraticCost(LossModule): + """Computes the expected saturating cost for a Gaussian-distributed state. + + This serves as a smooth, unimodal approximation of a 0-1 cost over a target area, + allowing for analytic gradient computation during policy search (e.g., PILCO). + Calculates E_{x_t}[c(x_t)] over N(m, s) as defined in Eq. (24) and (25) of + Deisenroth & Rasmussen (2011). + + Args: + target (torch.Tensor, optional): The target state vector. Defaults to the origin. + weights (torch.Tensor, optional): The precision matrix mapping state dimensions + to the cost distance metric. Defaults to the identity matrix. + reduction (str, optional): Specifies the reduction to apply to the output: + 'mean' | 'sum' | 'none'. Defaults to 'mean'. + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for configurable tensordict keys.""" + + loc: str | tuple[str, ...] = ("observation", "mean") + scale: str | tuple[str, ...] = ("observation", "var") + loss_cost: str | tuple[str, ...] = "loss_cost" + + default_keys = _AcceptedKeys + + def __init__( + self, + target: torch.Tensor | None = None, + weights: torch.Tensor | None = None, + reduction: str = "mean", + ): + super().__init__() + self._tensor_keys = self._AcceptedKeys() + self.reduction = reduction + + self.register_buffer("target", target) + self.register_buffer("weights", weights) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + m = tensordict.get(self.tensor_keys.loc) + s = tensordict.get(self.tensor_keys.scale) + + batch_shape = m.shape[:-1] + D = m.shape[-1] + device = m.device + dtype = m.dtype + + weights = ( + self.weights + if self.weights is not None + else torch.eye(D, device=device, dtype=dtype) + ) + target = ( + self.target + if self.target is not None + else torch.zeros(D, device=device, dtype=dtype) + ) + + if target.dim() == 1: + target_shape = (*[1] * len(batch_shape), D) + target = target.view(*target_shape).expand(*batch_shape, D) + + eye = torch.eye(D, device=device, dtype=dtype) + eye_batch = eye.view(*[1] * len(batch_shape), D, D) + + # diff: Distance from the current mean to the target (x - x_target) + diff = (m - target).unsqueeze(-1) + + # L_w, V_w: Eigenvalues and eigenvectors of the precision weight matrix + L_w, V_w = torch.linalg.eigh(weights) + L_w = torch.clamp(L_w, min=0.0) + + # U: Scaled transformation matrix for the cost weighting + U = V_w @ torch.diag_embed(torch.sqrt(L_w)) @ V_w.transpose(-2, -1) + + # A_sym: Covariance transformation required for computing the expected cost integral + # U is (D, D), s is (*batch_shape, D, D) + A_sym = eye_batch + torch.matmul(U, torch.matmul(s, U)) + + jitter = 1e-5 + A_sym = A_sym + jitter * eye_batch + + # L: Cholesky decomposition of A_sym for numerical stability + L = torch.linalg.cholesky(A_sym) + + # Determinant and exponential terms for the closed-form expected cost + log_det = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(-1) + det_term = torch.exp(-0.5 * log_det) + + # Mahalanobis distance components scaled by the target weights + # U @ diff needs broadcasting + v = torch.matmul(U.view(*[1] * len(batch_shape), D, D), diff) + tmp = torch.cholesky_solve(v, L) + quad = torch.matmul(v.transpose(-2, -1), tmp) + exp_term = (-0.5 * quad).squeeze(-1).squeeze(-1) + + # Expected cost bounded in [0, 1] + cost = 1.0 - det_term * torch.exp(exp_term) + + if self.reduction == "mean": + loss = cost.mean() + out_batch_size = [] + elif self.reduction == "sum": + loss = cost.sum() + out_batch_size = [] + elif self.reduction == "none": + loss = cost + out_batch_size = batch_shape + else: + raise ValueError(f"Unsupported reduction: {self.reduction}") + return TensorDict({self.tensor_keys.loss_cost: loss}, batch_size=out_batch_size)