From 5e2e9a67362de9abfbf33b56349b0d43ecf488f5 Mon Sep 17 00:00:00 2001 From: Pedro Rosa Date: Fri, 27 Feb 2026 15:40:39 -0300 Subject: [PATCH 1/8] [Algorithm] PILCO --- sota-check/run_pilco.sh | 26 ++ sota-implementations/pilco/config.yaml | 18 + sota-implementations/pilco/pilco.py | 190 ++++++++ sota-implementations/pilco/utils.py | 585 +++++++++++++++++++++++++ 4 files changed, 819 insertions(+) create mode 100644 sota-check/run_pilco.sh create mode 100644 sota-implementations/pilco/config.yaml create mode 100644 sota-implementations/pilco/pilco.py create mode 100644 sota-implementations/pilco/utils.py diff --git a/sota-check/run_pilco.sh b/sota-check/run_pilco.sh new file mode 100644 index 00000000000..393b2ed7332 --- /dev/null +++ b/sota-check/run_pilco.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=pilco +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/pilco_%j.txt +#SBATCH --error=slurm_errors/pilco_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="pilco" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/pilco/pilco.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >> report.log +fi diff --git a/sota-implementations/pilco/config.yaml b/sota-implementations/pilco/config.yaml new file mode 100644 index 00000000000..93ee8ff126e --- /dev/null +++ b/sota-implementations/pilco/config.yaml @@ -0,0 +1,18 @@ +env: + env_name: InvertedPendulum-v5 + library: gym +device: null +logger: + backend: wandb + project_name: torchrl_pilco + group_name: null + video: True +optim: + policy_lr: 5e-3 +pilco: + horizon: 40 + initial_rollout_length: 200 + max_rollout_length: 350 + epochs: 3 + policy_training_steps: 100 + policy_n_basis: 10 diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py new file mode 100644 index 00000000000..da41ec2c41c --- /dev/null +++ b/sota-implementations/pilco/pilco.py @@ -0,0 +1,190 @@ +import hydra +import tensordict +import torch +from omegaconf import DictConfig + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModule +from torchrl._utils import get_available_device +from torchrl.envs import EnvBase +from torchrl.envs.utils import RandomPolicy +from torchrl.record.loggers import generate_exp_name, get_logger, Logger + +from utils import ( + BoTorchGPWorldModel, + ImaginedEnv, + make_env, + pendulum_cost, + RBFController, +) + + +def pilco_loop( + cfg: DictConfig, env: EnvBase, logger: Logger | None = None +) -> TensorDictModule: + obs_dim = env.observation_spec["observation"].shape[-1] + action_dim = env.action_spec.shape[-1] + + random_policy = RandomPolicy(action_spec=env.action_spec) + rollout = env.rollout( + max_steps=cfg.pilco.initial_rollout_length, + policy=random_policy, + break_when_all_done=False, + break_when_any_done=False, + ) + + base_policy = ( + RBFController( + input_dim=obs_dim, + output_dim=action_dim, + n_basis=cfg.pilco.policy_n_basis, + max_action=env.action_spec.high, + ) + .to(env.device) + .double() + ) + policy_module = TensorDictModule( + module=base_policy, + in_keys=[("observation", "mean"), ("observation", "var")], + out_keys=[ + ("action", "mean"), + ("action", "var"), + ("action", "cross_covariance"), + ], + ) + optimizer = torch.optim.Adam(policy_module.parameters(), lr=cfg.optim.policy_lr) + + dtype = torch.float64 + initial_observation = TensorDict( + { + ("observation", "mean"): torch.zeros( + obs_dim, device=env.device, dtype=dtype + ), + ("observation", "var"): torch.eye(obs_dim, device=env.device, dtype=dtype) + * 1e-3, + } + ) + + for epoch in range(cfg.pilco.epochs): + base_world_model = BoTorchGPWorldModel( + obs_dim=obs_dim, action_dim=action_dim + ).to(env.device) + base_world_model.fit(rollout) + base_world_model.freeze_and_detach() + + world_model_module = TensorDictModule( + module=base_world_model, + in_keys=["action", "observation"], + out_keys=[("next_observation", "mean"), ("next_observation", "var")], + ) + + imagined_env = ImaginedEnv( + world_model_module=world_model_module, + base_env=env, + ) + reset_td = initial_observation.expand(*imagined_env.batch_size) + + for step in range(cfg.pilco.policy_training_steps): + logger_step = (epoch * cfg.pilco.policy_training_steps) + step + optimizer.zero_grad() + + imagined_data = imagined_env.rollout( + max_steps=cfg.pilco.horizon, + policy=policy_module, + tensordict=reset_td, + ) + + obs = imagined_data["observation"] + cost = pendulum_cost(obs) + loss = cost.mean() + + loss.backward() + optimizer.step() + + if logger: + logger.log_scalar( + "train/trajectory_cost", loss.item(), step=logger_step + ) + + def policy_for_env(td: TensorDictBase) -> TensorDictBase: + obs = td["observation"] + device, dtype = obs.device, obs.dtype + + is_unbatched = obs.ndim == 1 + if is_unbatched: + obs = obs.unsqueeze(0) + + batch_shape = obs.shape[:-1] + D = obs.shape[-1] + + policy_in = TensorDict( + { + "observation": TensorDict( + { + "mean": obs, + "var": torch.zeros( + (*batch_shape, D, D), device=device, dtype=dtype + ), + }, + batch_size=batch_shape, + ) + }, + batch_size=batch_shape, + device=device, + ) + + policy_out = policy_module(policy_in) + action_mean = policy_out["action", "mean"] + + if is_unbatched: + action_mean = action_mean.squeeze(0) + + td["action"] = action_mean + return td + + test_rollout = env.rollout( + max_steps=1000, policy=policy_for_env, break_when_any_done=True + ) + + reward = test_rollout["episode_reward"][-1].item() + steps = test_rollout["step_count"].max().item() + + if logger: + logger.log_scalar("eval/reward", reward, step=logger_step) + logger.log_scalar("eval/steps", steps, step=logger_step) + + rollout = tensordict.cat([rollout, test_rollout], dim=0) + + if len(rollout) > cfg.pilco.max_rollout_length: + rollout = rollout[-cfg.pilco.max_rollout_length :] + + return policy_module + + +@hydra.main(config_path="", config_name="config", version_base="1.1") +def main(cfg: DictConfig) -> None: + device = torch.device(cfg.device) if cfg.device else get_available_device() + + env = make_env(cfg.env.env_name, device, from_pixels=cfg.logger.video) + + if cfg.logger.backend: + exp_name = generate_exp_name("PILCO", cfg.env.env_name) + logger = get_logger( + cfg.logger.backend, + logger_name="pilco", + experiment_name=exp_name, + wandb_kwargs={ + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + pilco_loop(cfg, env, logger=logger) + + if not env.is_closed: + env.close() + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/pilco/utils.py b/sota-implementations/pilco/utils.py new file mode 100644 index 00000000000..139526a6241 --- /dev/null +++ b/sota-implementations/pilco/utils.py @@ -0,0 +1,585 @@ +from collections.abc import Sequence + +import torch +import torch.nn as nn +from botorch.fit import fit_gpytorch_mll + +from botorch.models import ModelListGP, SingleTaskGP +from gpytorch.kernels import RBFKernel, ScaleKernel +from gpytorch.mlls import SumMarginalLogLikelihood +from gpytorch.priors import GammaPrior + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl.envs import ( + EnvBase, + GymEnv, + ModelBasedEnvBase, + RewardSum, + StepCounter, + TransformedEnv, +) + + +def make_env( + env_name: str, device: str | torch.device, from_pixels: bool = False +) -> TransformedEnv: + """Creates the transformed environment.""" + env = TransformedEnv( + GymEnv(env_name, pixels_only=False, from_pixels=from_pixels, device=device) + ) + env.append_transform(RewardSum()) + env.append_transform(StepCounter()) + return env + + +def pendulum_cost( + obs: TensorDictBase, + weights: torch.Tensor | None = None, + target: torch.Tensor | None = None, +) -> torch.Tensor: + """ + obs["mean"]: [B, T, D] + obs["var"] : [B, T, D, D] + """ + m = obs.get("mean") + s = obs.get("var") + + B, T, D = m.shape + device = m.device + dtype = m.dtype + + if weights is None: + diag_vals = torch.tensor([1.0, 1.0, 1.0, 1.0], device=device, dtype=dtype) + weights = torch.diag(diag_vals) + + if target is None: + target = torch.zeros(D, device=device, dtype=dtype) + + if target.dim() == 1: + target = target.view(1, 1, D).expand(B, T, D) + + eye = torch.eye(D, device=device, dtype=dtype).view(1, 1, D, D) + diff = (m - target).unsqueeze(-1) # [B, T, D, 1] + + L_w, V_w = torch.linalg.eigh(weights) + L_w = torch.clamp(L_w, min=0.0) + U = V_w @ torch.diag_embed(torch.sqrt(L_w)) @ V_w.transpose(-2, -1) + + A_sym = eye + torch.matmul(U, torch.matmul(s, U)) + + jitter = 1e-5 + A_sym = A_sym + jitter * eye + + L = torch.linalg.cholesky(A_sym) + + log_det = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(-1) + det_term = torch.exp(-0.5 * log_det) + + v = torch.matmul(U, diff) + tmp = torch.cholesky_solve(v, L) + quad = torch.matmul(v.transpose(-2, -1), tmp) + exp_term = (-0.5 * quad).squeeze(-1).squeeze(-1) + + return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1) + + +class BoTorchGPWorldModel(nn.Module): + def __init__(self, obs_dim: int, action_dim: int) -> None: + super().__init__() + self.obs_dim = obs_dim + self.action_dim = action_dim + self.input_dim = obs_dim + action_dim + + self.model_list: ModelListGP | None = None + + self.register_buffer("X_train", torch.empty(0)) + self.register_buffer("lengthscales", torch.zeros(self.obs_dim, self.input_dim)) + self.register_buffer("variances", torch.zeros(self.obs_dim, 1)) + self.register_buffer("noises", torch.zeros(self.obs_dim)) + self._cached_inv_K: torch.Tensor | None = None + self._cached_beta: torch.Tensor | None = None + + @property + def device(self) -> torch.device: + return self.lengthscales.device + + def fit(self, dataset: TensorDictBase) -> None: + obs = dataset["observation"] + action = dataset["action"] + next_obs = dataset[("next", "observation")] + + X_train = torch.cat([obs, action], dim=-1).detach().to(self.device) + y_train = (next_obs - obs).detach().to(self.device) + self.X_train = X_train + + models = [] + for i in range(self.obs_dim): + train_x = X_train + train_y = y_train[:, i].unsqueeze(-1) + + covar_module = ScaleKernel( + RBFKernel( + ard_num_dims=self.input_dim, lengthscale_prior=GammaPrior(1.1, 0.1) + ), + outputscale_prior=GammaPrior(1.5, 0.5), + ) + + gp = SingleTaskGP( + train_X=train_x, train_Y=train_y, covar_module=covar_module + ) + gp.likelihood.noise_covar.register_prior( + "noise_prior", GammaPrior(1.2, 0.05), "noise" + ) + + models.append(gp) + + self.model_list = ModelListGP(*models).to(self.device) + mll = SumMarginalLogLikelihood(self.model_list.likelihood, self.model_list) + + fit_gpytorch_mll(mll) + self._extract_parameters(y_train) + + def _extract_parameters(self, y_train: torch.Tensor) -> None: + lengthscales, variances, noises, inv_Ks, betas = [], [], [], [], [] + + for i, gp in enumerate(self.model_list.models): + gp.eval() + gp.likelihood.eval() + + ls = gp.covar_module.base_kernel.lengthscale.squeeze().detach() + var = gp.covar_module.outputscale.detach() + noise = gp.likelihood.noise.squeeze().detach() + + lengthscales.append(ls) + variances.append(var) + noises.append(noise) + + X_scaled = self.X_train / ls + dist = torch.cdist(X_scaled, X_scaled, p=2) ** 2 + K = var * torch.exp(-0.5 * dist) + + K_noisy = K + (noise + 1e-6) * torch.eye( + self.X_train.size(0), device=self.device + ) + + L = torch.linalg.cholesky(K_noisy) + eye = torch.eye(L.size(0), dtype=L.dtype, device=L.device) + inv_K = torch.cholesky_solve(eye, L) + + y = y_train[:, i].unsqueeze(-1) + beta = torch.cholesky_solve(y, L).squeeze(-1) + + inv_Ks.append(inv_K) + betas.append(beta) + + self.lengthscales = torch.stack(lengthscales) + self.variances = torch.stack(variances).unsqueeze(-1) + self.noises = torch.stack(noises) + + self._cached_inv_K = torch.stack(inv_Ks) + self._cached_beta = torch.stack(betas) + + def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: + return self._cached_inv_K, self._cached_beta + + def _gather_gp_params(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.lengthscales, self.variances, self.noises + + def forward( + self, action: TensorDictBase, observation: TensorDictBase + ) -> tuple[torch.Tensor, torch.Tensor]: + observation_uncertain = False + x_var = observation.get("var") + if x_var is not None: + observation_uncertain = not torch.all( + torch.isclose(x_var, torch.zeros_like(x_var)) + ) + if observation_uncertain: + return self.uncertain_forward(action, observation) + else: + return self.deterministic_forward(action, observation) + + def freeze_and_detach(self) -> None: + pass + + def uncertain_forward( + self, action: TensorDictBase, obs: TensorDictBase + ) -> tuple[torch.Tensor, torch.Tensor]: + inv_K, beta = self.compute_factorizations() + lengthscales, variances, noises = self._gather_gp_params() + + m_x, s_x = obs.get("mean"), obs.get("var") + m_u, s_u, c_xu = ( + action.get("mean"), + action.get("var"), + action.get("cross_covariance"), + ) + + device, dtype = m_x.device, m_x.dtype + + joint_mean = torch.cat([m_x, m_u], dim=-1) + + s_ = s_x @ c_xu + upper = torch.cat([s_x, s_], dim=-1) + lower = torch.cat([s_.transpose(-1, -2), s_u], dim=-1) + + joint_var = torch.cat([upper, lower], dim=-2) + + X_train = self.X_train + num_train_pts = X_train.shape[0] + batch_size = joint_mean.shape[0] + + inp = X_train - joint_mean.unsqueeze(1) + + inv_L = torch.diag_embed(1.0 / lengthscales).to(dtype=dtype, device=device) + inv_N = inp.unsqueeze(1) @ inv_L.unsqueeze(0) + + B_mat = inv_L.unsqueeze(0) @ joint_var.unsqueeze(1) @ inv_L.unsqueeze(0) + B_mat = B_mat + torch.eye( + self.input_dim, dtype=m_x.dtype, device=m_x.device + ).view(1, 1, self.input_dim, self.input_dim) + + t = torch.linalg.solve(B_mat, inv_N.transpose(-2, -1)).transpose(-2, -1) + + scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) + lb = scaled_exp * beta.unsqueeze(0) + + det_B = torch.linalg.det(B_mat) + c = variances.squeeze(1).unsqueeze(0) / torch.sqrt(det_B) + + pred_mean = torch.sum(lb, dim=-1) * c.squeeze(0) + + t_inv_L = t @ inv_L.unsqueeze(0) + + cross_cov_E_D = torch.matmul( + t_inv_L.transpose(-2, -1), lb.unsqueeze(-1) + ).squeeze(-1) * c.unsqueeze(-1) + cross_cov = cross_cov_E_D.transpose(-2, -1) + + pred_cov = torch.zeros( + batch_size, self.obs_dim, self.obs_dim, dtype=m_x.dtype, device=m_x.device + ) + + X_i = X_train.unsqueeze(1) + X_j = X_train.unsqueeze(0) + diff = X_i - X_j + joint_mean_flat = joint_mean.unsqueeze(1).unsqueeze(1) + + for a in range(self.obs_dim): + for b in range(self.obs_dim): + l2_a = lengthscales[a].to(device=device, dtype=dtype) ** 2 + l2_b = lengthscales[b].to(device=device, dtype=dtype) ** 2 + + inv_L_a = 1.0 / l2_a + inv_L_b = 1.0 / l2_b + inv_L_sum = inv_L_a + inv_L_b + Lambda_ab = 1.0 / inv_L_sum + + z_bar = Lambda_ab * (X_i * inv_L_a + X_j * inv_L_b) + z = z_bar.unsqueeze(0) - joint_mean_flat + + z_flat = z.view( + batch_size, num_train_pts * num_train_pts, self.input_dim + ) + + R_ab = joint_var @ torch.diag(inv_L_sum) + torch.eye( + self.input_dim, dtype=m_x.dtype, device=m_x.device + ).unsqueeze(0) + + inv_L_plus = 1.0 / (l2_a + l2_b) + exp1 = -0.5 * torch.sum(diff * inv_L_plus * diff, dim=-1) + + M_ab = joint_var + torch.diag(Lambda_ab).unsqueeze(0) + + solved_z_flat = torch.linalg.solve( + M_ab, z_flat.transpose(-2, -1) + ).transpose(-2, -1) + exp2 = (-0.5 * torch.sum(z_flat * solved_z_flat, dim=-1)).view( + batch_size, num_train_pts, num_train_pts + ) + + det_R_ab = torch.linalg.det(R_ab) + c_ab = variances[a] * variances[b] / torch.sqrt(det_R_ab) + + Q_ab = c_ab.view(-1, 1, 1) * torch.exp(exp1.unsqueeze(0) + exp2) + + Qb = torch.matmul(Q_ab, beta[b]) + pred_cov[:, a, b] = ( + torch.matmul(beta[a].unsqueeze(0), Qb.unsqueeze(-1)) + .squeeze(-1) + .squeeze(-1) + ) + + if a == b: + invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) + trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) + + pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item() + + outer_mean = torch.bmm(pred_mean.unsqueeze(-1), pred_mean.unsqueeze(-2)) + pred_cov = pred_cov - outer_mean + + pred_cov = (pred_cov + pred_cov.transpose(-2, -1)) / 2.0 + + m_dx = pred_mean + s_dx = pred_cov + c_xdx = cross_cov + + cov_xf = upper @ c_xdx + + m_x = m_x + m_dx + + s_x = s_x + s_dx + cov_xf + cov_xf.transpose(-2, -1) + + s_x = (s_x + s_x.transpose(-2, -1)) / 2.0 + s_x = s_x + 1e-8 * torch.eye(self.obs_dim, device=s_x.device).expand( + s_x.shape[0], -1, -1 + ) + return m_x, s_x + + def deterministic_forward( + self, action: TensorDictBase, observation: TensorDictBase + ) -> tuple[torch.Tensor, torch.Tensor]: + observation_mean = observation.get("mean") + action_mean = action.get("mean") + + x_flat = observation_mean.view(-1, self.obs_dim) + u_flat = action_mean.view(-1, self.action_dim) + + X_test = torch.cat([x_flat, u_flat], dim=-1) + + means, stds = [], [] + + with torch.no_grad(): + for gp in self.model_list.models: + posterior = gp.posterior(X_test) + means.append(posterior.mean.squeeze(-1)) + stds.append(torch.sqrt(posterior.variance).squeeze(-1)) + + delta_mean_flat = torch.stack(means, dim=-1) + delta_std_flat = torch.stack(stds, dim=-1) + + batch_shape = observation_mean.shape[:-1] + delta_mean = delta_mean_flat.view(*batch_shape, self.obs_dim) + delta_std = delta_std_flat.view(*batch_shape, self.obs_dim) + + return observation_mean + delta_mean, torch.diag_embed(delta_std**2) + + +class ImaginedEnv(ModelBasedEnvBase): + def __init__( + self, + world_model_module: TensorDictModule, + base_env: EnvBase, + batch_size: int | torch.Size | Sequence[int] | None = None, + **kwargs + ) -> None: + if batch_size is not None: + self.batch_size = ( + torch.Size(batch_size) + if not isinstance(batch_size, torch.Size) + else batch_size + ) + elif len(base_env.batch_size) == 0: + self.batch_size = torch.Size([1]) + else: + self.batch_size = base_env.batch_size + + super().__init__( + world_model_module, + device=base_env.device, + batch_size=self.batch_size, + **kwargs + ) + + self.observation_spec = base_env.observation_spec.expand( + self.batch_size + ).clone() + self.action_spec = base_env.action_spec.expand(self.batch_size).clone() + self.reward_spec = base_env.reward_spec.expand(self.batch_size).clone() + self.done_spec = base_env.done_spec.expand(self.batch_size).clone() + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict = self.world_model(tensordict) + + reward = torch.zeros(*tensordict.shape, 1, device=self.device) + done = torch.zeros(*tensordict.shape, 1, dtype=torch.bool, device=self.device) + out = TensorDict( + { + "observation": tensordict.get("next_observation"), + "reward": reward, + "done": done, + "terminated": done.clone(), + }, + tensordict.shape, + ) + return out + + def _reset( + self, tensordict: TensorDictBase | None = None, **kwargs + ) -> TensorDictBase: + if tensordict is None: + tensordict = TensorDict({}, batch_size=self.batch_size, device=self.device) + + if ( + tensordict.get(("observation", "var"), None) is not None + and tensordict.get(("observation", "mean"), None) is not None + ): + return tensordict.copy() + + obs = tensordict.get("observation", None) + if obs is None: + obs = self.observation_spec.rand(shape=self.batch_size).get("observation") + if obs.ndim == 1: + obs = obs.expand(self.batch_size, -1) + + obs = obs.to(self.device) + B, D = obs.shape + + out = TensorDict( + { + ("observation", "mean"): obs, + ("observation", "var"): torch.zeros( + B, D, D, dtype=obs.dtype, device=self.device + ), + }, + batch_size=self.batch_size, + device=self.device, + ) + + out.set("done", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) + out.set("terminated", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) + + return out + + +class RBFController(nn.Module): + def __init__( + self, + input_dim: int, + output_dim: int, + max_action: float | torch.Tensor, + n_basis: int = 10, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.max_action = max_action + self.n_basis = n_basis + + self.centers = nn.Parameter(torch.randn(n_basis, input_dim) * 0.5) + self.weights = nn.Parameter(torch.randn(n_basis, output_dim) * 0.1) + self.lengthscales = nn.Parameter(torch.ones(input_dim)) + self.variance = 1.0 + + @staticmethod + def squash_sin( + m: torch.Tensor, s: torch.Tensor, max_action: float | torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, K = m.shape + device = m.device + dtype = m.dtype + + if not isinstance(max_action, torch.Tensor): + max_action = torch.tensor(max_action, dtype=dtype, device=device) + + max_action = max_action.view(-1) + if max_action.shape[0] == 1 and K > 1: + max_action = max_action.expand(K) + + diag_s = torch.diagonal(s, dim1=-2, dim2=-1) + + M = max_action * torch.exp(-diag_s / 2.0) * torch.sin(m) + + lq = -(diag_s.unsqueeze(-1) + diag_s.unsqueeze(-2)) / 2.0 + q = torch.exp(lq) + + m_diff = m.unsqueeze(-1) - m.unsqueeze(-2) + m_sum = m.unsqueeze(-1) + m.unsqueeze(-2) + + S = (torch.exp(lq + s) - q) * torch.cos(m_diff) - ( + torch.exp(lq - s) - q + ) * torch.cos(m_sum) + + outer_max = max_action.unsqueeze(1) * max_action.unsqueeze(0) + S = outer_max.unsqueeze(0) * S / 2.0 + + C = torch.diag_embed(max_action * torch.exp(-diag_s / 2.0) * torch.cos(m)) + + return M, S, C + + def forward( + self, m: torch.Tensor, S: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, D = m.shape + N = self.n_basis + device = m.device + + iL = torch.diag(1.0 / self.lengthscales) + iL_batch = iL.unsqueeze(0) + + inp = self.centers.unsqueeze(0) - m.unsqueeze(1) + + B_mat = iL_batch @ S @ iL_batch + torch.eye( + D, device=device, dtype=m.dtype + ).unsqueeze(0) + + iN = inp @ iL + + t = torch.linalg.solve(B_mat, iN.mT).mT + + exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1)) + detB = torch.linalg.det(B_mat) + c = self.variance / torch.sqrt(detB) + phi_mean = c.unsqueeze(-1) * exp_term + + M = phi_mean @ self.weights + + tiL = t @ iL + V = torch.bmm(tiL.mT, phi_mean.unsqueeze(-1) * self.weights) + + c_i = self.centers.unsqueeze(1) + c_j = self.centers.unsqueeze(0) + diff = c_i - c_j + c_bar = (c_i + c_j) / 2.0 + + inv_Lambda = 1.0 / (self.lengthscales**2) + exp1 = -0.25 * torch.sum((diff**2) * inv_Lambda, dim=-1) + + Lambda_half = torch.diag((self.lengthscales**2) / 2.0) + B_q = S + Lambda_half.unsqueeze(0) + + z = c_bar.unsqueeze(0) - m.unsqueeze(1).unsqueeze(1) + z_flat = z.view(B, N * N, D) + + solved_z_flat = torch.linalg.solve(B_q, z_flat.mT).mT + exp2 = -0.5 * torch.sum(z_flat * solved_z_flat, dim=-1).view(B, N, N) + + log_det_Lambda_half = torch.sum(torch.log((self.lengthscales**2) / 2.0)) + log_det_B_q = torch.logdet(B_q) + c_q = torch.exp(0.5 * (log_det_Lambda_half - log_det_B_q)) + + Q = (self.variance**2 * c_q.view(B, 1, 1)) * torch.exp( + exp1.unsqueeze(0) + exp2 + ) + + W_batch = self.weights.unsqueeze(0).expand(B, N, -1) + S_action = torch.bmm(W_batch.mT, torch.bmm(Q, W_batch)) + + M_out = torch.bmm(M.unsqueeze(-1), M.unsqueeze(1)) + S_action = S_action - M_out + + S_action = (S_action + S_action.mT) / 2.0 + S_action = ( + S_action + + torch.eye(self.output_dim, device=device, dtype=m.dtype).unsqueeze(0) + * 1e-6 + ) + + if self.max_action is not None: + M, S_action, C = self.squash_sin(M, S_action, self.max_action) + V = torch.bmm(V, C) + + return M, S_action, V From d60c8d05bb59e88951bc9fe9db72f8b9ddf42626 Mon Sep 17 00:00:00 2001 From: Pedro Rosa Date: Tue, 3 Mar 2026 00:54:58 -0300 Subject: [PATCH 2/8] add pilco objective --- sota-implementations/pilco/pilco.py | 19 ++--- torchrl/objectives/__init__.py | 2 + torchrl/objectives/pilco.py | 119 ++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 11 deletions(-) create mode 100644 torchrl/objectives/pilco.py diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py index da41ec2c41c..83bb4e7cffd 100644 --- a/sota-implementations/pilco/pilco.py +++ b/sota-implementations/pilco/pilco.py @@ -8,15 +8,10 @@ from torchrl._utils import get_available_device from torchrl.envs import EnvBase from torchrl.envs.utils import RandomPolicy +from torchrl.objectives import ExponentialQuadraticCost from torchrl.record.loggers import generate_exp_name, get_logger, Logger -from utils import ( - BoTorchGPWorldModel, - ImaginedEnv, - make_env, - pendulum_cost, - RBFController, -) +from utils import BoTorchGPWorldModel, ImaginedEnv, make_env, RBFController def pilco_loop( @@ -65,6 +60,7 @@ def pilco_loop( } ) + cost_module = ExponentialQuadraticCost(reduction="none").to(env.device) for epoch in range(cfg.pilco.epochs): base_world_model = BoTorchGPWorldModel( obs_dim=obs_dim, action_dim=action_dim @@ -94,9 +90,8 @@ def pilco_loop( tensordict=reset_td, ) - obs = imagined_data["observation"] - cost = pendulum_cost(obs) - loss = cost.mean() + loss_td = cost_module(imagined_data) + loss = loss_td.get("loss_cost").sum(dim=-1).mean() loss.backward() optimizer.step() @@ -143,7 +138,9 @@ def policy_for_env(td: TensorDictBase) -> TensorDictBase: return td test_rollout = env.rollout( - max_steps=1000, policy=policy_for_env, break_when_any_done=True + max_steps=100, + policy=policy_for_env, + break_when_any_done=True, # TODO change the max_steps back maybe? ) reward = test_rollout["episode_reward"][-1].item() diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index 2df2da650ca..f8e47d73519 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -18,6 +18,7 @@ from torchrl.objectives.gail import GAILLoss from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss from torchrl.objectives.multiagent import QMixerLoss +from torchrl.objectives.pilco import ExponentialQuadraticCost from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from torchrl.objectives.redq import REDQLoss from torchrl.objectives.reinforce import ReinforceLoss @@ -52,6 +53,7 @@ "DreamerActorLoss", "DreamerModelLoss", "DreamerValueLoss", + "ExponentialQuadraticCost", "GAILLoss", "HardUpdate", "IQLLoss", diff --git a/torchrl/objectives/pilco.py b/torchrl/objectives/pilco.py new file mode 100644 index 00000000000..ae523415b43 --- /dev/null +++ b/torchrl/objectives/pilco.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass + +import torch +from tensordict import TensorDict, TensorDictBase +from torchrl.objectives.common import LossModule + + +class ExponentialQuadraticCost(LossModule): + """Computes the expected saturating cost for a Gaussian-distributed state. + + This serves as a smooth, unimodal approximation of a 0-1 cost over a target area, + allowing for analytic gradient computation during policy search (e.g., PILCO). + Calculates E_{x_t}[c(x_t)] over N(m, s) as defined in Eq. (24) and (25) of + Deisenroth & Rasmussen (2011). + + Args: + target (torch.Tensor, optional): The target state vector. Defaults to the origin. + weights (torch.Tensor, optional): The precision matrix mapping state dimensions + to the cost distance metric. Defaults to the identity matrix. + reduction (str, optional): Specifies the reduction to apply to the output: + 'mean' | 'sum' | 'none'. Defaults to 'mean'. + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for configurable tensordict keys.""" + + loc: str | tuple[str, ...] = ("observation", "mean") + scale: str | tuple[str, ...] = ("observation", "var") + loss_cost: str | tuple[str, ...] = "loss_cost" + + default_keys = _AcceptedKeys + + def __init__( + self, + target: torch.Tensor | None = None, + weights: torch.Tensor | None = None, + reduction: str = "mean", + ): + super().__init__() + self._tensor_keys = self._AcceptedKeys() + self.reduction = reduction + + self.register_buffer("target", target) + self.register_buffer("weights", weights) + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + m = tensordict.get(self.tensor_keys.loc) + s = tensordict.get(self.tensor_keys.scale) + + batch_shape = m.shape[:-1] + D = m.shape[-1] + device = m.device + dtype = m.dtype + + weights = ( + self.weights + if self.weights is not None + else torch.eye(D, device=device, dtype=dtype) + ) + target = ( + self.target + if self.target is not None + else torch.zeros(D, device=device, dtype=dtype) + ) + + if target.dim() == 1: + target_shape = (*[1] * len(batch_shape), D) + target = target.view(*target_shape).expand(*batch_shape, D) + + eye = torch.eye(D, device=device, dtype=dtype) + eye_batch = eye.view(*[1] * len(batch_shape), D, D) + + # diff: Distance from the current mean to the target (x - x_target) + diff = (m - target).unsqueeze(-1) + + # L_w, V_w: Eigenvalues and eigenvectors of the precision weight matrix + L_w, V_w = torch.linalg.eigh(weights) + L_w = torch.clamp(L_w, min=0.0) + + # U: Scaled transformation matrix for the cost weighting + U = V_w @ torch.diag_embed(torch.sqrt(L_w)) @ V_w.transpose(-2, -1) + + # A_sym: Covariance transformation required for computing the expected cost integral + # U is (D, D), s is (*batch_shape, D, D) + A_sym = eye_batch + torch.matmul(U, torch.matmul(s, U)) + + jitter = 1e-5 + A_sym = A_sym + jitter * eye_batch + + # L: Cholesky decomposition of A_sym for numerical stability + L = torch.linalg.cholesky(A_sym) + + # Determinant and exponential terms for the closed-form expected cost + log_det = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(-1) + det_term = torch.exp(-0.5 * log_det) + + # Mahalanobis distance components scaled by the target weights + # U @ diff needs broadcasting + v = torch.matmul(U.view(*[1] * len(batch_shape), D, D), diff) + tmp = torch.cholesky_solve(v, L) + quad = torch.matmul(v.transpose(-2, -1), tmp) + exp_term = (-0.5 * quad).squeeze(-1).squeeze(-1) + + # Expected cost bounded in [0, 1] + cost = 1.0 - det_term * torch.exp(exp_term) + + if self.reduction == "mean": + loss = cost.mean() + out_batch_size = [] + elif self.reduction == "sum": + loss = cost.sum() + out_batch_size = [] + elif self.reduction == "none": + loss = cost + out_batch_size = batch_shape + else: + raise ValueError(f"Unsupported reduction: {self.reduction}") + return TensorDict({self.tensor_keys.loss_cost: loss}, batch_size=out_batch_size) From 2eba8f75ab9e7b0cecb29764e57ea3c2a75242f2 Mon Sep 17 00:00:00 2001 From: Pedro Rosa Date: Fri, 13 Mar 2026 20:25:05 +0100 Subject: [PATCH 3/8] Move world-model to core --- sota-implementations/pilco/pilco.py | 9 +- sota-implementations/pilco/utils.py | 289 ---------------------- torchrl/modules/models/gp.py | 362 ++++++++++++++++++++++++++++ 3 files changed, 367 insertions(+), 293 deletions(-) create mode 100644 torchrl/modules/models/gp.py diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py index 83bb4e7cffd..8c4849134c4 100644 --- a/sota-implementations/pilco/pilco.py +++ b/sota-implementations/pilco/pilco.py @@ -8,10 +8,11 @@ from torchrl._utils import get_available_device from torchrl.envs import EnvBase from torchrl.envs.utils import RandomPolicy +from torchrl.modules.models import GPWorldModel from torchrl.objectives import ExponentialQuadraticCost from torchrl.record.loggers import generate_exp_name, get_logger, Logger -from utils import BoTorchGPWorldModel, ImaginedEnv, make_env, RBFController +from utils import ImaginedEnv, make_env, RBFController def pilco_loop( @@ -62,9 +63,9 @@ def pilco_loop( cost_module = ExponentialQuadraticCost(reduction="none").to(env.device) for epoch in range(cfg.pilco.epochs): - base_world_model = BoTorchGPWorldModel( - obs_dim=obs_dim, action_dim=action_dim - ).to(env.device) + base_world_model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim).to( + env.device + ) base_world_model.fit(rollout) base_world_model.freeze_and_detach() diff --git a/sota-implementations/pilco/utils.py b/sota-implementations/pilco/utils.py index 139526a6241..43713fe8565 100644 --- a/sota-implementations/pilco/utils.py +++ b/sota-implementations/pilco/utils.py @@ -2,12 +2,6 @@ import torch import torch.nn as nn -from botorch.fit import fit_gpytorch_mll - -from botorch.models import ModelListGP, SingleTaskGP -from gpytorch.kernels import RBFKernel, ScaleKernel -from gpytorch.mlls import SumMarginalLogLikelihood -from gpytorch.priors import GammaPrior from tensordict import TensorDict, TensorDictBase from tensordict.nn import TensorDictModule @@ -85,289 +79,6 @@ def pendulum_cost( return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1) -class BoTorchGPWorldModel(nn.Module): - def __init__(self, obs_dim: int, action_dim: int) -> None: - super().__init__() - self.obs_dim = obs_dim - self.action_dim = action_dim - self.input_dim = obs_dim + action_dim - - self.model_list: ModelListGP | None = None - - self.register_buffer("X_train", torch.empty(0)) - self.register_buffer("lengthscales", torch.zeros(self.obs_dim, self.input_dim)) - self.register_buffer("variances", torch.zeros(self.obs_dim, 1)) - self.register_buffer("noises", torch.zeros(self.obs_dim)) - self._cached_inv_K: torch.Tensor | None = None - self._cached_beta: torch.Tensor | None = None - - @property - def device(self) -> torch.device: - return self.lengthscales.device - - def fit(self, dataset: TensorDictBase) -> None: - obs = dataset["observation"] - action = dataset["action"] - next_obs = dataset[("next", "observation")] - - X_train = torch.cat([obs, action], dim=-1).detach().to(self.device) - y_train = (next_obs - obs).detach().to(self.device) - self.X_train = X_train - - models = [] - for i in range(self.obs_dim): - train_x = X_train - train_y = y_train[:, i].unsqueeze(-1) - - covar_module = ScaleKernel( - RBFKernel( - ard_num_dims=self.input_dim, lengthscale_prior=GammaPrior(1.1, 0.1) - ), - outputscale_prior=GammaPrior(1.5, 0.5), - ) - - gp = SingleTaskGP( - train_X=train_x, train_Y=train_y, covar_module=covar_module - ) - gp.likelihood.noise_covar.register_prior( - "noise_prior", GammaPrior(1.2, 0.05), "noise" - ) - - models.append(gp) - - self.model_list = ModelListGP(*models).to(self.device) - mll = SumMarginalLogLikelihood(self.model_list.likelihood, self.model_list) - - fit_gpytorch_mll(mll) - self._extract_parameters(y_train) - - def _extract_parameters(self, y_train: torch.Tensor) -> None: - lengthscales, variances, noises, inv_Ks, betas = [], [], [], [], [] - - for i, gp in enumerate(self.model_list.models): - gp.eval() - gp.likelihood.eval() - - ls = gp.covar_module.base_kernel.lengthscale.squeeze().detach() - var = gp.covar_module.outputscale.detach() - noise = gp.likelihood.noise.squeeze().detach() - - lengthscales.append(ls) - variances.append(var) - noises.append(noise) - - X_scaled = self.X_train / ls - dist = torch.cdist(X_scaled, X_scaled, p=2) ** 2 - K = var * torch.exp(-0.5 * dist) - - K_noisy = K + (noise + 1e-6) * torch.eye( - self.X_train.size(0), device=self.device - ) - - L = torch.linalg.cholesky(K_noisy) - eye = torch.eye(L.size(0), dtype=L.dtype, device=L.device) - inv_K = torch.cholesky_solve(eye, L) - - y = y_train[:, i].unsqueeze(-1) - beta = torch.cholesky_solve(y, L).squeeze(-1) - - inv_Ks.append(inv_K) - betas.append(beta) - - self.lengthscales = torch.stack(lengthscales) - self.variances = torch.stack(variances).unsqueeze(-1) - self.noises = torch.stack(noises) - - self._cached_inv_K = torch.stack(inv_Ks) - self._cached_beta = torch.stack(betas) - - def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: - return self._cached_inv_K, self._cached_beta - - def _gather_gp_params(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.lengthscales, self.variances, self.noises - - def forward( - self, action: TensorDictBase, observation: TensorDictBase - ) -> tuple[torch.Tensor, torch.Tensor]: - observation_uncertain = False - x_var = observation.get("var") - if x_var is not None: - observation_uncertain = not torch.all( - torch.isclose(x_var, torch.zeros_like(x_var)) - ) - if observation_uncertain: - return self.uncertain_forward(action, observation) - else: - return self.deterministic_forward(action, observation) - - def freeze_and_detach(self) -> None: - pass - - def uncertain_forward( - self, action: TensorDictBase, obs: TensorDictBase - ) -> tuple[torch.Tensor, torch.Tensor]: - inv_K, beta = self.compute_factorizations() - lengthscales, variances, noises = self._gather_gp_params() - - m_x, s_x = obs.get("mean"), obs.get("var") - m_u, s_u, c_xu = ( - action.get("mean"), - action.get("var"), - action.get("cross_covariance"), - ) - - device, dtype = m_x.device, m_x.dtype - - joint_mean = torch.cat([m_x, m_u], dim=-1) - - s_ = s_x @ c_xu - upper = torch.cat([s_x, s_], dim=-1) - lower = torch.cat([s_.transpose(-1, -2), s_u], dim=-1) - - joint_var = torch.cat([upper, lower], dim=-2) - - X_train = self.X_train - num_train_pts = X_train.shape[0] - batch_size = joint_mean.shape[0] - - inp = X_train - joint_mean.unsqueeze(1) - - inv_L = torch.diag_embed(1.0 / lengthscales).to(dtype=dtype, device=device) - inv_N = inp.unsqueeze(1) @ inv_L.unsqueeze(0) - - B_mat = inv_L.unsqueeze(0) @ joint_var.unsqueeze(1) @ inv_L.unsqueeze(0) - B_mat = B_mat + torch.eye( - self.input_dim, dtype=m_x.dtype, device=m_x.device - ).view(1, 1, self.input_dim, self.input_dim) - - t = torch.linalg.solve(B_mat, inv_N.transpose(-2, -1)).transpose(-2, -1) - - scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) - lb = scaled_exp * beta.unsqueeze(0) - - det_B = torch.linalg.det(B_mat) - c = variances.squeeze(1).unsqueeze(0) / torch.sqrt(det_B) - - pred_mean = torch.sum(lb, dim=-1) * c.squeeze(0) - - t_inv_L = t @ inv_L.unsqueeze(0) - - cross_cov_E_D = torch.matmul( - t_inv_L.transpose(-2, -1), lb.unsqueeze(-1) - ).squeeze(-1) * c.unsqueeze(-1) - cross_cov = cross_cov_E_D.transpose(-2, -1) - - pred_cov = torch.zeros( - batch_size, self.obs_dim, self.obs_dim, dtype=m_x.dtype, device=m_x.device - ) - - X_i = X_train.unsqueeze(1) - X_j = X_train.unsqueeze(0) - diff = X_i - X_j - joint_mean_flat = joint_mean.unsqueeze(1).unsqueeze(1) - - for a in range(self.obs_dim): - for b in range(self.obs_dim): - l2_a = lengthscales[a].to(device=device, dtype=dtype) ** 2 - l2_b = lengthscales[b].to(device=device, dtype=dtype) ** 2 - - inv_L_a = 1.0 / l2_a - inv_L_b = 1.0 / l2_b - inv_L_sum = inv_L_a + inv_L_b - Lambda_ab = 1.0 / inv_L_sum - - z_bar = Lambda_ab * (X_i * inv_L_a + X_j * inv_L_b) - z = z_bar.unsqueeze(0) - joint_mean_flat - - z_flat = z.view( - batch_size, num_train_pts * num_train_pts, self.input_dim - ) - - R_ab = joint_var @ torch.diag(inv_L_sum) + torch.eye( - self.input_dim, dtype=m_x.dtype, device=m_x.device - ).unsqueeze(0) - - inv_L_plus = 1.0 / (l2_a + l2_b) - exp1 = -0.5 * torch.sum(diff * inv_L_plus * diff, dim=-1) - - M_ab = joint_var + torch.diag(Lambda_ab).unsqueeze(0) - - solved_z_flat = torch.linalg.solve( - M_ab, z_flat.transpose(-2, -1) - ).transpose(-2, -1) - exp2 = (-0.5 * torch.sum(z_flat * solved_z_flat, dim=-1)).view( - batch_size, num_train_pts, num_train_pts - ) - - det_R_ab = torch.linalg.det(R_ab) - c_ab = variances[a] * variances[b] / torch.sqrt(det_R_ab) - - Q_ab = c_ab.view(-1, 1, 1) * torch.exp(exp1.unsqueeze(0) + exp2) - - Qb = torch.matmul(Q_ab, beta[b]) - pred_cov[:, a, b] = ( - torch.matmul(beta[a].unsqueeze(0), Qb.unsqueeze(-1)) - .squeeze(-1) - .squeeze(-1) - ) - - if a == b: - invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) - trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) - - pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item() - - outer_mean = torch.bmm(pred_mean.unsqueeze(-1), pred_mean.unsqueeze(-2)) - pred_cov = pred_cov - outer_mean - - pred_cov = (pred_cov + pred_cov.transpose(-2, -1)) / 2.0 - - m_dx = pred_mean - s_dx = pred_cov - c_xdx = cross_cov - - cov_xf = upper @ c_xdx - - m_x = m_x + m_dx - - s_x = s_x + s_dx + cov_xf + cov_xf.transpose(-2, -1) - - s_x = (s_x + s_x.transpose(-2, -1)) / 2.0 - s_x = s_x + 1e-8 * torch.eye(self.obs_dim, device=s_x.device).expand( - s_x.shape[0], -1, -1 - ) - return m_x, s_x - - def deterministic_forward( - self, action: TensorDictBase, observation: TensorDictBase - ) -> tuple[torch.Tensor, torch.Tensor]: - observation_mean = observation.get("mean") - action_mean = action.get("mean") - - x_flat = observation_mean.view(-1, self.obs_dim) - u_flat = action_mean.view(-1, self.action_dim) - - X_test = torch.cat([x_flat, u_flat], dim=-1) - - means, stds = [], [] - - with torch.no_grad(): - for gp in self.model_list.models: - posterior = gp.posterior(X_test) - means.append(posterior.mean.squeeze(-1)) - stds.append(torch.sqrt(posterior.variance).squeeze(-1)) - - delta_mean_flat = torch.stack(means, dim=-1) - delta_std_flat = torch.stack(stds, dim=-1) - - batch_shape = observation_mean.shape[:-1] - delta_mean = delta_mean_flat.view(*batch_shape, self.obs_dim) - delta_std = delta_std_flat.view(*batch_shape, self.obs_dim) - - return observation_mean + delta_mean, torch.diag_embed(delta_std**2) - - class ImaginedEnv(ModelBasedEnvBase): def __init__( self, diff --git a/torchrl/modules/models/gp.py b/torchrl/modules/models/gp.py new file mode 100644 index 00000000000..1df42e1aa72 --- /dev/null +++ b/torchrl/modules/models/gp.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import torch.nn as nn +from tensordict import TensorDictBase + +try: + from botorch.fit import fit_gpytorch_mll + from botorch.models import ModelListGP, SingleTaskGP + from gpytorch.kernels import RBFKernel, ScaleKernel + from gpytorch.mlls import SumMarginalLogLikelihood + from gpytorch.priors import GammaPrior + + _has_botorch = True +except ImportError: + _has_botorch = False + + +class GPWorldModel(nn.Module): + """Gaussian Process World Model. + + This module implements a Gaussian Process (GP) based world model using BoTorch and GPyTorch. + It models the transition dynamics of an environment by predicting the change in observation + given the current observation and action. + + Args: + obs_dim (int): The dimension of the observation space. + action_dim (int): The dimension of the action space. + """ + + def __init__(self, obs_dim: int, action_dim: int) -> None: + if not _has_botorch: + raise ImportError( + "botorch and gpytorch are required to use GPWorldModel. " + "Please install them to proceed." + ) + super().__init__() + self.obs_dim = obs_dim + self.action_dim = action_dim + self.input_dim = obs_dim + action_dim + + self.model_list: ModelListGP | None = None + + self.register_buffer("X_train", torch.empty(0)) + self.register_buffer("lengthscales", torch.zeros(self.obs_dim, self.input_dim)) + self.register_buffer("variances", torch.zeros(self.obs_dim, 1)) + self.register_buffer("noises", torch.zeros(self.obs_dim)) + self._cached_inv_K: torch.Tensor | None = None + self._cached_beta: torch.Tensor | None = None + + @property + def device(self) -> torch.device: + return self.lengthscales.device + + def fit(self, dataset: TensorDictBase) -> None: + """Fits the Gaussian Process model to the provided dataset. + + The dataset must contain the ``"observation"``, ``"action"``, and + ``("next", "observation")`` keys. The model predicts the difference + between the next observation and the current observation. + + Args: + dataset (TensorDictBase): A dataset of collected transitions. + """ + obs = dataset["observation"] + action = dataset["action"] + next_obs = dataset[("next", "observation")] + + X_train = torch.cat([obs, action], dim=-1).detach().to(self.device) + y_train = (next_obs - obs).detach().to(self.device) + self.X_train = X_train + + models = [] + for i in range(self.obs_dim): + train_x = X_train + train_y = y_train[:, i].unsqueeze(-1) + + covar_module = ScaleKernel( + RBFKernel( + ard_num_dims=self.input_dim, lengthscale_prior=GammaPrior(1.1, 0.1) + ), + outputscale_prior=GammaPrior(1.5, 0.5), + ) + + gp = SingleTaskGP( + train_X=train_x, train_Y=train_y, covar_module=covar_module + ) + gp.likelihood.noise_covar.register_prior( + "noise_prior", GammaPrior(1.2, 0.05), "noise" + ) + + models.append(gp) + + self.model_list = ModelListGP(*models).to(self.device) + mll = SumMarginalLogLikelihood(self.model_list.likelihood, self.model_list) + + fit_gpytorch_mll(mll) + self._extract_parameters(y_train) + + def _extract_parameters(self, y_train: torch.Tensor) -> None: + lengthscales, variances, noises, inv_Ks, betas = [], [], [], [], [] + + for i, gp in enumerate(self.model_list.models): + gp.eval() + gp.likelihood.eval() + + ls = gp.covar_module.base_kernel.lengthscale.squeeze().detach() + var = gp.covar_module.outputscale.detach() + noise = gp.likelihood.noise.squeeze().detach() + + lengthscales.append(ls) + variances.append(var) + noises.append(noise) + + X_scaled = self.X_train / ls + dist = torch.cdist(X_scaled, X_scaled, p=2) ** 2 + K = var * torch.exp(-0.5 * dist) + + K_noisy = K + (noise + 1e-6) * torch.eye( + self.X_train.size(0), device=self.device + ) + + L = torch.linalg.cholesky(K_noisy) + eye = torch.eye(L.size(0), dtype=L.dtype, device=L.device) + inv_K = torch.cholesky_solve(eye, L) + + y = y_train[:, i].unsqueeze(-1) + beta = torch.cholesky_solve(y, L).squeeze(-1) + + inv_Ks.append(inv_K) + betas.append(beta) + + self.lengthscales = torch.stack(lengthscales) + self.variances = torch.stack(variances).unsqueeze(-1) + self.noises = torch.stack(noises) + + self._cached_inv_K = torch.stack(inv_Ks) + self._cached_beta = torch.stack(betas) + + def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: + return self._cached_inv_K, self._cached_beta + + def _gather_gp_params(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.lengthscales, self.variances, self.noises + + def forward( + self, action: TensorDictBase, observation: TensorDictBase + ) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass for the GPWorldModel. + + Routes the request to either the deterministic or uncertain forward pass + depending on whether the observation input contains variance. + + Args: + action (TensorDictBase): The action tensordict. + observation (TensorDictBase): The observation tensordict. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the mean and + variance tensors of the next observation. + """ + observation_uncertain = False + x_var = observation.get("var", None) + if x_var is not None: + observation_uncertain = not torch.all( + torch.isclose(x_var, torch.zeros_like(x_var)) + ) + if observation_uncertain: + return self.uncertain_forward(action, observation) + else: + return self.deterministic_forward(action, observation) + + def freeze_and_detach(self) -> None: + """Freezes the model and detaches gradients.""" + + def uncertain_forward( + self, action: TensorDictBase, obs: TensorDictBase + ) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates the forward pass when the observation has uncertainty (non-zero variance). + + Propagates uncertainty through the Gaussian Process via exact moment matching. + + Args: + action (TensorDictBase): A tensordict containing ``"mean"``, ``"var"``, and + ``"cross_covariance"`` of the action. + obs (TensorDictBase): A tensordict containing the ``"mean"`` and ``"var"`` + of the current observation. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Next observation mean and variance matrices. + """ + inv_K, beta = self.compute_factorizations() + lengthscales, variances, noises = self._gather_gp_params() + + m_x, s_x = obs.get("mean"), obs.get("var") + m_u, s_u, c_xu = ( + action.get("mean"), + action.get("var"), + action.get("cross_covariance"), + ) + + device, dtype = m_x.device, m_x.dtype + + joint_mean = torch.cat([m_x, m_u], dim=-1) + + s_ = s_x @ c_xu + upper = torch.cat([s_x, s_], dim=-1) + lower = torch.cat([s_.transpose(-1, -2), s_u], dim=-1) + + joint_var = torch.cat([upper, lower], dim=-2) + + X_train = self.X_train + num_train_pts = X_train.shape[0] + batch_size = joint_mean.shape[0] + + inp = X_train - joint_mean.unsqueeze(1) + + inv_L = torch.diag_embed(1.0 / lengthscales).to(dtype=dtype, device=device) + inv_N = inp.unsqueeze(1) @ inv_L.unsqueeze(0) + + B_mat = inv_L.unsqueeze(0) @ joint_var.unsqueeze(1) @ inv_L.unsqueeze(0) + B_mat = B_mat + torch.eye( + self.input_dim, dtype=m_x.dtype, device=m_x.device + ).view(1, 1, self.input_dim, self.input_dim) + + t = torch.linalg.solve(B_mat, inv_N.transpose(-2, -1)).transpose(-2, -1) + + scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) + lb = scaled_exp * beta.unsqueeze(0) + + det_B = torch.linalg.det(B_mat) + c = variances.squeeze(1).unsqueeze(0) / torch.sqrt(det_B) + + pred_mean = torch.sum(lb, dim=-1) * c.squeeze(0) + + t_inv_L = t @ inv_L.unsqueeze(0) + + cross_cov_E_D = torch.matmul( + t_inv_L.transpose(-2, -1), lb.unsqueeze(-1) + ).squeeze(-1) * c.unsqueeze(-1) + cross_cov = cross_cov_E_D.transpose(-2, -1) + + pred_cov = torch.zeros( + batch_size, self.obs_dim, self.obs_dim, dtype=m_x.dtype, device=m_x.device + ) + + X_i = X_train.unsqueeze(1) + X_j = X_train.unsqueeze(0) + diff = X_i - X_j + joint_mean_flat = joint_mean.unsqueeze(1).unsqueeze(1) + + for a in range(self.obs_dim): + for b in range(self.obs_dim): + l2_a = lengthscales[a].to(device=device, dtype=dtype) ** 2 + l2_b = lengthscales[b].to(device=device, dtype=dtype) ** 2 + + inv_L_a = 1.0 / l2_a + inv_L_b = 1.0 / l2_b + inv_L_sum = inv_L_a + inv_L_b + Lambda_ab = 1.0 / inv_L_sum + + z_bar = Lambda_ab * (X_i * inv_L_a + X_j * inv_L_b) + z = z_bar.unsqueeze(0) - joint_mean_flat + + z_flat = z.view( + batch_size, num_train_pts * num_train_pts, self.input_dim + ) + + R_ab = joint_var @ torch.diag(inv_L_sum) + torch.eye( + self.input_dim, dtype=m_x.dtype, device=m_x.device + ).unsqueeze(0) + + inv_L_plus = 1.0 / (l2_a + l2_b) + exp1 = -0.5 * torch.sum(diff * inv_L_plus * diff, dim=-1) + + M_ab = joint_var + torch.diag(Lambda_ab).unsqueeze(0) + + solved_z_flat = torch.linalg.solve( + M_ab, z_flat.transpose(-2, -1) + ).transpose(-2, -1) + exp2 = (-0.5 * torch.sum(z_flat * solved_z_flat, dim=-1)).view( + batch_size, num_train_pts, num_train_pts + ) + + det_R_ab = torch.linalg.det(R_ab) + c_ab = variances[a] * variances[b] / torch.sqrt(det_R_ab) + + Q_ab = c_ab.view(-1, 1, 1) * torch.exp(exp1.unsqueeze(0) + exp2) + + Qb = torch.matmul(Q_ab, beta[b]) + pred_cov[:, a, b] = ( + torch.matmul(beta[a].unsqueeze(0), Qb.unsqueeze(-1)) + .squeeze(-1) + .squeeze(-1) + ) + + if a == b: + invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) + trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) + + pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item() + + outer_mean = torch.bmm(pred_mean.unsqueeze(-1), pred_mean.unsqueeze(-2)) + pred_cov = pred_cov - outer_mean + + pred_cov = (pred_cov + pred_cov.transpose(-2, -1)) / 2.0 + + m_dx = pred_mean + s_dx = pred_cov + c_xdx = cross_cov + + cov_xf = upper @ c_xdx + + m_x = m_x + m_dx + + s_x = s_x + s_dx + cov_xf + cov_xf.transpose(-2, -1) + + s_x = (s_x + s_x.transpose(-2, -1)) / 2.0 + s_x = s_x + 1e-8 * torch.eye(self.obs_dim, device=s_x.device).expand( + s_x.shape[0], -1, -1 + ) + return m_x, s_x + + def deterministic_forward( + self, action: TensorDictBase, observation: TensorDictBase + ) -> tuple[torch.Tensor, torch.Tensor]: + """Calculates the forward pass when the input observation is deterministic (no variance). + + Args: + action (TensorDictBase): A tensordict containing the ``"mean"`` of the action. + observation (TensorDictBase): A tensordict containing the ``"mean"`` of the + current observation. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Next observation mean and variance matrices. + """ + observation_mean = observation.get("mean") + action_mean = action.get("mean") + + x_flat = observation_mean.view(-1, self.obs_dim) + u_flat = action_mean.view(-1, self.action_dim) + + X_test = torch.cat([x_flat, u_flat], dim=-1) + + means, stds = [], [] + + with torch.no_grad(): + for gp in self.model_list.models: + posterior = gp.posterior(X_test) + means.append(posterior.mean.squeeze(-1)) + stds.append(torch.sqrt(posterior.variance).squeeze(-1)) + + delta_mean_flat = torch.stack(means, dim=-1) + delta_std_flat = torch.stack(stds, dim=-1) + + batch_shape = observation_mean.shape[:-1] + delta_mean = delta_mean_flat.view(*batch_shape, self.obs_dim) + delta_std = delta_std_flat.view(*batch_shape, self.obs_dim) + + return observation_mean + delta_mean, torch.diag_embed(delta_std**2) From 67b2c7fce4330c3c60df8a876af52cc3d01a43d0 Mon Sep 17 00:00:00 2001 From: Pedro Rosa Date: Sat, 14 Mar 2026 14:26:00 +0100 Subject: [PATCH 4/8] move the rbfcontroller to core and remove unused function --- sota-implementations/pilco/pilco.py | 4 +- sota-implementations/pilco/utils.py | 182 ----------------------- torchrl/modules/models/__init__.py | 4 + torchrl/modules/models/rbf_controller.py | 172 +++++++++++++++++++++ 4 files changed, 178 insertions(+), 184 deletions(-) create mode 100644 torchrl/modules/models/rbf_controller.py diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py index 8c4849134c4..b448ebf6062 100644 --- a/sota-implementations/pilco/pilco.py +++ b/sota-implementations/pilco/pilco.py @@ -8,11 +8,11 @@ from torchrl._utils import get_available_device from torchrl.envs import EnvBase from torchrl.envs.utils import RandomPolicy -from torchrl.modules.models import GPWorldModel +from torchrl.modules.models import GPWorldModel, RBFController from torchrl.objectives import ExponentialQuadraticCost from torchrl.record.loggers import generate_exp_name, get_logger, Logger -from utils import ImaginedEnv, make_env, RBFController +from utils import ImaginedEnv, make_env def pilco_loop( diff --git a/sota-implementations/pilco/utils.py b/sota-implementations/pilco/utils.py index 43713fe8565..6de2fc373d4 100644 --- a/sota-implementations/pilco/utils.py +++ b/sota-implementations/pilco/utils.py @@ -1,7 +1,6 @@ from collections.abc import Sequence import torch -import torch.nn as nn from tensordict import TensorDict, TensorDictBase from tensordict.nn import TensorDictModule @@ -28,57 +27,6 @@ def make_env( return env -def pendulum_cost( - obs: TensorDictBase, - weights: torch.Tensor | None = None, - target: torch.Tensor | None = None, -) -> torch.Tensor: - """ - obs["mean"]: [B, T, D] - obs["var"] : [B, T, D, D] - """ - m = obs.get("mean") - s = obs.get("var") - - B, T, D = m.shape - device = m.device - dtype = m.dtype - - if weights is None: - diag_vals = torch.tensor([1.0, 1.0, 1.0, 1.0], device=device, dtype=dtype) - weights = torch.diag(diag_vals) - - if target is None: - target = torch.zeros(D, device=device, dtype=dtype) - - if target.dim() == 1: - target = target.view(1, 1, D).expand(B, T, D) - - eye = torch.eye(D, device=device, dtype=dtype).view(1, 1, D, D) - diff = (m - target).unsqueeze(-1) # [B, T, D, 1] - - L_w, V_w = torch.linalg.eigh(weights) - L_w = torch.clamp(L_w, min=0.0) - U = V_w @ torch.diag_embed(torch.sqrt(L_w)) @ V_w.transpose(-2, -1) - - A_sym = eye + torch.matmul(U, torch.matmul(s, U)) - - jitter = 1e-5 - A_sym = A_sym + jitter * eye - - L = torch.linalg.cholesky(A_sym) - - log_det = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(-1) - det_term = torch.exp(-0.5 * log_det) - - v = torch.matmul(U, diff) - tmp = torch.cholesky_solve(v, L) - quad = torch.matmul(v.transpose(-2, -1), tmp) - exp_term = (-0.5 * quad).squeeze(-1).squeeze(-1) - - return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1) - - class ImaginedEnv(ModelBasedEnvBase): def __init__( self, @@ -164,133 +112,3 @@ def _reset( out.set("terminated", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) return out - - -class RBFController(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - max_action: float | torch.Tensor, - n_basis: int = 10, - ) -> None: - super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.max_action = max_action - self.n_basis = n_basis - - self.centers = nn.Parameter(torch.randn(n_basis, input_dim) * 0.5) - self.weights = nn.Parameter(torch.randn(n_basis, output_dim) * 0.1) - self.lengthscales = nn.Parameter(torch.ones(input_dim)) - self.variance = 1.0 - - @staticmethod - def squash_sin( - m: torch.Tensor, s: torch.Tensor, max_action: float | torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - B, K = m.shape - device = m.device - dtype = m.dtype - - if not isinstance(max_action, torch.Tensor): - max_action = torch.tensor(max_action, dtype=dtype, device=device) - - max_action = max_action.view(-1) - if max_action.shape[0] == 1 and K > 1: - max_action = max_action.expand(K) - - diag_s = torch.diagonal(s, dim1=-2, dim2=-1) - - M = max_action * torch.exp(-diag_s / 2.0) * torch.sin(m) - - lq = -(diag_s.unsqueeze(-1) + diag_s.unsqueeze(-2)) / 2.0 - q = torch.exp(lq) - - m_diff = m.unsqueeze(-1) - m.unsqueeze(-2) - m_sum = m.unsqueeze(-1) + m.unsqueeze(-2) - - S = (torch.exp(lq + s) - q) * torch.cos(m_diff) - ( - torch.exp(lq - s) - q - ) * torch.cos(m_sum) - - outer_max = max_action.unsqueeze(1) * max_action.unsqueeze(0) - S = outer_max.unsqueeze(0) * S / 2.0 - - C = torch.diag_embed(max_action * torch.exp(-diag_s / 2.0) * torch.cos(m)) - - return M, S, C - - def forward( - self, m: torch.Tensor, S: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - B, D = m.shape - N = self.n_basis - device = m.device - - iL = torch.diag(1.0 / self.lengthscales) - iL_batch = iL.unsqueeze(0) - - inp = self.centers.unsqueeze(0) - m.unsqueeze(1) - - B_mat = iL_batch @ S @ iL_batch + torch.eye( - D, device=device, dtype=m.dtype - ).unsqueeze(0) - - iN = inp @ iL - - t = torch.linalg.solve(B_mat, iN.mT).mT - - exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1)) - detB = torch.linalg.det(B_mat) - c = self.variance / torch.sqrt(detB) - phi_mean = c.unsqueeze(-1) * exp_term - - M = phi_mean @ self.weights - - tiL = t @ iL - V = torch.bmm(tiL.mT, phi_mean.unsqueeze(-1) * self.weights) - - c_i = self.centers.unsqueeze(1) - c_j = self.centers.unsqueeze(0) - diff = c_i - c_j - c_bar = (c_i + c_j) / 2.0 - - inv_Lambda = 1.0 / (self.lengthscales**2) - exp1 = -0.25 * torch.sum((diff**2) * inv_Lambda, dim=-1) - - Lambda_half = torch.diag((self.lengthscales**2) / 2.0) - B_q = S + Lambda_half.unsqueeze(0) - - z = c_bar.unsqueeze(0) - m.unsqueeze(1).unsqueeze(1) - z_flat = z.view(B, N * N, D) - - solved_z_flat = torch.linalg.solve(B_q, z_flat.mT).mT - exp2 = -0.5 * torch.sum(z_flat * solved_z_flat, dim=-1).view(B, N, N) - - log_det_Lambda_half = torch.sum(torch.log((self.lengthscales**2) / 2.0)) - log_det_B_q = torch.logdet(B_q) - c_q = torch.exp(0.5 * (log_det_Lambda_half - log_det_B_q)) - - Q = (self.variance**2 * c_q.view(B, 1, 1)) * torch.exp( - exp1.unsqueeze(0) + exp2 - ) - - W_batch = self.weights.unsqueeze(0).expand(B, N, -1) - S_action = torch.bmm(W_batch.mT, torch.bmm(Q, W_batch)) - - M_out = torch.bmm(M.unsqueeze(-1), M.unsqueeze(1)) - S_action = S_action - M_out - - S_action = (S_action + S_action.mT) / 2.0 - S_action = ( - S_action - + torch.eye(self.output_dim, device=device, dtype=m.dtype).unsqueeze(0) - * 1e-6 - ) - - if self.max_action is not None: - M, S_action, C = self.squash_sin(M, S_action, self.max_action) - V = torch.bmm(V, C) - - return M, S_action, V diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 98d34666cf8..a5b0bb09fbb 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -16,6 +16,7 @@ NoisyLinear, reset_noise, ) +from .gp import GPWorldModel from .llm import GPT2RewardModel from .model_based import ( DreamerActor, @@ -46,6 +47,7 @@ QMixer, VDNMixer, ) +from .rbf_controller import RBFController from .utils import Squeeze2dLayer, SqueezeLayer __all__ = [ @@ -53,6 +55,7 @@ "BatchRenorm1d", "DecisionTransformer", "GPT2RewardModel", + "GPWorldModel", "ConsistentDropout", "ConsistentDropoutModule", "NoisyLazyLinear", @@ -81,6 +84,7 @@ "MultiAgentNetBase", "QMixer", "VDNMixer", + "RBFController", "Squeeze2dLayer", "SqueezeLayer", ] diff --git a/torchrl/modules/models/rbf_controller.py b/torchrl/modules/models/rbf_controller.py new file mode 100644 index 00000000000..32b37dc8184 --- /dev/null +++ b/torchrl/modules/models/rbf_controller.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from torch import nn + + +class RBFController(nn.Module): + """A Radial Basis Function (RBF) controller. + + Args: + input_dim (int): The dimensionality of the input space. + output_dim (int): The dimensionality of the output space. + max_action (float or torch.Tensor): The maximum action magnitude used for the squashing function. + n_basis (int, optional): The number of basis functions to use. Defaults to 10. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + max_action: float | torch.Tensor, + n_basis: int = 10, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.max_action = max_action + self.n_basis = n_basis + + self.centers = nn.Parameter(torch.randn(n_basis, input_dim) * 0.5) + self.weights = nn.Parameter(torch.randn(n_basis, output_dim) * 0.1) + self.lengthscales = nn.Parameter(torch.ones(input_dim)) + self.variance = 1.0 + + @staticmethod + def squash_sin( + m: torch.Tensor, s: torch.Tensor, max_action: float | torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Squashes the output using a sine function to keep actions within the bounded range. + + Args: + m (torch.Tensor): The mean of the distribution. + s (torch.Tensor): The covariance matrix of the distribution. + max_action (float or torch.Tensor): The maximum magnitude of the action bounds. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - M (torch.Tensor): The squashed mean. + - S (torch.Tensor): The squashed covariance. + - C (torch.Tensor): The cross-covariance between input and output. + """ + B, K = m.shape + device = m.device + dtype = m.dtype + + if not isinstance(max_action, torch.Tensor): + max_action = torch.tensor(max_action, dtype=dtype, device=device) + + max_action = max_action.view(-1) + if max_action.shape[0] == 1 and K > 1: + max_action = max_action.expand(K) + + diag_s = torch.diagonal(s, dim1=-2, dim2=-1) + + M = max_action * torch.exp(-diag_s / 2.0) * torch.sin(m) + + lq = -(diag_s.unsqueeze(-1) + diag_s.unsqueeze(-2)) / 2.0 + q = torch.exp(lq) + + m_diff = m.unsqueeze(-1) - m.unsqueeze(-2) + m_sum = m.unsqueeze(-1) + m.unsqueeze(-2) + + S = (torch.exp(lq + s) - q) * torch.cos(m_diff) - ( + torch.exp(lq - s) - q + ) * torch.cos(m_sum) + + outer_max = max_action.unsqueeze(1) * max_action.unsqueeze(0) + S = outer_max.unsqueeze(0) * S / 2.0 + + C = torch.diag_embed(max_action * torch.exp(-diag_s / 2.0) * torch.cos(m)) + + return M, S, C + + def forward( + self, m: torch.Tensor, S: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Computes the forward pass of the RBF Controller. + + Args: + m (torch.Tensor): The mean of the input tensor. + S (torch.Tensor): The covariance matrix of the input tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - M (torch.Tensor): The mean of the action distribution. + - S_action (torch.Tensor): The covariance of the action distribution. + - V (torch.Tensor): The input-output cross-covariance. + """ + B, D = m.shape + N = self.n_basis + device = m.device + + iL = torch.diag(1.0 / self.lengthscales) + iL_batch = iL.unsqueeze(0) + + inp = self.centers.unsqueeze(0) - m.unsqueeze(1) + + B_mat = iL_batch @ S @ iL_batch + torch.eye( + D, device=device, dtype=m.dtype + ).unsqueeze(0) + + iN = inp @ iL + + t = torch.linalg.solve(B_mat, iN.mT).mT + + exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1)) + detB = torch.linalg.det(B_mat) + c = self.variance / torch.sqrt(detB) + phi_mean = c.unsqueeze(-1) * exp_term + + M = phi_mean @ self.weights + + tiL = t @ iL + V = torch.bmm(tiL.mT, phi_mean.unsqueeze(-1) * self.weights) + + c_i = self.centers.unsqueeze(1) + c_j = self.centers.unsqueeze(0) + diff = c_i - c_j + c_bar = (c_i + c_j) / 2.0 + + inv_Lambda = 1.0 / (self.lengthscales**2) + exp1 = -0.25 * torch.sum((diff**2) * inv_Lambda, dim=-1) + + Lambda_half = torch.diag((self.lengthscales**2) / 2.0) + B_q = S + Lambda_half.unsqueeze(0) + + z = c_bar.unsqueeze(0) - m.unsqueeze(1).unsqueeze(1) + z_flat = z.view(B, N * N, D) + + solved_z_flat = torch.linalg.solve(B_q, z_flat.mT).mT + exp2 = -0.5 * torch.sum(z_flat * solved_z_flat, dim=-1).view(B, N, N) + + log_det_Lambda_half = torch.sum(torch.log((self.lengthscales**2) / 2.0)) + log_det_B_q = torch.logdet(B_q) + c_q = torch.exp(0.5 * (log_det_Lambda_half - log_det_B_q)) + + Q = (self.variance**2 * c_q.view(B, 1, 1)) * torch.exp( + exp1.unsqueeze(0) + exp2 + ) + + W_batch = self.weights.unsqueeze(0).expand(B, N, -1) + S_action = torch.bmm(W_batch.mT, torch.bmm(Q, W_batch)) + + M_out = torch.bmm(M.unsqueeze(-1), M.unsqueeze(1)) + S_action = S_action - M_out + + S_action = (S_action + S_action.mT) / 2.0 + S_action = ( + S_action + + torch.eye(self.output_dim, device=device, dtype=m.dtype).unsqueeze(0) + * 1e-6 + ) + + if self.max_action is not None: + M, S_action, C = self.squash_sin(M, S_action, self.max_action) + V = torch.bmm(V, C) + + return M, S_action, V From 438710aec5f7833646d56c668e7c2a60f515c8ae Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Mar 2026 16:09:47 +0000 Subject: [PATCH 5/8] [Feature] Port PILCO components to core with tests, docs, and CI Port RBFController, ImaginedEnv, and MeanActionSelector from sota-implementations/pilco/utils.py to torchrl core. Add unit tests, documentation entries, and a botorch CI job in test-linux-libs. - torchrl/modules/models/rbf_controller.py: RBF controller for moment-matching policy search with full docstrings - torchrl/envs/model_based/imagined.py: general-purpose imagination env for model-based policy search - torchrl/envs/transforms/mean_action_selector.py: transform bridging Gaussian belief-space policies with standard environments - Improve GPWorldModel: slogdet for numerical stability, remove .item() - Register GPWorldModel and RBFController in module exports - Add all new components to docs - Add 39 unit tests in test/test_objectives.py - Add botorch CI job to test-linux-libs workflow - Update sota-implementations/pilco to import from core Made-with: Cursor --- .../scripts_botorch/environment.yml | 23 + .../linux_libs/scripts_botorch/install.sh | 53 ++ .../scripts_botorch/post_process.sh | 6 + .../linux_libs/scripts_botorch/run_test.sh | 33 + .../linux_libs/scripts_botorch/setup_env.sh | 44 ++ .github/workflows/test-linux-libs.yml | 38 ++ docs/source/reference/envs_api.rst | 1 + docs/source/reference/envs_transforms.rst | 1 + docs/source/reference/modules_models.rst | 12 + docs/source/reference/objectives_other.rst | 1 + sota-implementations/pilco/pilco.py | 58 +- sota-implementations/pilco/utils.py | 286 +------- test/test_objectives.py | 627 +++++++++++++++++- torchrl/envs/__init__.py | 5 +- torchrl/envs/model_based/__init__.py | 3 +- torchrl/envs/model_based/imagined.py | 164 +++++ torchrl/envs/transforms/__init__.py | 2 + .../envs/transforms/mean_action_selector.py | 103 +++ torchrl/modules/__init__.py | 4 + torchrl/modules/models/__init__.py | 32 +- torchrl/modules/models/gp.py | 45 +- torchrl/modules/models/rbf_controller.py | 223 +++++++ 22 files changed, 1406 insertions(+), 358 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_botorch/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_botorch/install.sh create mode 100755 .github/unittest/linux_libs/scripts_botorch/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_botorch/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_botorch/setup_env.sh create mode 100644 torchrl/envs/model_based/imagined.py create mode 100644 torchrl/envs/transforms/mean_action_selector.py create mode 100644 torchrl/modules/models/rbf_controller.py diff --git a/.github/unittest/linux_libs/scripts_botorch/environment.yml b/.github/unittest/linux_libs/scripts_botorch/environment.yml new file mode 100644 index 00000000000..c6a5013405e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/environment.yml @@ -0,0 +1,23 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-json-report + - pytest-error-for-skips + - expecttest + - pybind11[global] + - pyyaml + - scipy + - botorch + - gpytorch + - psutil diff --git a/.github/unittest/linux_libs/scripts_botorch/install.sh b/.github/unittest/linux_libs/scripts_botorch/install.sh new file mode 100755 index 00000000000..395e15ea99e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/install.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION + +set -euxo pipefail + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu128" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git --progress-bar off + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python -m pip install -e . --no-build-isolation + +# smoke test +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_botorch/post_process.sh b/.github/unittest/linux_libs/scripts_botorch/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_botorch/run_test.sh b/.github/unittest/linux_libs/scripts_botorch/run_test.sh new file mode 100755 index 00000000000..3d732357ef6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/run_test.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False + +python -m torch.utils.collect_env +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +export MKL_THREADING_LAYER=GNU + +# smoke test +python -c "import botorch; print('botorch', botorch.__version__)" +python -c "import gpytorch; print('gpytorch', gpytorch.__version__)" + +# JSON report for flaky test tracking +json_report_dir="${RUNNER_ARTIFACT_DIR:-${root_dir}}" +json_report_args="--json-report --json-report-file=${json_report_dir}/test-results-botorch.json --json-report-indent=2" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_objectives.py ${json_report_args} --instafail -v --durations 200 --capture no -k TestGPWorldModel --error-for-skips +coverage combine -q +coverage xml -i + +# Upload test results with metadata for flaky tracking +python .github/unittest/helpers/upload_test_results.py || echo "Warning: Failed to process test results for flaky tracking" diff --git a/.github/unittest/linux_libs/scripts_botorch/setup_env.sh b/.github/unittest/linux_libs/scripts_botorch/setup_env.sh new file mode 100755 index 00000000000..d7dbd1bb7e6 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_botorch/setup_env.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +apt-get update && apt-get upgrade -y && apt-get install -y git cmake +git config --global --add safe.directory '*' +apt-get install -y wget gcc g++ + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index b1840a11fef..0f413792711 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -93,6 +93,44 @@ jobs: bash .github/unittest/linux_libs/scripts_brax/run_all.sh + unittests-botorch: + strategy: + matrix: + python_version: ["3.10"] + cuda_arch_version: ["12.8"] + if: ${{ github.event_name == 'push' || github.event_name == 'workflow_call' || github.event_name == 'workflow_dispatch' || contains(github.event.pull_request.labels.*.name, 'Modules') || contains(github.event.pull_request.labels.*.name, 'Objectives') }} + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "12.8" + docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.10" + export CU_VERSION="12.8" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export TD_GET_DEFAULTS_TO_NONE=1 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_botorch/setup_env.sh + bash .github/unittest/linux_libs/scripts_botorch/install.sh + bash .github/unittest/linux_libs/scripts_botorch/run_test.sh + bash .github/unittest/linux_libs/scripts_botorch/post_process.sh + # unittests-d4rl: # strategy: # matrix: diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst index bf6ba8b9a96..66f91682377 100644 --- a/docs/source/reference/envs_api.rst +++ b/docs/source/reference/envs_api.rst @@ -191,6 +191,7 @@ Domain-specific ModelBasedEnvBase model_based.dreamer.DreamerEnv model_based.dreamer.DreamerDecoder + model_based.imagined.ImaginedEnv Helpers ------- diff --git a/docs/source/reference/envs_transforms.rst b/docs/source/reference/envs_transforms.rst index e3f8ab55fab..b345c493131 100644 --- a/docs/source/reference/envs_transforms.rst +++ b/docs/source/reference/envs_transforms.rst @@ -273,6 +273,7 @@ Available Transforms Hash InitTracker LineariseRewards + MeanActionSelector ModuleTransform MultiAction NoopResetEnv diff --git a/docs/source/reference/modules_models.rst b/docs/source/reference/modules_models.rst index be3e74ef0c7..68891fe4e67 100644 --- a/docs/source/reference/modules_models.rst +++ b/docs/source/reference/modules_models.rst @@ -16,3 +16,15 @@ Modules for model-based reinforcement learning, including world models and dynam RSSMPosterior RSSMPrior RSSMRollout + +PILCO +----- + +Components for moment-matching model-based policy search (PILCO). + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + GPWorldModel + RBFController diff --git a/docs/source/reference/objectives_other.rst b/docs/source/reference/objectives_other.rst index 018268ed7f6..b97d3efea50 100644 --- a/docs/source/reference/objectives_other.rst +++ b/docs/source/reference/objectives_other.rst @@ -15,3 +15,4 @@ Additional loss modules for specialized algorithms. DreamerActorLoss DreamerModelLoss DreamerValueLoss + ExponentialQuadraticCost diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py index 8c4849134c4..024359e0fd3 100644 --- a/sota-implementations/pilco/pilco.py +++ b/sota-implementations/pilco/pilco.py @@ -3,16 +3,18 @@ import torch from omegaconf import DictConfig -from tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict from tensordict.nn import TensorDictModule from torchrl._utils import get_available_device -from torchrl.envs import EnvBase +from torchrl.envs import EnvBase, TransformedEnv +from torchrl.envs.model_based import ImaginedEnv +from torchrl.envs.transforms import MeanActionSelector from torchrl.envs.utils import RandomPolicy -from torchrl.modules.models import GPWorldModel +from torchrl.modules.models import GPWorldModel, RBFController from torchrl.objectives import ExponentialQuadraticCost from torchrl.record.loggers import generate_exp_name, get_logger, Logger -from utils import ImaginedEnv, make_env, RBFController +from utils import make_env def pilco_loop( @@ -61,6 +63,8 @@ def pilco_loop( } ) + eval_env = TransformedEnv(env, MeanActionSelector()) + cost_module = ExponentialQuadraticCost(reduction="none").to(env.device) for epoch in range(cfg.pilco.epochs): base_world_model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim).to( @@ -102,50 +106,14 @@ def pilco_loop( "train/trajectory_cost", loss.item(), step=logger_step ) - def policy_for_env(td: TensorDictBase) -> TensorDictBase: - obs = td["observation"] - device, dtype = obs.device, obs.dtype - - is_unbatched = obs.ndim == 1 - if is_unbatched: - obs = obs.unsqueeze(0) - - batch_shape = obs.shape[:-1] - D = obs.shape[-1] - - policy_in = TensorDict( - { - "observation": TensorDict( - { - "mean": obs, - "var": torch.zeros( - (*batch_shape, D, D), device=device, dtype=dtype - ), - }, - batch_size=batch_shape, - ) - }, - batch_size=batch_shape, - device=device, - ) - - policy_out = policy_module(policy_in) - action_mean = policy_out["action", "mean"] - - if is_unbatched: - action_mean = action_mean.squeeze(0) - - td["action"] = action_mean - return td - - test_rollout = env.rollout( + test_rollout = eval_env.rollout( max_steps=100, - policy=policy_for_env, - break_when_any_done=True, # TODO change the max_steps back maybe? + policy=policy_module, + break_when_any_done=True, ) - reward = test_rollout["episode_reward"][-1].item() - steps = test_rollout["step_count"].max().item() + reward = test_rollout["episode_reward"][-1].tolist() + steps = test_rollout["step_count"].max().tolist() if logger: logger.log_scalar("eval/reward", reward, step=logger_step) diff --git a/sota-implementations/pilco/utils.py b/sota-implementations/pilco/utils.py index 43713fe8565..034cc90ca24 100644 --- a/sota-implementations/pilco/utils.py +++ b/sota-implementations/pilco/utils.py @@ -1,296 +1,14 @@ -from collections.abc import Sequence - import torch -import torch.nn as nn - -from tensordict import TensorDict, TensorDictBase -from tensordict.nn import TensorDictModule - -from torchrl.envs import ( - EnvBase, - GymEnv, - ModelBasedEnvBase, - RewardSum, - StepCounter, - TransformedEnv, -) +from torchrl.envs import GymEnv, RewardSum, StepCounter, TransformedEnv def make_env( env_name: str, device: str | torch.device, from_pixels: bool = False ) -> TransformedEnv: - """Creates the transformed environment.""" + """Creates the transformed environment for PILCO experiments.""" env = TransformedEnv( GymEnv(env_name, pixels_only=False, from_pixels=from_pixels, device=device) ) env.append_transform(RewardSum()) env.append_transform(StepCounter()) return env - - -def pendulum_cost( - obs: TensorDictBase, - weights: torch.Tensor | None = None, - target: torch.Tensor | None = None, -) -> torch.Tensor: - """ - obs["mean"]: [B, T, D] - obs["var"] : [B, T, D, D] - """ - m = obs.get("mean") - s = obs.get("var") - - B, T, D = m.shape - device = m.device - dtype = m.dtype - - if weights is None: - diag_vals = torch.tensor([1.0, 1.0, 1.0, 1.0], device=device, dtype=dtype) - weights = torch.diag(diag_vals) - - if target is None: - target = torch.zeros(D, device=device, dtype=dtype) - - if target.dim() == 1: - target = target.view(1, 1, D).expand(B, T, D) - - eye = torch.eye(D, device=device, dtype=dtype).view(1, 1, D, D) - diff = (m - target).unsqueeze(-1) # [B, T, D, 1] - - L_w, V_w = torch.linalg.eigh(weights) - L_w = torch.clamp(L_w, min=0.0) - U = V_w @ torch.diag_embed(torch.sqrt(L_w)) @ V_w.transpose(-2, -1) - - A_sym = eye + torch.matmul(U, torch.matmul(s, U)) - - jitter = 1e-5 - A_sym = A_sym + jitter * eye - - L = torch.linalg.cholesky(A_sym) - - log_det = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(-1) - det_term = torch.exp(-0.5 * log_det) - - v = torch.matmul(U, diff) - tmp = torch.cholesky_solve(v, L) - quad = torch.matmul(v.transpose(-2, -1), tmp) - exp_term = (-0.5 * quad).squeeze(-1).squeeze(-1) - - return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1) - - -class ImaginedEnv(ModelBasedEnvBase): - def __init__( - self, - world_model_module: TensorDictModule, - base_env: EnvBase, - batch_size: int | torch.Size | Sequence[int] | None = None, - **kwargs - ) -> None: - if batch_size is not None: - self.batch_size = ( - torch.Size(batch_size) - if not isinstance(batch_size, torch.Size) - else batch_size - ) - elif len(base_env.batch_size) == 0: - self.batch_size = torch.Size([1]) - else: - self.batch_size = base_env.batch_size - - super().__init__( - world_model_module, - device=base_env.device, - batch_size=self.batch_size, - **kwargs - ) - - self.observation_spec = base_env.observation_spec.expand( - self.batch_size - ).clone() - self.action_spec = base_env.action_spec.expand(self.batch_size).clone() - self.reward_spec = base_env.reward_spec.expand(self.batch_size).clone() - self.done_spec = base_env.done_spec.expand(self.batch_size).clone() - - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - tensordict = self.world_model(tensordict) - - reward = torch.zeros(*tensordict.shape, 1, device=self.device) - done = torch.zeros(*tensordict.shape, 1, dtype=torch.bool, device=self.device) - out = TensorDict( - { - "observation": tensordict.get("next_observation"), - "reward": reward, - "done": done, - "terminated": done.clone(), - }, - tensordict.shape, - ) - return out - - def _reset( - self, tensordict: TensorDictBase | None = None, **kwargs - ) -> TensorDictBase: - if tensordict is None: - tensordict = TensorDict({}, batch_size=self.batch_size, device=self.device) - - if ( - tensordict.get(("observation", "var"), None) is not None - and tensordict.get(("observation", "mean"), None) is not None - ): - return tensordict.copy() - - obs = tensordict.get("observation", None) - if obs is None: - obs = self.observation_spec.rand(shape=self.batch_size).get("observation") - if obs.ndim == 1: - obs = obs.expand(self.batch_size, -1) - - obs = obs.to(self.device) - B, D = obs.shape - - out = TensorDict( - { - ("observation", "mean"): obs, - ("observation", "var"): torch.zeros( - B, D, D, dtype=obs.dtype, device=self.device - ), - }, - batch_size=self.batch_size, - device=self.device, - ) - - out.set("done", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) - out.set("terminated", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) - - return out - - -class RBFController(nn.Module): - def __init__( - self, - input_dim: int, - output_dim: int, - max_action: float | torch.Tensor, - n_basis: int = 10, - ) -> None: - super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.max_action = max_action - self.n_basis = n_basis - - self.centers = nn.Parameter(torch.randn(n_basis, input_dim) * 0.5) - self.weights = nn.Parameter(torch.randn(n_basis, output_dim) * 0.1) - self.lengthscales = nn.Parameter(torch.ones(input_dim)) - self.variance = 1.0 - - @staticmethod - def squash_sin( - m: torch.Tensor, s: torch.Tensor, max_action: float | torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - B, K = m.shape - device = m.device - dtype = m.dtype - - if not isinstance(max_action, torch.Tensor): - max_action = torch.tensor(max_action, dtype=dtype, device=device) - - max_action = max_action.view(-1) - if max_action.shape[0] == 1 and K > 1: - max_action = max_action.expand(K) - - diag_s = torch.diagonal(s, dim1=-2, dim2=-1) - - M = max_action * torch.exp(-diag_s / 2.0) * torch.sin(m) - - lq = -(diag_s.unsqueeze(-1) + diag_s.unsqueeze(-2)) / 2.0 - q = torch.exp(lq) - - m_diff = m.unsqueeze(-1) - m.unsqueeze(-2) - m_sum = m.unsqueeze(-1) + m.unsqueeze(-2) - - S = (torch.exp(lq + s) - q) * torch.cos(m_diff) - ( - torch.exp(lq - s) - q - ) * torch.cos(m_sum) - - outer_max = max_action.unsqueeze(1) * max_action.unsqueeze(0) - S = outer_max.unsqueeze(0) * S / 2.0 - - C = torch.diag_embed(max_action * torch.exp(-diag_s / 2.0) * torch.cos(m)) - - return M, S, C - - def forward( - self, m: torch.Tensor, S: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - B, D = m.shape - N = self.n_basis - device = m.device - - iL = torch.diag(1.0 / self.lengthscales) - iL_batch = iL.unsqueeze(0) - - inp = self.centers.unsqueeze(0) - m.unsqueeze(1) - - B_mat = iL_batch @ S @ iL_batch + torch.eye( - D, device=device, dtype=m.dtype - ).unsqueeze(0) - - iN = inp @ iL - - t = torch.linalg.solve(B_mat, iN.mT).mT - - exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1)) - detB = torch.linalg.det(B_mat) - c = self.variance / torch.sqrt(detB) - phi_mean = c.unsqueeze(-1) * exp_term - - M = phi_mean @ self.weights - - tiL = t @ iL - V = torch.bmm(tiL.mT, phi_mean.unsqueeze(-1) * self.weights) - - c_i = self.centers.unsqueeze(1) - c_j = self.centers.unsqueeze(0) - diff = c_i - c_j - c_bar = (c_i + c_j) / 2.0 - - inv_Lambda = 1.0 / (self.lengthscales**2) - exp1 = -0.25 * torch.sum((diff**2) * inv_Lambda, dim=-1) - - Lambda_half = torch.diag((self.lengthscales**2) / 2.0) - B_q = S + Lambda_half.unsqueeze(0) - - z = c_bar.unsqueeze(0) - m.unsqueeze(1).unsqueeze(1) - z_flat = z.view(B, N * N, D) - - solved_z_flat = torch.linalg.solve(B_q, z_flat.mT).mT - exp2 = -0.5 * torch.sum(z_flat * solved_z_flat, dim=-1).view(B, N, N) - - log_det_Lambda_half = torch.sum(torch.log((self.lengthscales**2) / 2.0)) - log_det_B_q = torch.logdet(B_q) - c_q = torch.exp(0.5 * (log_det_Lambda_half - log_det_B_q)) - - Q = (self.variance**2 * c_q.view(B, 1, 1)) * torch.exp( - exp1.unsqueeze(0) + exp2 - ) - - W_batch = self.weights.unsqueeze(0).expand(B, N, -1) - S_action = torch.bmm(W_batch.mT, torch.bmm(Q, W_batch)) - - M_out = torch.bmm(M.unsqueeze(-1), M.unsqueeze(1)) - S_action = S_action - M_out - - S_action = (S_action + S_action.mT) / 2.0 - S_action = ( - S_action - + torch.eye(self.output_dim, device=device, dtype=m.dtype).unsqueeze(0) - * 1e-6 - ) - - if self.max_action is not None: - M, S_action, C = self.squash_sin(M, S_action, self.max_action) - V = torch.bmm(V, C) - - return M, S_action, V diff --git a/test/test_objectives.py b/test/test_objectives.py index 39ad38aa2f6..275fc31faf3 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -58,7 +58,8 @@ from torchrl.envs import EnvBase, GymEnv, InitTracker, SerialEnv from torchrl.envs.libs.gym import _has_gym from torchrl.envs.model_based.dreamer import DreamerEnv -from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv +from torchrl.envs.model_based.imagined import ImaginedEnv +from torchrl.envs.transforms import MeanActionSelector, TensorDictPrimer, TransformedEnv from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type from torchrl.modules import ( DistributionalQValueActor, @@ -82,6 +83,7 @@ RSSMRollout, ) from torchrl.modules.models.models import MLP +from torchrl.modules.models.rbf_controller import RBFController from torchrl.modules.tensordict_module.actors import ( Actor, ActorCriticOperator, @@ -105,6 +107,7 @@ DreamerModelLoss, DreamerValueLoss, DTLoss, + ExponentialQuadraticCost, GAILLoss, IQLLoss, KLPENPPOLoss, @@ -175,6 +178,7 @@ FUNCTORCH_ERR = str(err) _has_transformers = bool(importlib.util.find_spec("transformers")) +_has_botorch = bool(importlib.util.find_spec("botorch")) TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) IS_WINDOWS = sys.platform == "win32" @@ -18295,6 +18299,627 @@ def test_make_value_estimator_with_gae_instance(self, device): assert loss_fn.value_type is GAE +class TestRBFController: + @pytest.mark.parametrize("input_dim", [2, 4]) + @pytest.mark.parametrize("output_dim", [1, 3]) + @pytest.mark.parametrize("n_basis", [5, 10]) + def test_forward_shapes(self, input_dim, output_dim, n_basis): + max_action = torch.ones(output_dim) + controller = RBFController( + input_dim=input_dim, + output_dim=output_dim, + max_action=max_action, + n_basis=n_basis, + ).double() + + batch_size = 3 + mean = torch.randn(batch_size, input_dim, dtype=torch.float64) + cov = ( + torch.eye(input_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch_size, -1, -1) + * 0.1 + ) + + action_mean, action_cov, cross_cov = controller(mean, cov) + + assert action_mean.shape == (batch_size, output_dim) + assert action_cov.shape == (batch_size, output_dim, output_dim) + assert cross_cov.shape == (batch_size, input_dim, output_dim) + + def test_action_covariance_is_symmetric(self): + controller = RBFController( + input_dim=4, output_dim=2, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + cov = torch.eye(4, dtype=torch.float64).unsqueeze(0).expand(2, -1, -1) * 0.1 + + _, action_cov, _ = controller(mean, cov) + + torch.testing.assert_close( + action_cov, action_cov.transpose(-2, -1), atol=1e-6, rtol=1e-5 + ) + + def test_action_covariance_is_positive_semidefinite(self): + controller = RBFController( + input_dim=4, output_dim=2, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + cov = torch.eye(4, dtype=torch.float64).unsqueeze(0).expand(2, -1, -1) * 0.1 + + _, action_cov, _ = controller(mean, cov) + + eigenvalues = torch.linalg.eigvalsh(action_cov) + assert ( + eigenvalues >= -1e-6 + ).all(), f"Negative eigenvalues found: {eigenvalues}" + + @pytest.mark.parametrize("max_action", [0.5, 1.0, 2.0]) + def test_squash_sin_bounds(self, max_action): + mean = torch.randn(10, 3, dtype=torch.float64) + cov = torch.eye(3, dtype=torch.float64).unsqueeze(0).expand(10, -1, -1) * 0.01 + + squashed_mean, squashed_cov, cross_cov = RBFController.squash_sin( + mean, cov, max_action + ) + + assert (squashed_mean.abs() <= max_action + 1e-6).all() + assert squashed_cov.shape == (10, 3, 3) + assert cross_cov.shape == (10, 3, 3) + + def test_deterministic_with_zero_variance(self): + controller = RBFController( + input_dim=4, output_dim=1, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + zero_cov = torch.zeros(2, 4, 4, dtype=torch.float64) + + action_mean1, _, _ = controller(mean, zero_cov) + action_mean2, _, _ = controller(mean, zero_cov) + + torch.testing.assert_close(action_mean1, action_mean2) + + def test_gradients_flow(self): + controller = RBFController( + input_dim=4, output_dim=1, max_action=1.0, n_basis=5 + ).double() + + mean = torch.randn(2, 4, dtype=torch.float64) + cov = torch.eye(4, dtype=torch.float64).unsqueeze(0).expand(2, -1, -1) * 0.1 + + action_mean, action_cov, cross_cov = controller(mean, cov) + loss = action_mean.sum() + action_cov.sum() + loss.backward() + + for name, param in controller.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + + def test_as_tensordict_module(self): + controller = RBFController( + input_dim=4, output_dim=1, max_action=1.0, n_basis=5 + ).double() + + module = TensorDictModule( + module=controller, + in_keys=[("observation", "mean"), ("observation", "var")], + out_keys=[ + ("action", "mean"), + ("action", "var"), + ("action", "cross_covariance"), + ], + ) + + td = TensorDict( + { + ("observation", "mean"): torch.randn(2, 4, dtype=torch.float64), + ("observation", "var"): torch.eye(4, dtype=torch.float64) + .unsqueeze(0) + .expand(2, -1, -1) + * 0.1, + }, + batch_size=[2], + ) + + out = module(td) + assert ("action", "mean") in out.keys(True) + assert ("action", "var") in out.keys(True) + assert ("action", "cross_covariance") in out.keys(True) + + +class TestExponentialQuadraticCost: + def test_forward_shapes_default(self): + cost = ExponentialQuadraticCost(reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.randn(2, 5, 4), + ("observation", "var"): torch.eye(4) + .unsqueeze(0) + .unsqueeze(0) + .expand(2, 5, -1, -1) + * 0.1, + }, + batch_size=[2, 5], + ) + + out = cost(td) + loss = out["loss_cost"] + assert loss.shape == (2, 5) + + def test_cost_at_target_is_low(self): + target = torch.zeros(4) + cost = ExponentialQuadraticCost(target=target, reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.zeros(1, 4), + ("observation", "var"): torch.eye(4).unsqueeze(0) * 1e-6, + }, + batch_size=[1], + ) + + out = cost(td) + assert out["loss_cost"].item() < 0.01 + + def test_cost_far_from_target_is_high(self): + target = torch.zeros(4) + cost = ExponentialQuadraticCost(target=target, reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.ones(1, 4) * 10.0, + ("observation", "var"): torch.eye(4).unsqueeze(0) * 0.1, + }, + batch_size=[1], + ) + + out = cost(td) + assert out["loss_cost"].item() > 0.9 + + def test_cost_bounded_zero_one(self): + cost = ExponentialQuadraticCost(reduction="none") + + td = TensorDict( + { + ("observation", "mean"): torch.randn(10, 4), + ("observation", "var"): torch.eye(4).unsqueeze(0).expand(10, -1, -1) + * 0.1, + }, + batch_size=[10], + ) + + out = cost(td) + loss = out["loss_cost"] + assert (loss >= -1e-6).all() + assert (loss <= 1.0 + 1e-6).all() + + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + def test_reductions(self, reduction): + cost = ExponentialQuadraticCost(reduction=reduction) + + td = TensorDict( + { + ("observation", "mean"): torch.randn(3, 5, 4), + ("observation", "var"): torch.eye(4) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, 5, -1, -1) + * 0.1, + }, + batch_size=[3, 5], + ) + + out = cost(td) + loss = out["loss_cost"] + + if reduction == "none": + assert loss.shape == (3, 5) + else: + assert loss.shape == () + + def test_custom_weights_and_target(self): + weights = torch.diag(torch.tensor([2.0, 0.5, 1.0, 1.0])) + target = torch.tensor([1.0, 0.0, 0.0, 0.0]) + cost = ExponentialQuadraticCost( + target=target, weights=weights, reduction="none" + ) + + td = TensorDict( + { + ("observation", "mean"): target.unsqueeze(0), + ("observation", "var"): torch.eye(4).unsqueeze(0) * 1e-6, + }, + batch_size=[1], + ) + + out = cost(td) + assert out["loss_cost"].item() < 0.01 + + def test_gradients_flow(self): + cost = ExponentialQuadraticCost(reduction="mean") + + mean = torch.randn(2, 4, requires_grad=True) + var = torch.eye(4).unsqueeze(0).expand(2, -1, -1) * 0.1 + + td = TensorDict( + {("observation", "mean"): mean, ("observation", "var"): var}, + batch_size=[2], + ) + + out = cost(td) + out["loss_cost"].backward() + assert mean.grad is not None + + +class TestImaginedEnv: + @staticmethod + def _make_dummy_world_model(obs_dim, action_dim): + class DummyWM(torch.nn.Module): + def __init__(self, obs_dim): + super().__init__() + self.obs_dim = obs_dim + + def forward(self, action, observation): + mean = observation.get("mean") + var = ( + torch.eye( + self.obs_dim, device=mean.device, dtype=mean.dtype + ).expand(*mean.shape[:-1], -1, -1) + * 0.01 + ) + return mean + 0.1, var + + return TensorDictModule( + DummyWM(obs_dim), + in_keys=["action", "observation"], + out_keys=[("next_observation", "mean"), ("next_observation", "var")], + ) + + @staticmethod + def _make_base_env(obs_dim, action_dim): + class StubEnv(EnvBase): + def __init__(self, obs_dim, action_dim): + super().__init__(batch_size=torch.Size([])) + self.observation_spec = Composite( + observation=Unbounded(shape=(obs_dim,)) + ) + self.action_spec = Unbounded(shape=(action_dim,)) + self.reward_spec = Unbounded(shape=(1,)) + + def _reset(self, tensordict=None): + return TensorDict( + {"observation": torch.zeros(obs_dim)}, + batch_size=self.batch_size, + ) + + def _step(self, tensordict): + return TensorDict( + { + "observation": torch.randn(obs_dim), + "reward": torch.zeros(1), + "done": torch.tensor(False).unsqueeze(0), + "terminated": torch.tensor(False).unsqueeze(0), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed): + pass + + return StubEnv(obs_dim, action_dim) + + def test_creation(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + assert env.batch_size == torch.Size([1]) + + def test_creation_with_batch_size(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env, batch_size=[3]) + assert env.batch_size == torch.Size([3]) + + def test_reset_with_observation(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + + reset_td = TensorDict( + { + ("observation", "mean"): torch.zeros(1, obs_dim), + ("observation", "var"): torch.eye(obs_dim).unsqueeze(0) * 1e-3, + }, + batch_size=[1], + ) + + out = env.reset(reset_td) + assert ("observation", "mean") in out.keys(True) + assert ("observation", "var") in out.keys(True) + + def test_step(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + + td = TensorDict( + { + ("observation", "mean"): torch.zeros(1, obs_dim), + ("observation", "var"): torch.eye(obs_dim).unsqueeze(0) * 1e-3, + ("action", "mean"): torch.zeros(1, action_dim), + ("action", "var"): torch.zeros(1, action_dim, action_dim), + ("action", "cross_covariance"): torch.zeros(1, obs_dim, action_dim), + }, + batch_size=[1], + ) + + out = env.step(td) + next_td = out["next"] + assert ("observation", "mean") in next_td.keys(True) + assert ("observation", "var") in next_td.keys(True) + assert "done" in next_td.keys() + assert not next_td["done"].any() + + def test_never_terminates(self): + obs_dim, action_dim = 4, 1 + wm = self._make_dummy_world_model(obs_dim, action_dim) + base_env = self._make_base_env(obs_dim, action_dim) + + env = ImaginedEnv(world_model_module=wm, base_env=base_env) + + td = TensorDict( + {"done": torch.ones(1, 1, dtype=torch.bool)}, + batch_size=[1], + ) + assert not env.any_done(td) + + +class TestMeanActionSelector: + @staticmethod + def _make_base_env(obs_dim, action_dim): + class StubEnv(EnvBase): + def __init__(self, obs_dim, action_dim): + super().__init__(batch_size=torch.Size([])) + self.observation_spec = Composite( + observation=Unbounded(shape=(obs_dim,)) + ) + self.action_spec = Unbounded(shape=(action_dim,)) + self.reward_spec = Unbounded(shape=(1,)) + + def _reset(self, tensordict=None): + return TensorDict( + {"observation": torch.zeros(obs_dim)}, + batch_size=self.batch_size, + ) + + def _step(self, tensordict): + return TensorDict( + { + "observation": torch.randn(obs_dim), + "reward": torch.zeros(1), + "done": torch.tensor(False).unsqueeze(0), + "terminated": torch.tensor(False).unsqueeze(0), + }, + batch_size=self.batch_size, + ) + + def _set_seed(self, seed): + pass + + return StubEnv(obs_dim, action_dim) + + def test_forward_wraps_observation(self): + transform = MeanActionSelector() + obs = torch.randn(4) + td = TensorDict( + {"observation": obs.clone()}, + batch_size=[], + ) + + out = transform._call(td) + assert ("observation", "mean") in out.keys(True) + assert ("observation", "var") in out.keys(True) + assert out["observation", "var"].shape == (4, 4) + torch.testing.assert_close(out["observation", "mean"], obs) + + def test_inverse_extracts_action_mean(self): + transform = MeanActionSelector() + action_mean = torch.randn(2) + td = TensorDict( + { + ("action", "mean"): action_mean, + ("action", "var"): torch.eye(2), + }, + batch_size=[], + ) + + out = transform._inv_call(td) + assert "action" in out.keys() + torch.testing.assert_close(out["action"], action_mean) + + def test_with_transformed_env_reset(self): + obs_dim, action_dim = 4, 1 + base_env = self._make_base_env(obs_dim, action_dim) + env = TransformedEnv(base_env, MeanActionSelector()) + + td = env.reset() + assert ("observation", "mean") in td.keys(True) + assert ("observation", "var") in td.keys(True) + + def test_observation_spec_transformed(self): + obs_dim, action_dim = 4, 1 + base_env = self._make_base_env(obs_dim, action_dim) + env = TransformedEnv(base_env, MeanActionSelector()) + + obs_spec = env.observation_spec + assert ("observation", "mean") in obs_spec.keys(True) + assert ("observation", "var") in obs_spec.keys(True) + + def test_zero_variance_on_reset(self): + obs_dim, action_dim = 4, 1 + base_env = self._make_base_env(obs_dim, action_dim) + env = TransformedEnv(base_env, MeanActionSelector()) + + td = env.reset() + var = td["observation", "var"] + torch.testing.assert_close(var, torch.zeros(obs_dim, obs_dim)) + + +@pytest.mark.skipif(not _has_botorch, reason="botorch/gpytorch not installed") +class TestGPWorldModel: + def test_creation(self): + from torchrl.modules.models.gp import GPWorldModel + + model = GPWorldModel(obs_dim=4, action_dim=1) + assert model.obs_dim == 4 + assert model.action_dim == 1 + assert model.input_dim == 5 + + def test_fit_and_deterministic_forward(self): + from torchrl.modules.models.gp import GPWorldModel + + obs_dim, action_dim = 2, 1 + model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) + + n_samples = 20 + obs = torch.randn(n_samples, obs_dim) + action = torch.randn(n_samples, action_dim) + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim) + + dataset = TensorDict( + { + "observation": obs, + "action": action, + ("next", "observation"): next_obs, + }, + batch_size=[n_samples], + ) + + model.fit(dataset) + model.freeze_and_detach() + + test_obs = TensorDict({"mean": torch.randn(3, obs_dim)}, batch_size=[3]) + test_action = TensorDict({"mean": torch.randn(3, action_dim)}, batch_size=[3]) + + mean, var = model.deterministic_forward(test_action, test_obs) + + assert mean.shape == (3, obs_dim) + assert var.shape == (3, obs_dim, obs_dim) + + def test_uncertain_forward(self): + from torchrl.modules.models.gp import GPWorldModel + + obs_dim, action_dim = 2, 1 + model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) + + n_samples = 20 + obs = torch.randn(n_samples, obs_dim).double() + action = torch.randn(n_samples, action_dim).double() + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim).double() + + dataset = TensorDict( + { + "observation": obs, + "action": action, + ("next", "observation"): next_obs, + }, + batch_size=[n_samples], + ) + + model.double() + model.fit(dataset) + model.freeze_and_detach() + + batch = 2 + test_obs = TensorDict( + { + "mean": torch.randn(batch, obs_dim, dtype=torch.float64), + "var": torch.eye(obs_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + }, + batch_size=[batch], + ) + test_action = TensorDict( + { + "mean": torch.randn(batch, action_dim, dtype=torch.float64), + "var": torch.eye(action_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + "cross_covariance": torch.zeros( + batch, obs_dim, action_dim, dtype=torch.float64 + ), + }, + batch_size=[batch], + ) + + mean, var = model.uncertain_forward(test_action, test_obs) + + assert mean.shape == (batch, obs_dim) + assert var.shape == (batch, obs_dim, obs_dim) + + torch.testing.assert_close(var, var.transpose(-2, -1), atol=1e-5, rtol=1e-4) + + def test_forward_dispatch(self): + from torchrl.modules.models.gp import GPWorldModel + + obs_dim, action_dim = 2, 1 + model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) + + n_samples = 20 + obs = torch.randn(n_samples, obs_dim) + action = torch.randn(n_samples, action_dim) + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim) + + dataset = TensorDict( + { + "observation": obs, + "action": action, + ("next", "observation"): next_obs, + }, + batch_size=[n_samples], + ) + + model.fit(dataset) + model.freeze_and_detach() + + det_obs = TensorDict({"mean": torch.randn(2, obs_dim)}, batch_size=[2]) + det_action = TensorDict({"mean": torch.randn(2, action_dim)}, batch_size=[2]) + mean, var = model(det_action, det_obs) + assert mean.shape == (2, obs_dim) + + unc_obs = TensorDict( + { + "mean": torch.randn(2, obs_dim), + "var": torch.eye(obs_dim).unsqueeze(0).expand(2, -1, -1) * 0.1, + }, + batch_size=[2], + ) + unc_action = TensorDict( + { + "mean": torch.randn(2, action_dim), + "var": torch.eye(action_dim).unsqueeze(0).expand(2, -1, -1) * 0.01, + "cross_covariance": torch.zeros(2, obs_dim, action_dim), + }, + batch_size=[2], + ) + mean, var = model(unc_action, unc_obs) + assert mean.shape == (2, obs_dim) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 05ecfc564e2..0bc594a8615 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -52,7 +52,7 @@ VmasEnv, VmasWrapper, ) -from .model_based import DreamerDecoder, DreamerEnv, ModelBasedEnvBase +from .model_based import DreamerDecoder, DreamerEnv, ImaginedEnv, ModelBasedEnvBase from .transforms import ( ActionDiscretizer, ActionMask, @@ -83,6 +83,7 @@ Hash, InitTracker, LineariseRewards, + MeanActionSelector, MultiAction, MultiStepTransform, NoopResetEnv, @@ -184,6 +185,7 @@ "HabitatEnv", "Hash", "InitTracker", + "ImaginedEnv", "IsaacGymEnv", "IsaacGymWrapper", "JumanjiEnv", @@ -193,6 +195,7 @@ "MOGymEnv", "MOGymWrapper", "MarlGroupMapType", + "MeanActionSelector", "MeltingpotEnv", "MeltingpotWrapper", "ModelBasedEnvBase", diff --git a/torchrl/envs/model_based/__init__.py b/torchrl/envs/model_based/__init__.py index cb387af7ff8..11af9351561 100644 --- a/torchrl/envs/model_based/__init__.py +++ b/torchrl/envs/model_based/__init__.py @@ -5,5 +5,6 @@ from .common import ModelBasedEnvBase from .dreamer import DreamerDecoder, DreamerEnv +from .imagined import ImaginedEnv -__all__ = ["ModelBasedEnvBase", "DreamerDecoder", "DreamerEnv"] +__all__ = ["DreamerDecoder", "DreamerEnv", "ImaginedEnv", "ModelBasedEnvBase"] diff --git a/torchrl/envs/model_based/imagined.py b/torchrl/envs/model_based/imagined.py new file mode 100644 index 00000000000..60300b92d89 --- /dev/null +++ b/torchrl/envs/model_based/imagined.py @@ -0,0 +1,164 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections.abc import Sequence + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModule +from torchrl.envs.common import EnvBase +from torchrl.envs.model_based import ModelBasedEnvBase + + +class ImaginedEnv(ModelBasedEnvBase): + """Imagination environment for model-based policy search. + + Wraps a learned world model (e.g. a Gaussian Process) as a standard + TorchRL environment so that imagined rollouts can be collected with + :meth:`~torchrl.envs.EnvBase.rollout`. Observations carry both mean + and covariance (under keys ``("observation", "mean")`` and + ``("observation", "var")``) to support uncertainty-aware moment-matching + controllers. + + The environment never terminates on its own -- rollout length is + controlled solely by the ``max_steps`` argument of + :meth:`~torchrl.envs.EnvBase.rollout`. The ``done`` and ``terminated`` + flags are always ``False``. + + Args: + world_model_module (TensorDictModule): A :class:`~tensordict.nn.TensorDictModule` + that takes ``"action"`` and ``"observation"`` entries and produces + ``("next_observation", "mean")`` and ``("next_observation", "var")``. + base_env (EnvBase): The real environment whose specs (observation, action, + reward, done) are copied into this imagined environment. + batch_size (int, Sequence[int], torch.Size, optional): Override batch size. + If ``None``, inferred from ``base_env`` (with a minimum of ``[1]``). + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.envs.model_based import ImaginedEnv, ModelBasedEnvBase + >>> from torchrl.data import Composite, Unbounded + >>> # A toy world model that returns zero-mean, identity covariance + >>> class DummyWorldModel(torch.nn.Module): + ... def __init__(self, obs_dim): + ... super().__init__() + ... self.obs_dim = obs_dim + ... def forward(self, action, observation): + ... mean = observation.get("mean") + ... var = torch.eye(self.obs_dim).expand(*mean.shape[:-1], -1, -1) + ... return mean, var + >>> obs_dim, act_dim = 4, 1 + >>> wm = TensorDictModule( + ... DummyWorldModel(obs_dim), + ... in_keys=["action", "observation"], + ... out_keys=[("next_observation", "mean"), ("next_observation", "var")], + ... ) + """ + + def __init__( + self, + world_model_module: TensorDictModule, + base_env: EnvBase, + batch_size: int | torch.Size | Sequence[int] | None = None, + **kwargs, + ) -> None: + if batch_size is not None: + batch_size = ( + torch.Size(batch_size) + if not isinstance(batch_size, torch.Size) + else batch_size + ) + elif len(base_env.batch_size) == 0: + batch_size = torch.Size([1]) + else: + batch_size = base_env.batch_size + + super().__init__( + world_model_module, + device=base_env.device, + batch_size=batch_size, + allow_done_after_reset=True, + **kwargs, + ) + + self.observation_spec = base_env.observation_spec.expand( + self.batch_size + ).clone() + self.action_spec = base_env.action_spec.expand(self.batch_size).clone() + self.reward_spec = base_env.reward_spec.expand(self.batch_size).clone() + self.done_spec = base_env.done_spec.expand(self.batch_size).clone() + + def any_done(self, tensordict) -> bool: + """Returns False -- imagination rollouts never terminate. + + Overridden to avoid CUDA sync from ``done.any()`` in the parent class. + """ + return False + + def maybe_reset(self, tensordict): + """No-op -- imagination rollouts do not need partial resets. + + Overridden to avoid CUDA sync from done checks in the parent class. + """ + return tensordict + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict = self.world_model(tensordict) + + reward = torch.zeros(*tensordict.shape, 1, device=self.device) + done = torch.zeros(*tensordict.shape, 1, dtype=torch.bool, device=self.device) + out = TensorDict( + { + "observation": tensordict.get("next_observation"), + "reward": reward, + "done": done, + "terminated": done.clone(), + }, + tensordict.shape, + ) + return out + + def _reset( + self, tensordict: TensorDictBase | None = None, **kwargs + ) -> TensorDictBase: + if tensordict is None: + tensordict = TensorDict({}, batch_size=self.batch_size, device=self.device) + + if ( + tensordict.get(("observation", "var"), None) is not None + and tensordict.get(("observation", "mean"), None) is not None + ): + return tensordict.copy() + + obs = tensordict.get("observation", None) + if obs is None: + obs = self.observation_spec.rand(shape=self.batch_size).get("observation") + if obs.ndim == 1: + obs = obs.expand(self.batch_size[0], -1) + + obs = obs.to(self.device) + B, D = obs.shape + + out = TensorDict( + { + ("observation", "mean"): obs, + ("observation", "var"): torch.zeros( + B, D, D, dtype=obs.dtype, device=self.device + ), + }, + batch_size=self.batch_size, + device=self.device, + ) + + out.set("done", torch.zeros(B, 1, dtype=torch.bool, device=self.device)) + out.set( + "terminated", + torch.zeros(B, 1, dtype=torch.bool, device=self.device), + ) + + return out diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 1a4230bf962..d1eafa8dc83 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .gym_transforms import EndOfLifeTransform +from .mean_action_selector import MeanActionSelector from .module import ModuleTransform from .r3m import R3MTransform from .ray_service import RayTransform @@ -103,6 +104,7 @@ "Hash", "InitTracker", "LineariseRewards", + "MeanActionSelector", "ModuleTransform", "MultiAction", "MultiStepTransform", diff --git a/torchrl/envs/transforms/mean_action_selector.py b/torchrl/envs/transforms/mean_action_selector.py new file mode 100644 index 00000000000..50c63eca91c --- /dev/null +++ b/torchrl/envs/transforms/mean_action_selector.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict import TensorDictBase +from torchrl.data.tensor_specs import Composite, Unbounded +from torchrl.envs.transforms.transforms import Transform + + +class MeanActionSelector(Transform): + """Bridges Gaussian belief-space policies with standard environments. + + Gaussian policies used in moment-matching model-based RL (e.g. PILCO) operate + on state *beliefs* -- ``(mean, covariance)`` pairs -- and produce + action distributions with ``("action", "mean")``, ``("action", "var")``, etc. + This transform adapts a standard environment so that such a policy can be + used directly with :meth:`~torchrl.envs.EnvBase.rollout`: + + * **Forward** (env output -> policy input): wraps the flat ``"observation"`` + tensor into ``("observation", "mean")`` with a zero-covariance + ``("observation", "var")``, representing a deterministic state belief. + * **Inverse** (policy output -> env input): extracts ``("action", "mean")`` + from the policy output and writes it as the flat ``"action"`` for the + base environment step. + + Args: + observation_key (str, optional): The observation key to read from the + base environment. Defaults to ``"observation"``. + action_key (str, optional): The action key expected by the base + environment. Defaults to ``"action"``. + + Examples: + >>> import torch + >>> from torchrl.envs import GymEnv, TransformedEnv + >>> from torchrl.envs.transforms import MeanActionSelector + >>> base_env = GymEnv("Pendulum-v1") + >>> env = TransformedEnv(base_env, MeanActionSelector()) + >>> td = env.reset() + >>> # The policy now sees ("observation", "mean") and ("observation", "var") + >>> print(td["observation", "mean"].shape) + >>> print(td["observation", "var"].shape) + """ + + def __init__( + self, + observation_key: str = "observation", + action_key: str = "action", + ) -> None: + super().__init__( + in_keys=[observation_key], + out_keys=[(observation_key, "mean"), (observation_key, "var")], + in_keys_inv=[action_key], + out_keys_inv=[(action_key, "mean")], + ) + self._observation_key = observation_key + self._action_key = action_key + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + obs = tensordict.get(self._observation_key) + + is_nested = isinstance(obs, TensorDictBase) + if is_nested: + return tensordict + + batch_shape = obs.shape[:-1] + D = obs.shape[-1] + device = obs.device + dtype = obs.dtype + + tensordict.pop(self._observation_key) + + tensordict.set( + (self._observation_key, "mean"), + obs, + ) + tensordict.set( + (self._observation_key, "var"), + torch.zeros(*batch_shape, D, D, device=device, dtype=dtype), + ) + + return tensordict + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + action_mean = tensordict.get((self._action_key, "mean"), None) + if action_mean is not None: + tensordict.set(self._action_key, action_mean) + return tensordict + + def transform_observation_spec(self, observation_spec): + obs_spec = observation_spec[self._observation_key] + D = obs_spec.shape[-1] + observation_spec[self._observation_key] = Composite( + mean=obs_spec.clone(), + var=Unbounded(shape=(*obs_spec.shape, D), dtype=obs_spec.dtype), + shape=obs_spec.shape, + ) + return observation_spec + + def _reset(self, tensordict, tensordict_reset): + return self._call(tensordict_reset) diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index ed0b0863fde..a790879dfa9 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -36,6 +36,7 @@ DreamerActor, DTActor, DuelingCnnDQNet, + GPWorldModel, MLP, MultiAgentConvNet, MultiAgentMLP, @@ -46,6 +47,7 @@ ObsEncoder, OnlineDTActor, QMixer, + RBFController, reset_noise, RSSMPosterior, RSSMPrior, @@ -136,6 +138,7 @@ "DreamerActor", "DuelingCnnDQNet", "EGreedyModule", + "GPWorldModel", "EGreedyWrapper", "GRU", "GRUCell", @@ -174,6 +177,7 @@ "PUCTScore", "QMixer", "QValueActor", + "RBFController", "QValueHook", "QValueModule", "RSSMPosterior", diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index 98d34666cf8..b8d85025a44 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -16,6 +16,7 @@ NoisyLinear, reset_noise, ) +from .gp import GPWorldModel from .llm import GPT2RewardModel from .model_based import ( DreamerActor, @@ -46,24 +47,13 @@ QMixer, VDNMixer, ) +from .rbf_controller import RBFController from .utils import Squeeze2dLayer, SqueezeLayer __all__ = [ - "DistributionalDQNnet", "BatchRenorm1d", - "DecisionTransformer", - "GPT2RewardModel", "ConsistentDropout", "ConsistentDropoutModule", - "NoisyLazyLinear", - "NoisyLinear", - "reset_noise", - "DreamerActor", - "ObsDecoder", - "ObsEncoder", - "RSSMPosterior", - "RSSMPrior", - "RSSMRollout", "Conv2dNet", "Conv3dNet", "ConvNet", @@ -71,16 +61,30 @@ "DdpgCnnQNet", "DdpgMlpActor", "DdpgMlpQNet", + "DecisionTransformer", + "DistributionalDQNnet", + "DreamerActor", "DTActor", "DuelingCnnDQNet", "DuelingMlpDQNet", + "GPT2RewardModel", + "GPWorldModel", "MLP", - "OnlineDTActor", "MultiAgentConvNet", "MultiAgentMLP", "MultiAgentNetBase", + "NoisyLazyLinear", + "NoisyLinear", + "ObsDecoder", + "ObsEncoder", + "OnlineDTActor", "QMixer", - "VDNMixer", + "RBFController", + "RSSMPosterior", + "RSSMPrior", + "RSSMRollout", "Squeeze2dLayer", "SqueezeLayer", + "VDNMixer", + "reset_noise", ] diff --git a/torchrl/modules/models/gp.py b/torchrl/modules/models/gp.py index 1df42e1aa72..c9ef5ac97f4 100644 --- a/torchrl/modules/models/gp.py +++ b/torchrl/modules/models/gp.py @@ -19,15 +19,29 @@ class GPWorldModel(nn.Module): - """Gaussian Process World Model. + """Gaussian Process World Model for moment-matching model-based RL. - This module implements a Gaussian Process (GP) based world model using BoTorch and GPyTorch. - It models the transition dynamics of an environment by predicting the change in observation - given the current observation and action. + Fits one independent single-task GP per observation dimension using + BoTorch/GPyTorch. Each GP models the *transition residual* + ``delta_i = next_obs_i - obs_i`` given the concatenated ``[obs, action]`` + input. After fitting, the model supports two forward modes: + + * **Deterministic**: point predictions via the GP posterior mean/variance. + * **Uncertain** (moment-matching): propagates Gaussian beliefs + ``N(m, S)`` through the GP analytically, yielding the next-state + belief ``N(m', S')``. This is the core computation in PILCO + (Deisenroth & Rasmussen, 2011). + + Requires ``botorch`` and ``gpytorch`` as optional dependencies. Args: - obs_dim (int): The dimension of the observation space. - action_dim (int): The dimension of the action space. + obs_dim (int): Dimensionality of the observation space. + action_dim (int): Dimensionality of the action space. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> model = GPWorldModel(obs_dim=4, action_dim=1) # doctest: +SKIP """ def __init__(self, obs_dim: int, action_dim: int) -> None: @@ -140,9 +154,16 @@ def _extract_parameters(self, y_train: torch.Tensor) -> None: self._cached_beta = torch.stack(betas) def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: + """Returns the cached kernel inverse and weight vectors. + + Returns: + inv_K (Tensor): Inverse kernel matrices, shape ``(obs_dim, N, N)``. + beta (Tensor): Weight vectors ``K^{-1} y``, shape ``(obs_dim, N)``. + """ return self._cached_inv_K, self._cached_beta def _gather_gp_params(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns the extracted hyperparameters of each per-dimension GP.""" return self.lengthscales, self.variances, self.noises def forward( @@ -173,7 +194,7 @@ def forward( return self.deterministic_forward(action, observation) def freeze_and_detach(self) -> None: - """Freezes the model and detaches gradients.""" + """Freezes the model parameters so they are not updated during policy optimisation.""" def uncertain_forward( self, action: TensorDictBase, obs: TensorDictBase @@ -230,8 +251,8 @@ def uncertain_forward( scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) lb = scaled_exp * beta.unsqueeze(0) - det_B = torch.linalg.det(B_mat) - c = variances.squeeze(1).unsqueeze(0) / torch.sqrt(det_B) + _, log_det_B = torch.linalg.slogdet(B_mat) + c = variances.squeeze(1).unsqueeze(0) * torch.exp(-0.5 * log_det_B) pred_mean = torch.sum(lb, dim=-1) * c.squeeze(0) @@ -284,8 +305,8 @@ def uncertain_forward( batch_size, num_train_pts, num_train_pts ) - det_R_ab = torch.linalg.det(R_ab) - c_ab = variances[a] * variances[b] / torch.sqrt(det_R_ab) + _, log_det_R_ab = torch.linalg.slogdet(R_ab) + c_ab = variances[a] * variances[b] * torch.exp(-0.5 * log_det_R_ab) Q_ab = c_ab.view(-1, 1, 1) * torch.exp(exp1.unsqueeze(0) + exp2) @@ -300,7 +321,7 @@ def uncertain_forward( invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) - pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item() + pred_cov[:, a, a] += variances[a] - trace_val + noises[a] outer_mean = torch.bmm(pred_mean.unsqueeze(-1), pred_mean.unsqueeze(-2)) pred_cov = pred_cov - outer_mean diff --git a/torchrl/modules/models/rbf_controller.py b/torchrl/modules/models/rbf_controller.py new file mode 100644 index 00000000000..5a490bcdf11 --- /dev/null +++ b/torchrl/modules/models/rbf_controller.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +import torch.nn as nn + + +class RBFController(nn.Module): + """Radial Basis Function controller for moment-matching policy search. + + Implements a policy that maps Gaussian-distributed state beliefs + ``(mean, covariance)`` to Gaussian-distributed actions using an RBF network + followed by a sinusoidal squashing function. The moment-matching formulas + allow analytic gradient computation through the policy during model-based + optimization (e.g., PILCO). + + The controller uses ``n_basis`` RBF basis functions, each parameterised + by a centre vector and a shared diagonal lengthscale. The output is a + weighted sum of basis activations, optionally squashed through + :meth:`squash_sin` to enforce action bounds. + + Reference: Deisenroth & Rasmussen, "PILCO: A Model-Based and Data-Efficient + Approach to Policy Search", ICML 2011. + + Args: + input_dim (int): Dimensionality of the state (observation) space. + output_dim (int): Dimensionality of the action space. + max_action (float or Tensor): Element-wise upper bound on action + magnitude. When provided, actions are squashed through + :meth:`squash_sin`. + n_basis (int, optional): Number of RBF basis functions. + Defaults to ``10``. + + Inputs: + mean (Tensor): State mean of shape ``(*batch, input_dim)``. + covariance (Tensor): State covariance of shape + ``(*batch, input_dim, input_dim)``. + + Returns: + action_mean (Tensor): Action mean of shape ``(*batch, output_dim)``. + action_covariance (Tensor): Action covariance of shape + ``(*batch, output_dim, output_dim)``. + cross_covariance (Tensor): Input–output cross-covariance of shape + ``(*batch, input_dim, output_dim)``. + + Examples: + >>> import torch + >>> controller = RBFController(input_dim=4, output_dim=1, max_action=2.0, n_basis=5) + >>> mean = torch.randn(2, 4) + >>> covariance = torch.eye(4).unsqueeze(0).expand(2, -1, -1) * 0.1 + >>> action_mean, action_cov, cross_cov = controller(mean, covariance) + >>> action_mean.shape + torch.Size([2, 1]) + >>> action_cov.shape + torch.Size([2, 1, 1]) + >>> cross_cov.shape + torch.Size([2, 4, 1]) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + max_action: float | torch.Tensor, + n_basis: int = 10, + ) -> None: + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.max_action = max_action + self.n_basis = n_basis + + self.centers = nn.Parameter(torch.randn(n_basis, input_dim) * 0.5) + self.weights = nn.Parameter(torch.randn(n_basis, output_dim) * 0.1) + self.lengthscales = nn.Parameter(torch.ones(input_dim)) + self.variance = 1.0 + + @staticmethod + def squash_sin( + mean: torch.Tensor, + covariance: torch.Tensor, + max_action: float | torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Propagates a Gaussian through an element-wise ``max_action * sin(x)`` squashing. + + Computes the exact moments of the transformed distribution using + the moment-matching identities for sine applied to Gaussian inputs. + + Args: + mean (Tensor): Input mean, shape ``(*batch, K)``. + covariance (Tensor): Input covariance, shape ``(*batch, K, K)``. + max_action (float or Tensor): Per-dimension action bound. + + Returns: + squashed_mean (Tensor): Output mean, shape ``(*batch, K)``. + squashed_covariance (Tensor): Output covariance, shape ``(*batch, K, K)``. + cross_covariance (Tensor): Input–output cross-covariance, shape ``(*batch, K, K)``. + """ + K = mean.shape[-1] + device = mean.device + dtype = mean.dtype + + if not isinstance(max_action, torch.Tensor): + max_action = torch.tensor(max_action, dtype=dtype, device=device) + + max_action = max_action.view(-1) + if max_action.shape[0] == 1 and K > 1: + max_action = max_action.expand(K) + + diag_cov = torch.diagonal(covariance, dim1=-2, dim2=-1) + + squashed_mean = max_action * torch.exp(-diag_cov / 2.0) * torch.sin(mean) + + lq = -(diag_cov.unsqueeze(-1) + diag_cov.unsqueeze(-2)) / 2.0 + q = torch.exp(lq) + + mean_diff = mean.unsqueeze(-1) - mean.unsqueeze(-2) + mean_sum = mean.unsqueeze(-1) + mean.unsqueeze(-2) + + squashed_covariance = (torch.exp(lq + covariance) - q) * torch.cos( + mean_diff + ) - (torch.exp(lq - covariance) - q) * torch.cos(mean_sum) + + outer_max = max_action.unsqueeze(-2) * max_action.unsqueeze(-1) + squashed_covariance = outer_max * squashed_covariance / 2.0 + + cross_covariance = torch.diag_embed( + max_action * torch.exp(-diag_cov / 2.0) * torch.cos(mean) + ) + + return squashed_mean, squashed_covariance, cross_covariance + + def forward( + self, mean: torch.Tensor, covariance: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_shape = mean.shape[:-1] + D = mean.shape[-1] + N = self.n_basis + device = mean.device + + # Flatten batch dimensions for computation + mean_flat = mean.reshape(-1, D) + covariance_flat = covariance.reshape(-1, D, D) + B = mean_flat.shape[0] + + inv_lengthscale = torch.diag(1.0 / self.lengthscales) + inv_lengthscale_batch = inv_lengthscale.unsqueeze(0) + + inp = self.centers.unsqueeze(0) - mean_flat.unsqueeze(1) + + B_mat = ( + inv_lengthscale_batch @ covariance_flat @ inv_lengthscale_batch + + torch.eye(D, device=device, dtype=mean.dtype).unsqueeze(0) + ) + + scaled_inp = inp @ inv_lengthscale + + t = torch.linalg.solve(B_mat, scaled_inp.mT).mT + + exp_term = torch.exp(-0.5 * torch.sum(scaled_inp * t, dim=-1)) + log_det_sign, log_det = torch.linalg.slogdet(B_mat) + normalizer = self.variance * torch.exp(-0.5 * log_det) + phi_mean = normalizer.unsqueeze(-1) * exp_term + + action_mean = phi_mean @ self.weights + + t_scaled = t @ inv_lengthscale + cross_cov = torch.bmm(t_scaled.mT, phi_mean.unsqueeze(-1) * self.weights) + + # Pairwise basis covariance (Eq. A.42–A.45 in Deisenroth thesis) + centers_i = self.centers.unsqueeze(1) + centers_j = self.centers.unsqueeze(0) + diff = centers_i - centers_j + center_bar = (centers_i + centers_j) / 2.0 + + inv_lambda = 1.0 / (self.lengthscales**2) + exp1 = -0.25 * torch.sum((diff**2) * inv_lambda, dim=-1) + + lambda_half = torch.diag((self.lengthscales**2) / 2.0) + B_q = covariance_flat + lambda_half.unsqueeze(0) + + z = center_bar.unsqueeze(0) - mean_flat.unsqueeze(1).unsqueeze(1) + z_flat = z.view(B, N * N, D) + + solved_z_flat = torch.linalg.solve(B_q, z_flat.mT).mT + exp2 = -0.5 * torch.sum(z_flat * solved_z_flat, dim=-1).view(B, N, N) + + log_det_lambda_half = torch.sum(torch.log((self.lengthscales**2) / 2.0)) + _, log_det_bq = torch.linalg.slogdet(B_q) + c_q = torch.exp(0.5 * (log_det_lambda_half - log_det_bq)) + + Q = (self.variance**2 * c_q.view(B, 1, 1)) * torch.exp( + exp1.unsqueeze(0) + exp2 + ) + + W_batch = self.weights.unsqueeze(0).expand(B, N, -1) + action_cov = torch.bmm(W_batch.mT, torch.bmm(Q, W_batch)) + + outer_mean = torch.bmm(action_mean.unsqueeze(-1), action_mean.unsqueeze(1)) + action_cov = action_cov - outer_mean + + action_cov = (action_cov + action_cov.mT) / 2.0 + action_cov = ( + action_cov + + torch.eye(self.output_dim, device=device, dtype=mean.dtype).unsqueeze(0) + * 1e-6 + ) + + if self.max_action is not None: + action_mean, action_cov, C = self.squash_sin( + action_mean, action_cov, self.max_action + ) + cross_cov = torch.bmm(cross_cov, C) + + # Reshape back to original batch shape + action_mean = action_mean.reshape(*batch_shape, self.output_dim) + action_cov = action_cov.reshape(*batch_shape, self.output_dim, self.output_dim) + cross_cov = cross_cov.reshape(*batch_shape, D, self.output_dim) + + return action_mean, action_cov, cross_cov From 716c32f4c3154cb80a21656dadfdbdaff26531d8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Mar 2026 14:41:13 +0000 Subject: [PATCH 6/8] edit --- setup-and-run.sh | 35 ++++++++++++++++++++++++++++++----- torchrl/modules/models/gp.py | 26 +++++++++++++------------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/setup-and-run.sh b/setup-and-run.sh index f4414cf1b38..c45565327cf 100755 --- a/setup-and-run.sh +++ b/setup-and-run.sh @@ -30,6 +30,7 @@ REPO_DIR="/root/rl" VENV_DIR="/root/torchrl_venv" MODE="isaac" # "isaac" or "dmcontrol" BUILD_ONLY=false +GPUS="" # explicit GPU set, e.g. "3,4,5" EXTRA_ARGS=() # extra Hydra overrides forwarded to the training script # ---- Parse arguments -------------------------------------------------------- @@ -38,6 +39,7 @@ for arg in "$@"; do --build-only) BUILD_ONLY=true ;; --dmcontrol) MODE="dmcontrol" ;; --isaac) MODE="isaac" ;; + --gpus=*) GPUS="${arg#--gpus=}" ;; *) EXTRA_ARGS+=("$arg") ;; esac done @@ -45,15 +47,38 @@ done # Avoid "'': unknown terminal type" in headless containers export TERM="${TERM:-xterm}" +# Resolve GPU set early so we can use it for zombie cleanup +if [[ -n "$GPUS" ]]; then + export CUDA_VISIBLE_DEVICES="$GPUS" +elif [[ "$MODE" == "isaac" ]]; then + export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2}" +fi + echo "============================================================" echo " setup-and-run.sh" echo " mode=$MODE build_only=$BUILD_ONLY" +echo " gpus=${CUDA_VISIBLE_DEVICES:-}" echo " extra_args=${EXTRA_ARGS[*]:-}" echo "============================================================" -# ---- 0) Kill zombie Python processes from previous runs --------------------- -echo "* Killing leftover Python processes..." -pkill -9 -f python || true +# ---- 0) Kill zombie Python processes on the SAME GPUs ---------------------- +# Only kill dreamer processes whose CUDA_VISIBLE_DEVICES matches ours, +# so that a second experiment on different GPUs is left untouched. +echo "* Killing leftover dreamer processes on GPUs=${CUDA_VISIBLE_DEVICES:-}..." +if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + # Find dreamer_isaac.py PIDs whose /proc//environ contains our GPU set + for pid in $(pgrep -f "dreamer_isaac.py|dreamer.py" 2>/dev/null || true); do + proc_env=$(tr '\0' '\n' < /proc/$pid/environ 2>/dev/null || true) + proc_gpus=$(echo "$proc_env" | grep '^CUDA_VISIBLE_DEVICES=' | head -1 | cut -d= -f2) + if [[ "$proc_gpus" == "$CUDA_VISIBLE_DEVICES" ]] || [[ -z "$proc_gpus" ]]; then + echo " Killing PID $pid (CUDA_VISIBLE_DEVICES=$proc_gpus)" + kill -9 "$pid" 2>/dev/null || true + fi + done +else + # No GPU constraint — kill all dreamer processes + pkill -9 -f "dreamer_isaac.py|dreamer.py" || true +fi sleep 1 # ---- 1) System dependencies ------------------------------------------------ @@ -203,8 +228,8 @@ echo "============================================================" cd "$REPO_DIR" if [[ "$MODE" == "isaac" ]]; then - # Expose 3 GPUs: GPU 0 = sim, GPU 1 = training, GPU 2 = eval (rendering) - export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2}" + # GPUs already set above: GPU0 = sim, GPU1 = training, GPU2 = eval (rendering) + echo " CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" $PYTHON "sota-implementations/dreamer/dreamer_isaac.py" "${EXTRA_ARGS[@]}" else export MUJOCO_GL=egl diff --git a/torchrl/modules/models/gp.py b/torchrl/modules/models/gp.py index c9ef5ac97f4..b7edd9c29d3 100644 --- a/torchrl/modules/models/gp.py +++ b/torchrl/modules/models/gp.py @@ -2,20 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import importlib.util + import torch import torch.nn as nn from tensordict import TensorDictBase -try: - from botorch.fit import fit_gpytorch_mll - from botorch.models import ModelListGP, SingleTaskGP - from gpytorch.kernels import RBFKernel, ScaleKernel - from gpytorch.mlls import SumMarginalLogLikelihood - from gpytorch.priors import GammaPrior - - _has_botorch = True -except ImportError: - _has_botorch = False +_has_gpytorch = importlib.util.find_spec("gpytorch") is not None +_has_botorch = importlib.util.find_spec("botorch") is not None class GPWorldModel(nn.Module): @@ -45,7 +39,7 @@ class GPWorldModel(nn.Module): """ def __init__(self, obs_dim: int, action_dim: int) -> None: - if not _has_botorch: + if not _has_botorch or not _has_gpytorch: raise ImportError( "botorch and gpytorch are required to use GPWorldModel. " "Please install them to proceed." @@ -55,7 +49,7 @@ def __init__(self, obs_dim: int, action_dim: int) -> None: self.action_dim = action_dim self.input_dim = obs_dim + action_dim - self.model_list: ModelListGP | None = None + self.model_list = None self.register_buffer("X_train", torch.empty(0)) self.register_buffer("lengthscales", torch.zeros(self.obs_dim, self.input_dim)) @@ -78,6 +72,12 @@ def fit(self, dataset: TensorDictBase) -> None: Args: dataset (TensorDictBase): A dataset of collected transitions. """ + from botorch.fit import fit_gpytorch_mll + from botorch.models import ModelListGP, SingleTaskGP + from gpytorch.kernels import RBFKernel, ScaleKernel + from gpytorch.mlls import SumMarginalLogLikelihood + from gpytorch.priors import GammaPrior + obs = dataset["observation"] action = dataset["action"] next_obs = dataset[("next", "observation")] @@ -244,7 +244,7 @@ def uncertain_forward( B_mat = inv_L.unsqueeze(0) @ joint_var.unsqueeze(1) @ inv_L.unsqueeze(0) B_mat = B_mat + torch.eye( self.input_dim, dtype=m_x.dtype, device=m_x.device - ).view(1, 1, self.input_dim, self.input_dim) + ).reshape(1, 1, self.input_dim, self.input_dim) t = torch.linalg.solve(B_mat, inv_N.transpose(-2, -1)).transpose(-2, -1) From efa4032ce3dce2704ca8455dbc5d0b333d16c713 Mon Sep 17 00:00:00 2001 From: Pedro Rosa Date: Sat, 14 Mar 2026 15:22:20 +0100 Subject: [PATCH 7/8] integrate GPWorldModel with TensorDict API and improve ImaginedEnv compatibility - Added optional `pilco` dependency group (botorch, gpytorch) to `pyproject.toml`. - Refactored `GPWorldModel` to follow the TensorDict module interface: - `forward` now accepts and returns a `TensorDict` instead of tuples. - Added configurable `in_keys` and `out_keys` for flexible integration. - Default keys support probabilistic inputs (`action.mean`, `action.var`, `action.cross_covariance`, `observation.mean`, `observation.var`). - Deterministic and uncertain forward passes now write results directly into the TensorDict. - Removed `freeze_and_detach` utility as it is no longer required. - Updated internal logic to read/write through TensorDict keys. - Updated `ImaginedEnv`: - Added `next_observation_key` argument to specify where the world model writes predicted observations. - Default key is `("next", "observation")`. - Adjusted environment step to read observations using this configurable key. - Updated documentation examples to reflect new TensorDict conventions. - Simplified PILCO implementation: - World model is now used directly instead of wrapping it with `TensorDictModule`. - Replaced `freeze_and_detach()` with `eval()` mode. - Adapted rollout handling to align keys between real and imagined trajectories (selecting mean values for observation/action and matching rollout structure). - Ensured compatibility when concatenating evaluation rollouts into the dataset. --- pyproject.toml | 4 + sota-implementations/pilco/pilco.py | 19 +++-- torchrl/envs/model_based/imagined.py | 18 ++++- torchrl/modules/models/gp.py | 106 +++++++++++++++++---------- 4 files changed, 98 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9b5b4a22d2f..f83ac2a7b57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,10 @@ marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot; python_version>='3 open_spiel = ["open_spiel>=1.5"] brax = ["jax>=0.7.0; python_version>='3.11'", "brax; python_version>='3.11'"] procgen = ["procgen"] +pilco = [ + "botorch", + "gpytorch", +] # Base LLM dependencies (no inference backend - use llm-vllm or llm-sglang) llm = [ "transformers", diff --git a/sota-implementations/pilco/pilco.py b/sota-implementations/pilco/pilco.py index 024359e0fd3..63fb765f169 100644 --- a/sota-implementations/pilco/pilco.py +++ b/sota-implementations/pilco/pilco.py @@ -71,16 +71,10 @@ def pilco_loop( env.device ) base_world_model.fit(rollout) - base_world_model.freeze_and_detach() - - world_model_module = TensorDictModule( - module=base_world_model, - in_keys=["action", "observation"], - out_keys=[("next_observation", "mean"), ("next_observation", "var")], - ) + base_world_model.eval() imagined_env = ImaginedEnv( - world_model_module=world_model_module, + world_model_module=base_world_model, base_env=env, ) reset_td = initial_observation.expand(*imagined_env.batch_size) @@ -119,6 +113,15 @@ def pilco_loop( logger.log_scalar("eval/reward", reward, step=logger_step) logger.log_scalar("eval/steps", steps, step=logger_step) + test_rollout.set("observation", test_rollout.get(("observation", "mean"))) + test_rollout.set("action", test_rollout.get(("action", "mean"))) + test_rollout.set( + ("next", "observation"), test_rollout.get(("next", "observation", "mean")) + ) + + test_rollout = test_rollout.select( + *rollout.keys(include_nested=True, leaves_only=True) + ) rollout = tensordict.cat([rollout, test_rollout], dim=0) if len(rollout) > cfg.pilco.max_rollout_length: diff --git a/torchrl/envs/model_based/imagined.py b/torchrl/envs/model_based/imagined.py index 60300b92d89..2d13d02e9d8 100644 --- a/torchrl/envs/model_based/imagined.py +++ b/torchrl/envs/model_based/imagined.py @@ -36,6 +36,8 @@ class ImaginedEnv(ModelBasedEnvBase): reward, done) are copied into this imagined environment. batch_size (int, Sequence[int], torch.Size, optional): Override batch size. If ``None``, inferred from ``base_env`` (with a minimum of ``[1]``). + next_observation_key (str or tuple of str, optional): The key where the world + model writes the predicted next observation. Defaults to ``("next", "observation")``. Examples: >>> import torch @@ -43,21 +45,26 @@ class ImaginedEnv(ModelBasedEnvBase): >>> from tensordict.nn import TensorDictModule >>> from torchrl.envs.model_based import ImaginedEnv, ModelBasedEnvBase >>> from torchrl.data import Composite, Unbounded + >>> base_env = GymEnv("Pendulum-v1") + >>> obs_dim = base_env.observation_spec["observation"].shape[-1] >>> # A toy world model that returns zero-mean, identity covariance >>> class DummyWorldModel(torch.nn.Module): ... def __init__(self, obs_dim): ... super().__init__() ... self.obs_dim = obs_dim ... def forward(self, action, observation): - ... mean = observation.get("mean") + ... # Assuming observation comes in as a dict with a "mean" key + ... mean = observation.get("mean", observation) ... var = torch.eye(self.obs_dim).expand(*mean.shape[:-1], -1, -1) ... return mean, var - >>> obs_dim, act_dim = 4, 1 >>> wm = TensorDictModule( ... DummyWorldModel(obs_dim), ... in_keys=["action", "observation"], - ... out_keys=[("next_observation", "mean"), ("next_observation", "var")], + ... out_keys=[("next", "observation", "mean"), ("next", "observation", "var")], ... ) + >>> imagined_env = ImaginedEnv(wm, base_env, next_observation_key=("next", "observation")) + >>> # Collect an imagined rollout + >>> rollout = imagined_env.rollout(max_steps=5, policy=RandomPolicy(imagined_env.action_spec)) """ def __init__( @@ -65,8 +72,11 @@ def __init__( world_model_module: TensorDictModule, base_env: EnvBase, batch_size: int | torch.Size | Sequence[int] | None = None, + next_observation_key: str | tuple[str, ...] = ("next", "observation"), **kwargs, ) -> None: + self.next_observation_key = next_observation_key + if batch_size is not None: batch_size = ( torch.Size(batch_size) @@ -114,7 +124,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: done = torch.zeros(*tensordict.shape, 1, dtype=torch.bool, device=self.device) out = TensorDict( { - "observation": tensordict.get("next_observation"), + "observation": tensordict.get(self.next_observation_key), "reward": reward, "done": done, "terminated": done.clone(), diff --git a/torchrl/modules/models/gp.py b/torchrl/modules/models/gp.py index b7edd9c29d3..c588e780527 100644 --- a/torchrl/modules/models/gp.py +++ b/torchrl/modules/models/gp.py @@ -29,8 +29,13 @@ class GPWorldModel(nn.Module): Requires ``botorch`` and ``gpytorch`` as optional dependencies. Args: - obs_dim (int): Dimensionality of the observation space. - action_dim (int): Dimensionality of the action space. + obs_dim (int): The dimension of the observation space. + action_dim (int): The dimension of the action space. + in_keys (list[str | tuple[str, ...]] | None, optional): The keys to read from the + input TensorDict. Defaults to ["action", "observation"]. + out_keys (list[str | tuple[str, ...]] | None, optional): The keys to write the + predicted mean and variance to in the output TensorDict. + Defaults to [("next", "observation"), ("next", "observation_var")]. Examples: >>> import torch @@ -38,7 +43,13 @@ class GPWorldModel(nn.Module): >>> model = GPWorldModel(obs_dim=4, action_dim=1) # doctest: +SKIP """ - def __init__(self, obs_dim: int, action_dim: int) -> None: + def __init__( + self, + obs_dim: int, + action_dim: int, + in_keys: list[str | tuple[str, ...]] | None = None, + out_keys: list[str | tuple[str, ...]] | None = None, + ) -> None: if not _has_botorch or not _has_gpytorch: raise ImportError( "botorch and gpytorch are required to use GPWorldModel. " @@ -49,6 +60,27 @@ def __init__(self, obs_dim: int, action_dim: int) -> None: self.action_dim = action_dim self.input_dim = obs_dim + action_dim + self.in_keys = ( + in_keys + if in_keys is not None + else [ + ("action", "mean"), + ("action", "var"), + ("action", "cross_covariance"), + ("observation", "mean"), + ("observation", "var"), + ] + ) + + self.out_keys = ( + out_keys + if out_keys is not None + else [ + ("next", "observation", "mean"), + ("next", "observation", "var"), + ] + ) + self.model_list = None self.register_buffer("X_train", torch.empty(0)) @@ -166,60 +198,53 @@ def _gather_gp_params(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Returns the extracted hyperparameters of each per-dimension GP.""" return self.lengthscales, self.variances, self.noises - def forward( - self, action: TensorDictBase, observation: TensorDictBase - ) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Forward pass for the GPWorldModel. Routes the request to either the deterministic or uncertain forward pass depending on whether the observation input contains variance. Args: - action (TensorDictBase): The action tensordict. - observation (TensorDictBase): The observation tensordict. + tensordict (TensorDictBase): The input tensordict containing the action and observation. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing the mean and variance tensors of the next observation. """ + u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys + + x_var = tensordict.get(x_var_key, None) observation_uncertain = False - x_var = observation.get("var", None) if x_var is not None: observation_uncertain = not torch.all( torch.isclose(x_var, torch.zeros_like(x_var)) ) + if observation_uncertain: - return self.uncertain_forward(action, observation) + return self.uncertain_forward(tensordict) else: - return self.deterministic_forward(action, observation) - - def freeze_and_detach(self) -> None: - """Freezes the model parameters so they are not updated during policy optimisation.""" + return self.deterministic_forward(tensordict) - def uncertain_forward( - self, action: TensorDictBase, obs: TensorDictBase - ) -> tuple[torch.Tensor, torch.Tensor]: + def uncertain_forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Calculates the forward pass when the observation has uncertainty (non-zero variance). Propagates uncertainty through the Gaussian Process via exact moment matching. Args: - action (TensorDictBase): A tensordict containing ``"mean"``, ``"var"``, and - ``"cross_covariance"`` of the action. - obs (TensorDictBase): A tensordict containing the ``"mean"`` and ``"var"`` - of the current observation. + tensordict (TensorDictBase): A tensordict containing the action and observation tensors. Returns: - tuple[torch.Tensor, torch.Tensor]: Next observation mean and variance matrices. + TensorDictBase: Next observation mean and variance matrices. """ inv_K, beta = self.compute_factorizations() lengthscales, variances, noises = self._gather_gp_params() + u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys - m_x, s_x = obs.get("mean"), obs.get("var") + m_x, s_x = tensordict.get(x_mean_key), tensordict.get(x_var_key) m_u, s_u, c_xu = ( - action.get("mean"), - action.get("var"), - action.get("cross_covariance"), + tensordict.get(u_mean_key), + tensordict.get(u_var_key), + tensordict.get(u_cc_key), ) device, dtype = m_x.device, m_x.dtype @@ -342,23 +367,24 @@ def uncertain_forward( s_x = s_x + 1e-8 * torch.eye(self.obs_dim, device=s_x.device).expand( s_x.shape[0], -1, -1 ) - return m_x, s_x - def deterministic_forward( - self, action: TensorDictBase, observation: TensorDictBase - ) -> tuple[torch.Tensor, torch.Tensor]: + out_mean_key, out_var_key = self.out_keys + tensordict.set(out_mean_key, m_x) + tensordict.set(out_var_key, s_x) + return tensordict + + def deterministic_forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Calculates the forward pass when the input observation is deterministic (no variance). Args: - action (TensorDictBase): A tensordict containing the ``"mean"`` of the action. - observation (TensorDictBase): A tensordict containing the ``"mean"`` of the - current observation. + tensordict (TensorDictBase): A tensordict containing the action and observation tensors. Returns: - tuple[torch.Tensor, torch.Tensor]: Next observation mean and variance matrices. + TensorDictBase: Next observation mean and variance matrices. """ - observation_mean = observation.get("mean") - action_mean = action.get("mean") + u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys + observation_mean = tensordict.get(x_mean_key) + action_mean = tensordict.get(u_mean_key) x_flat = observation_mean.view(-1, self.obs_dim) u_flat = action_mean.view(-1, self.action_dim) @@ -380,4 +406,10 @@ def deterministic_forward( delta_mean = delta_mean_flat.view(*batch_shape, self.obs_dim) delta_std = delta_std_flat.view(*batch_shape, self.obs_dim) - return observation_mean + delta_mean, torch.diag_embed(delta_std**2) + m_x = observation_mean + delta_mean + s_x = torch.diag_embed(delta_std**2) + + out_mean_key, out_var_key = self.out_keys + tensordict.set(out_mean_key, m_x) + tensordict.set(out_var_key, s_x) + return tensordict From 671b265951a65bd582edf8233bda8bf13ba0efb3 Mon Sep 17 00:00:00 2001 From: Pedro Rosa Date: Mon, 16 Mar 2026 20:56:05 +0100 Subject: [PATCH 8/8] vetorize the uncertain forward and fix tests --- test/test_objectives.py | 127 ++++--- torchrl/modules/models/gp.py | 716 +++++++++++++++++++++++------------ 2 files changed, 542 insertions(+), 301 deletions(-) diff --git a/test/test_objectives.py b/test/test_objectives.py index 275fc31faf3..13c4d71eba8 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -18648,10 +18648,17 @@ def test_reset_with_observation(self): def test_step(self): obs_dim, action_dim = 4, 1 + next_observation_key = ( + "next_observation" # ("next", "observation") could also be a possibility + ) wm = self._make_dummy_world_model(obs_dim, action_dim) base_env = self._make_base_env(obs_dim, action_dim) - env = ImaginedEnv(world_model_module=wm, base_env=base_env) + env = ImaginedEnv( + world_model_module=wm, + base_env=base_env, + next_observation_key=next_observation_key, + ) td = TensorDict( { @@ -18784,7 +18791,7 @@ def test_creation(self): model = GPWorldModel(obs_dim=4, action_dim=1) assert model.obs_dim == 4 assert model.action_dim == 1 - assert model.input_dim == 5 + assert model.state_action_dim == 5 def test_fit_and_deterministic_forward(self): from torchrl.modules.models.gp import GPWorldModel @@ -18793,9 +18800,9 @@ def test_fit_and_deterministic_forward(self): model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) n_samples = 20 - obs = torch.randn(n_samples, obs_dim) - action = torch.randn(n_samples, action_dim) - next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim) + obs = torch.randn(n_samples, obs_dim).double() + action = torch.randn(n_samples, action_dim).double() + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim).double() dataset = TensorDict( { @@ -18807,15 +18814,20 @@ def test_fit_and_deterministic_forward(self): ) model.fit(dataset) - model.freeze_and_detach() + model.eval() - test_obs = TensorDict({"mean": torch.randn(3, obs_dim)}, batch_size=[3]) - test_action = TensorDict({"mean": torch.randn(3, action_dim)}, batch_size=[3]) + td = TensorDict( + { + ("observation", "mean"): torch.randn(3, obs_dim), + ("action", "mean"): torch.randn(3, action_dim), + }, + batch_size=[3], + ) - mean, var = model.deterministic_forward(test_action, test_obs) + forward_td = model.deterministic_forward(td) - assert mean.shape == (3, obs_dim) - assert var.shape == (3, obs_dim, obs_dim) + assert forward_td[("next", "observation", "mean")].shape == (3, obs_dim) + assert forward_td[("next", "observation", "var")].shape == (3, obs_dim, obs_dim) def test_uncertain_forward(self): from torchrl.modules.models.gp import GPWorldModel @@ -18839,35 +18851,38 @@ def test_uncertain_forward(self): model.double() model.fit(dataset) - model.freeze_and_detach() + model.eval() batch = 2 - test_obs = TensorDict( - { - "mean": torch.randn(batch, obs_dim, dtype=torch.float64), - "var": torch.eye(obs_dim, dtype=torch.float64) - .unsqueeze(0) - .expand(batch, -1, -1) - * 0.01, - }, - batch_size=[batch], - ) - test_action = TensorDict( + td = TensorDict( { - "mean": torch.randn(batch, action_dim, dtype=torch.float64), - "var": torch.eye(action_dim, dtype=torch.float64) - .unsqueeze(0) - .expand(batch, -1, -1) - * 0.01, - "cross_covariance": torch.zeros( - batch, obs_dim, action_dim, dtype=torch.float64 - ), + "observation": { + "mean": torch.randn(batch, obs_dim, dtype=torch.float64), + "var": torch.eye(obs_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + }, + "action": { + "mean": torch.randn(batch, action_dim, dtype=torch.float64), + "var": torch.eye(action_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + "cross_covariance": torch.zeros( + batch, obs_dim, action_dim, dtype=torch.float64 + ), + }, }, batch_size=[batch], ) - mean, var = model.uncertain_forward(test_action, test_obs) + forward_td = model.uncertain_forward(td) + mean, var = ( + forward_td[("next", "observation", "mean")], + forward_td[("next", "observation", "var")], + ) assert mean.shape == (batch, obs_dim) assert var.shape == (batch, obs_dim, obs_dim) @@ -18880,9 +18895,9 @@ def test_forward_dispatch(self): model = GPWorldModel(obs_dim=obs_dim, action_dim=action_dim) n_samples = 20 - obs = torch.randn(n_samples, obs_dim) - action = torch.randn(n_samples, action_dim) - next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim) + obs = torch.randn(n_samples, obs_dim).double() + action = torch.randn(n_samples, action_dim).double() + next_obs = obs + 0.1 * torch.randn(n_samples, obs_dim).double() dataset = TensorDict( { @@ -18894,29 +18909,33 @@ def test_forward_dispatch(self): ) model.fit(dataset) - model.freeze_and_detach() - - det_obs = TensorDict({"mean": torch.randn(2, obs_dim)}, batch_size=[2]) - det_action = TensorDict({"mean": torch.randn(2, action_dim)}, batch_size=[2]) - mean, var = model(det_action, det_obs) - assert mean.shape == (2, obs_dim) + model.eval() - unc_obs = TensorDict( - { - "mean": torch.randn(2, obs_dim), - "var": torch.eye(obs_dim).unsqueeze(0).expand(2, -1, -1) * 0.1, - }, - batch_size=[2], - ) - unc_action = TensorDict( + batch = 2 + td = TensorDict( { - "mean": torch.randn(2, action_dim), - "var": torch.eye(action_dim).unsqueeze(0).expand(2, -1, -1) * 0.01, - "cross_covariance": torch.zeros(2, obs_dim, action_dim), + "observation": { + "mean": torch.randn(batch, obs_dim, dtype=torch.float64), + "var": torch.eye(obs_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + }, + "action": { + "mean": torch.randn(batch, action_dim, dtype=torch.float64), + "var": torch.eye(action_dim, dtype=torch.float64) + .unsqueeze(0) + .expand(batch, -1, -1) + * 0.01, + "cross_covariance": torch.zeros( + batch, obs_dim, action_dim, dtype=torch.float64 + ), + }, }, - batch_size=[2], + batch_size=[batch], ) - mean, var = model(unc_action, unc_obs) + forward_td = model(td) + mean = forward_td[("next", "observation", "mean")] assert mean.shape == (2, obs_dim) diff --git a/torchrl/modules/models/gp.py b/torchrl/modules/models/gp.py index c588e780527..18ed35f37d4 100644 --- a/torchrl/modules/models/gp.py +++ b/torchrl/modules/models/gp.py @@ -2,6 +2,22 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +# +# Variable naming follows Deisenroth & Rasmussen (2011), "PILCO: A Model-Based +# and Data-Efficient Approach to Policy Search" (cited inline as "Eq. N"). +# +# Key symbols +# ----------- +# x̃ := [x, u] concatenated state-action input (Eq. 1 / Sec. 2.1) +# Δ := x_t - x_{t-1} transition residual (Sec. 2.1) +# K_a Gram matrix K_{a,ij}=k_a(x̃_i,x̃_j) (Eq. 6) +# β_a := (K_a + σ²_ε I)^{-1}y_a GP weight vector (Eq. 7) +# q_a kernel-mean vector (Eq. 15) +# Q_{ab} cross-kernel matrix (Eqs. 21-22) +# μ̃ / Σ̃ joint state-action mean/cov (Sec. 2.2) +# μ_Δ / Σ_Δ predictive mean/cov of Δ (Eqs. 14, 17-23) +# μ_t / Σ_t next-state mean/cov (Eqs. 10-11) + import importlib.util import torch @@ -13,34 +29,57 @@ class GPWorldModel(nn.Module): - """Gaussian Process World Model for moment-matching model-based RL. + """Gaussian Process world model with moment-matching uncertainty propagation. - Fits one independent single-task GP per observation dimension using - BoTorch/GPyTorch. Each GP models the *transition residual* - ``delta_i = next_obs_i - obs_i`` given the concatenated ``[obs, action]`` - input. After fitting, the model supports two forward modes: + Implements the probabilistic dynamics model from PILCO + (Deisenroth & Rasmussen, 2011). One independent GP is fit per state + dimension, each predicting the transition residual + ``Δ = x_t - x_{t-1}`` from the concatenated state-action input + ``x̃ = [x, u]`` (Sec. 2.1). - * **Deterministic**: point predictions via the GP posterior mean/variance. - * **Uncertain** (moment-matching): propagates Gaussian beliefs - ``N(m, S)`` through the GP analytically, yielding the next-state - belief ``N(m', S')``. This is the core computation in PILCO - (Deisenroth & Rasmussen, 2011). + :meth:`forward` supports two modes depending on whether the input + observation carries non-zero variance: - Requires ``botorch`` and ``gpytorch`` as optional dependencies. + - **Deterministic**: uses the GP posterior mean and variance directly + (Eqs. 7-8). + - **Uncertain** (moment-matching): propagates a Gaussian belief + ``N(μ, Σ)`` through the GP analytically (Eqs. 10-23). + + .. note:: + Requires ``botorch`` and ``gpytorch`` as optional dependencies. Args: - obs_dim (int): The dimension of the observation space. - action_dim (int): The dimension of the action space. - in_keys (list[str | tuple[str, ...]] | None, optional): The keys to read from the - input TensorDict. Defaults to ["action", "observation"]. - out_keys (list[str | tuple[str, ...]] | None, optional): The keys to write the - predicted mean and variance to in the output TensorDict. - Defaults to [("next", "observation"), ("next", "observation_var")]. + obs_dim (int): Dimension D of the observation (state) space. + action_dim (int): Dimension F of the action (control) space. + in_keys (list of NestedKey, optional): Keys to read from the input + :class:`~tensordict.TensorDictBase`. Must contain five entries in + order: action mean, action covariance, state-action + cross-covariance, observation mean, observation covariance. + Defaults to ``[("action", "mean"), ("action", "var"), + ("action", "cross_covariance"), ("observation", "mean"), + ("observation", "var")]``. + out_keys (list of NestedKey, optional): Keys to write the predicted + next-state mean and covariance to. Defaults to + ``[("next", "observation", "mean"), + ("next", "observation", "var")]``. Examples: >>> import torch >>> from tensordict import TensorDict - >>> model = GPWorldModel(obs_dim=4, action_dim=1) # doctest: +SKIP + >>> model = GPWorldModel(obs_dim=4, action_dim=1) + >>> dataset = TensorDict( + ... { + ... "observation": torch.randn(50, 4), + ... "action": torch.randn(50, 1), + ... ("next", "observation"): torch.randn(50, 4), + ... }, + ... batch_size=[50], + ... ) + >>> model.fit(dataset) + + Reference: + Deisenroth, M. P. & Rasmussen, C. E. (2011). PILCO: A model-based + and data-efficient approach to policy search. *ICML*. """ def __init__( @@ -56,9 +95,9 @@ def __init__( "Please install them to proceed." ) super().__init__() - self.obs_dim = obs_dim - self.action_dim = action_dim - self.input_dim = obs_dim + action_dim + self.obs_dim = obs_dim # D in the paper + self.action_dim = action_dim # F in the paper + self.state_action_dim = obs_dim + action_dim # D+F, dimension of x̃ (Sec. 2.1) self.in_keys = ( in_keys @@ -83,26 +122,51 @@ def __init__( self.model_list = None - self.register_buffer("X_train", torch.empty(0)) - self.register_buffer("lengthscales", torch.zeros(self.obs_dim, self.input_dim)) - self.register_buffer("variances", torch.zeros(self.obs_dim, 1)) - self.register_buffer("noises", torch.zeros(self.obs_dim)) - self._cached_inv_K: torch.Tensor | None = None - self._cached_beta: torch.Tensor | None = None + # X̃ = [x̃_1, ..., x̃_n] ∈ R^{n×(D+F)} – training inputs (Sec. 2.1) + self.register_buffer("X_tilde_train", torch.empty(0)) + + # ℓ_a ∈ R^{D+F} – ARD length-scales for each output dimension a (Eq. 6). + # Stored as [D, D+F]; the full matrix Λ_a = diag(ℓ_a²) is never + # materialised — ℓ_a is squared on the fly wherever needed. + # Note: GPyTorch's .lengthscale returns ℓ directly (not ℓ²). + self.register_buffer("ell", torch.zeros(self.obs_dim, self.state_action_dim)) + + # α²_a – signal variance for each output dimension a (Eq. 6); shape [D, 1] + self.register_buffer("alpha_sq", torch.zeros(self.obs_dim, 1)) + + # σ²_{ε_a} – noise variance for each output dimension a (Sec. 2.1); shape [D] + self.register_buffer("sigma_sq_eps", torch.zeros(self.obs_dim)) + + # (K_a + σ²_{ε_a} I)^{-1} – cached inverse Gram matrices (Eq. 7); shape [D, n, n]. + # Registered as buffers so they survive .to(device) and state_dict round-trips. + self.register_buffer("_cached_inv_K_noisy", None) + + # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a – GP weight vectors (Eq. 7); shape [D, n]. + # Registered as a buffer so it survives .to(device) and state_dict round-trips. + self.register_buffer("_cached_beta", None) @property def device(self) -> torch.device: - return self.lengthscales.device + return self.ell.device def fit(self, dataset: TensorDictBase) -> None: - """Fits the Gaussian Process model to the provided dataset. + """Fit one GP per state dimension to a dataset of transitions. - The dataset must contain the ``"observation"``, ``"action"``, and - ``("next", "observation")`` keys. The model predicts the difference - between the next observation and the current observation. + Constructs training inputs ``X̃ = [x, u]`` and targets + ``Δ_a = x_{t,a} - x_{t-1,a}``, then maximises the marginal + log-likelihood to learn SE kernel hyper-parameters + (ℓ_a, α²_a, σ²_{ε_a}) for each output dimension (Sec. 2.1, Eq. 6). + + .. note:: + The dataset is expected to be flat with shape ``[n, *]``. If your + replay buffer returns multi-dimensional batches (e.g. ``[B, T, *]``), + call ``dataset.reshape(-1)`` before passing it here. Args: - dataset (TensorDictBase): A dataset of collected transitions. + dataset (TensorDictBase): Transition dataset with keys + ``"observation"`` of shape ``(n, D)``, + ``"action"`` of shape ``(n, F)``, and + ``("next", "observation")`` of shape ``(n, D)``. """ from botorch.fit import fit_gpytorch_mll from botorch.models import ModelListGP, SingleTaskGP @@ -110,306 +174,464 @@ def fit(self, dataset: TensorDictBase) -> None: from gpytorch.mlls import SumMarginalLogLikelihood from gpytorch.priors import GammaPrior - obs = dataset["observation"] - action = dataset["action"] - next_obs = dataset[("next", "observation")] + x_t_minus_1 = dataset["observation"] # x_{t-1} ∈ R^{n×D} + u_t_minus_1 = dataset["action"] # u_{t-1} ∈ R^{n×F} + x_t = dataset[("next", "observation")] # x_t ∈ R^{n×D} + + # x̃ = [x_{t-1}, u_{t-1}] ∈ R^{n×(D+F)} – training inputs (Sec. 2.1) + X_tilde_train = ( + torch.cat([x_t_minus_1, u_t_minus_1], dim=-1).detach().to(self.device) + ) + + # Δ ∈ R^{n×D}, Δ_{i,a} = x_{t,a} - x_{t-1,a} – training targets (Sec. 2.1) + Delta_train = (x_t - x_t_minus_1).detach().to(self.device) - X_train = torch.cat([obs, action], dim=-1).detach().to(self.device) - y_train = (next_obs - obs).detach().to(self.device) - self.X_train = X_train + self.X_tilde_train = X_tilde_train models = [] - for i in range(self.obs_dim): - train_x = X_train - train_y = y_train[:, i].unsqueeze(-1) + for a in range(self.obs_dim): + # Each GP_a models p(Δ_a | x̃) independently (Sec. 2.1) + Delta_a = Delta_train[:, a].unsqueeze(-1) # y_a ∈ R^{n×1} covar_module = ScaleKernel( + # SE kernel k_a(x̃, x̃') with ARD length-scales (one ℓ_{a,i} + # per input dimension, Eq. 6) RBFKernel( - ard_num_dims=self.input_dim, lengthscale_prior=GammaPrior(1.1, 0.1) + ard_num_dims=self.state_action_dim, + lengthscale_prior=GammaPrior(1.1, 0.1), ), - outputscale_prior=GammaPrior(1.5, 0.5), + outputscale_prior=GammaPrior(1.5, 0.5), # prior on α²_a (Eq. 6) ) - gp = SingleTaskGP( - train_X=train_x, train_Y=train_y, covar_module=covar_module + gp_a = SingleTaskGP( + train_X=X_tilde_train, + train_Y=Delta_a, + covar_module=covar_module, ) - gp.likelihood.noise_covar.register_prior( - "noise_prior", GammaPrior(1.2, 0.05), "noise" + gp_a.likelihood.noise_covar.register_prior( + "noise_prior", + GammaPrior(1.2, 0.05), + "noise", # prior on σ²_{ε_a} (Sec. 2.1) ) - models.append(gp) + models.append(gp_a) self.model_list = ModelListGP(*models).to(self.device) mll = SumMarginalLogLikelihood(self.model_list.likelihood, self.model_list) - fit_gpytorch_mll(mll) - self._extract_parameters(y_train) + fit_gpytorch_mll(mll) # evidence maximisation (Sec. 2.1) + self._extract_and_cache_parameters(Delta_train) - def _extract_parameters(self, y_train: torch.Tensor) -> None: - lengthscales, variances, noises, inv_Ks, betas = [], [], [], [], [] + def _extract_and_cache_parameters(self, Delta_train: torch.Tensor) -> None: + # Extract learned hyper-parameters from each GP_a and pre-compute the + # quantities that are fixed after fitting: + # ℓ_a, α²_a, σ²_{ε_a} (Eq. 6 / Sec. 2.1) + # (K_a + σ²_{ε_a} I)^{-1} (Eq. 7) + # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a (Eq. 7) + ell_list, alpha_sq_list, sigma_sq_eps_list = [], [], [] + inv_K_noisy_list, beta_list = [], [] - for i, gp in enumerate(self.model_list.models): - gp.eval() - gp.likelihood.eval() + n = self.X_tilde_train.shape[0] # number of training points - ls = gp.covar_module.base_kernel.lengthscale.squeeze().detach() - var = gp.covar_module.outputscale.detach() - noise = gp.likelihood.noise.squeeze().detach() + for a, gp_a in enumerate(self.model_list.models): + gp_a.eval() + gp_a.likelihood.eval() - lengthscales.append(ls) - variances.append(var) - noises.append(noise) + # ℓ_a ∈ R^{D+F} – ARD length-scales for GP_a (Eq. 6). + # GPyTorch's .lengthscale returns ℓ directly (not ℓ²). + ell_a = gp_a.covar_module.base_kernel.lengthscale.squeeze().detach() - X_scaled = self.X_train / ls - dist = torch.cdist(X_scaled, X_scaled, p=2) ** 2 - K = var * torch.exp(-0.5 * dist) + # α²_a – signal variance for GP_a (Eq. 6) + alpha_sq_a = gp_a.covar_module.outputscale.detach() - K_noisy = K + (noise + 1e-6) * torch.eye( - self.X_train.size(0), device=self.device - ) + # σ²_{ε_a} – noise variance for GP_a (Sec. 2.1) + sigma_sq_eps_a = gp_a.likelihood.noise.squeeze().detach() + + ell_list.append(ell_a) + alpha_sq_list.append(alpha_sq_a) + sigma_sq_eps_list.append(sigma_sq_eps_a) + + # K_{a,ij} = α²_a exp(-½ (x̃_i-x̃_j)^T Λ_a^{-1} (x̃_i-x̃_j)) (Eq. 6) + # Dividing X̃ by ℓ_a gives Λ_a^{-1/2}-scaled inputs for cdist. + X_tilde_scaled = self.X_tilde_train / ell_a + sq_dist = torch.cdist(X_tilde_scaled, X_tilde_scaled, p=2) ** 2 + K_a = alpha_sq_a * torch.exp(-0.5 * sq_dist) - L = torch.linalg.cholesky(K_noisy) - eye = torch.eye(L.size(0), dtype=L.dtype, device=L.device) - inv_K = torch.cholesky_solve(eye, L) + # K_{a,noisy} = K_a + σ²_{ε_a} I (denominator in Eq. 7) + K_a_noisy = K_a + (sigma_sq_eps_a + 1e-6) * torch.eye(n, device=self.device) - y = y_train[:, i].unsqueeze(-1) - beta = torch.cholesky_solve(y, L).squeeze(-1) + L_a = torch.linalg.cholesky(K_a_noisy) + eye_n = torch.eye(n, dtype=L_a.dtype, device=L_a.device) - inv_Ks.append(inv_K) - betas.append(beta) + # (K_a + σ²_{ε_a} I)^{-1} (Eq. 7) + inv_K_a_noisy = torch.cholesky_solve(eye_n, L_a) - self.lengthscales = torch.stack(lengthscales) - self.variances = torch.stack(variances).unsqueeze(-1) - self.noises = torch.stack(noises) + # y_a = [Δ_{1,a}, ..., Δ_{n,a}]^T – targets for GP_a (Sec. 2.1) + y_a = Delta_train[:, a].unsqueeze(-1) - self._cached_inv_K = torch.stack(inv_Ks) - self._cached_beta = torch.stack(betas) + # β_a = (K_a + σ²_{ε_a} I)^{-1} y_a (Eq. 7) + beta_a = torch.cholesky_solve(y_a, L_a).squeeze(-1) + + inv_K_noisy_list.append(inv_K_a_noisy) + beta_list.append(beta_a) + + self.ell = torch.stack(ell_list) # [D, D+F] + self.alpha_sq = torch.stack(alpha_sq_list).unsqueeze(-1) # [D, 1] + self.sigma_sq_eps = torch.stack(sigma_sq_eps_list) # [D] + self._cached_inv_K_noisy = torch.stack(inv_K_noisy_list) # [D, n, n] + self._cached_beta = torch.stack(beta_list) # [D, n] def compute_factorizations(self) -> tuple[torch.Tensor, torch.Tensor]: - """Returns the cached kernel inverse and weight vectors. + """Return the cached kernel inverses and GP weight vectors. Returns: - inv_K (Tensor): Inverse kernel matrices, shape ``(obs_dim, N, N)``. - beta (Tensor): Weight vectors ``K^{-1} y``, shape ``(obs_dim, N)``. + tuple[Tensor, Tensor]: A pair ``(inv_K_noisy, beta)`` where + ``inv_K_noisy`` has shape ``(D, n, n)`` and contains + ``(K_a + σ²_{ε_a} I)^{-1}`` for each output dimension (Eq. 7), + and ``beta`` has shape ``(D, n)`` and contains + ``β_a = (K_a + σ²_{ε_a} I)^{-1} y_a`` (Eq. 7). """ - return self._cached_inv_K, self._cached_beta + return self._cached_inv_K_noisy, self._cached_beta - def _gather_gp_params(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Returns the extracted hyperparameters of each per-dimension GP.""" - return self.lengthscales, self.variances, self.noises + def _gather_gp_hyperparams(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Returns (ell, alpha_sq, sigma_sq_eps) — the SE kernel hyper-parameters + # for each GP_a (Eq. 6 / Sec. 2.1): + # ell: ℓ_{a,i}, shape [D, D+F] (ℓ, not ℓ²) + # alpha_sq: α²_a, shape [D, 1] + # sigma_sq_eps: σ²_{ε_a}, shape [D] + return self.ell, self.alpha_sq, self.sigma_sq_eps def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """Forward pass for the GPWorldModel. + """Predict the next-state distribution given the current state and action. - Routes the request to either the deterministic or uncertain forward pass - depending on whether the observation input contains variance. + Routes to :meth:`uncertain_forward` (moment-matching, Eqs. 10-23) when + the input observation covariance is non-zero, and to + :meth:`deterministic_forward` (Eqs. 7-8) otherwise. Args: - tensordict (TensorDictBase): The input tensordict containing the action and observation. + tensordict (TensorDictBase): Input tensordict containing keys + defined by ``in_keys``. Observation and action tensors may be + unbatched ``(D,)`` / ``(F,)`` or batched ``(B, D)`` / + ``(B, F)``; a leading batch dimension will be added and removed + automatically for unbatched inputs. The observation covariance, + when present, must be a full matrix of shape ``(..., D, D)`` + — per-dimension variance vectors are not accepted; use + :func:`torch.diag_embed` to convert them first. Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing the mean and - variance tensors of the next observation. + TensorDictBase: The same tensordict, updated in-place with the + predicted next-state mean and covariance written to ``out_keys``. """ u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys - x_var = tensordict.get(x_var_key, None) - observation_uncertain = False - if x_var is not None: - observation_uncertain = not torch.all( - torch.isclose(x_var, torch.zeros_like(x_var)) + Sigma_x = tensordict.get(x_var_key, None) + if Sigma_x is not None and Sigma_x.dim() < 2: + raise ValueError( + f"Expected observation covariance to have at least 2 dimensions " + f"(..., D, D), got shape {tuple(Sigma_x.shape)}. " + "Convert per-dimension variances with torch.diag_embed() first." ) + observation_uncertain = Sigma_x is not None and not torch.all( + torch.isclose(Sigma_x, torch.zeros_like(Sigma_x)) + ) + if observation_uncertain: return self.uncertain_forward(tensordict) else: return self.deterministic_forward(tensordict) def uncertain_forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """Calculates the forward pass when the observation has uncertainty (non-zero variance). + """Moment-matching forward pass for a Gaussian input belief (Eqs. 10-23). - Propagates uncertainty through the Gaussian Process via exact moment matching. + Propagates the joint Gaussian belief + ``p(x̃_{t-1}) = N(μ̃_{t-1}, Σ̃_{t-1})`` (Sec. 2.2) through the GP + dynamics model and returns a Gaussian approximation to ``p(x_t)`` + via exact moment matching. Args: - tensordict (TensorDictBase): A tensordict containing the action and observation tensors. + tensordict (TensorDictBase): Input tensordict with keys defined by + ``in_keys``. Supports unbatched ``(D,)`` inputs or batched + inputs with a single leading batch dimension ``(B, D)``. Returns: - TensorDictBase: Next observation mean and variance matrices. + TensorDictBase: The same tensordict updated with next-state mean + ``μ_t`` (Eq. 10) and covariance ``Σ_t`` (Eq. 11) at ``out_keys``. """ - inv_K, beta = self.compute_factorizations() - lengthscales, variances, noises = self._gather_gp_params() + inv_K_noisy, beta = self.compute_factorizations() + ell, alpha_sq, sigma_sq_eps = self._gather_gp_hyperparams() u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys - m_x, s_x = tensordict.get(x_mean_key), tensordict.get(x_var_key) - m_u, s_u, c_xu = ( - tensordict.get(u_mean_key), - tensordict.get(u_var_key), - tensordict.get(u_cc_key), - ) - - device, dtype = m_x.device, m_x.dtype - - joint_mean = torch.cat([m_x, m_u], dim=-1) - - s_ = s_x @ c_xu - upper = torch.cat([s_x, s_], dim=-1) - lower = torch.cat([s_.transpose(-1, -2), s_u], dim=-1) - - joint_var = torch.cat([upper, lower], dim=-2) - - X_train = self.X_train - num_train_pts = X_train.shape[0] - batch_size = joint_mean.shape[0] - - inp = X_train - joint_mean.unsqueeze(1) - - inv_L = torch.diag_embed(1.0 / lengthscales).to(dtype=dtype, device=device) - inv_N = inp.unsqueeze(1) @ inv_L.unsqueeze(0) - - B_mat = inv_L.unsqueeze(0) @ joint_var.unsqueeze(1) @ inv_L.unsqueeze(0) - B_mat = B_mat + torch.eye( - self.input_dim, dtype=m_x.dtype, device=m_x.device - ).reshape(1, 1, self.input_dim, self.input_dim) - - t = torch.linalg.solve(B_mat, inv_N.transpose(-2, -1)).transpose(-2, -1) - - scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) - lb = scaled_exp * beta.unsqueeze(0) - - _, log_det_B = torch.linalg.slogdet(B_mat) - c = variances.squeeze(1).unsqueeze(0) * torch.exp(-0.5 * log_det_B) - - pred_mean = torch.sum(lb, dim=-1) * c.squeeze(0) - - t_inv_L = t @ inv_L.unsqueeze(0) - - cross_cov_E_D = torch.matmul( - t_inv_L.transpose(-2, -1), lb.unsqueeze(-1) - ).squeeze(-1) * c.unsqueeze(-1) - cross_cov = cross_cov_E_D.transpose(-2, -1) + mu_x = tensordict.get(x_mean_key) # μ_x, shape (B×)D + Sigma_x = tensordict.get(x_var_key) # Σ_x, shape (B×)D×D + mu_u = tensordict.get(u_mean_key) # μ_u, shape (B×)F + Sigma_u = tensordict.get(u_var_key) # Σ_u, shape (B×)F×F + C_xu = tensordict.get(u_cc_key) # cov[x_{t-1}, u_{t-1}], (B×)D×F (Eq. 12) + + # Support unbatched inputs by temporarily adding a leading batch dimension. + unbatched = mu_x.dim() == 1 + if unbatched: + mu_x, Sigma_x, mu_u, Sigma_u, C_xu = ( + mu_x.unsqueeze(0), + Sigma_x.unsqueeze(0), + mu_u.unsqueeze(0), + Sigma_u.unsqueeze(0), + C_xu.unsqueeze(0), + ) - pred_cov = torch.zeros( - batch_size, self.obs_dim, self.obs_dim, dtype=m_x.dtype, device=m_x.device + device, dtype = mu_x.device, mu_x.dtype + B = mu_x.shape[0] # batch size + n = self.X_tilde_train.shape[0] # number of training points + D = self.obs_dim # state dimension + DF = self.state_action_dim # D+F, dimension of x̃ + + # ---- Build joint state-action distribution p(x̃_{t-1}) (Sec. 2.2) ---- + # μ̃_{t-1} = [μ_x; μ_u] ∈ R^{B×(D+F)} + mu_tilde = torch.cat([mu_x, mu_u], dim=-1) + + # Σ̃_{t-1} = [[Σ_x, Σ_x C_xu ], + # [C_xu^T Σ_x^T, Σ_u ]] ∈ R^{B×(D+F)×(D+F)} + Sigma_x_C_xu = Sigma_x @ C_xu # upper-right block [B, D, F] + Sigma_tilde = torch.cat( + [ + torch.cat([Sigma_x, Sigma_x_C_xu], dim=-1), + torch.cat([Sigma_x_C_xu.transpose(-1, -2), Sigma_u], dim=-1), + ], + dim=-2, + ) # [B, D+F, D+F] + + # ---- Compute q_a (mean-prediction kernel vector, Eq. 15) ---- + # ν_i = x̃_i - μ̃_{t-1} (Eq. 16); shape [B, n, D+F] + nu = self.X_tilde_train - mu_tilde.unsqueeze(1) + + # Λ_a^{-1} as diagonal matrices; shape [D, D+F, D+F]. + # ell stores ℓ_a (not ℓ²_a), so 1/ℓ_a gives the diagonal of Λ_a^{-1/2}; + # used here to form the full Λ_a^{-1} = diag(1/ℓ²_a) = diag(1/ℓ_a)². + inv_Lambda_diag_mats = torch.diag_embed(1.0 / ell).to( + device=device, dtype=dtype ) - X_i = X_train.unsqueeze(1) - X_j = X_train.unsqueeze(0) - diff = X_i - X_j - joint_mean_flat = joint_mean.unsqueeze(1).unsqueeze(1) - - for a in range(self.obs_dim): - for b in range(self.obs_dim): - l2_a = lengthscales[a].to(device=device, dtype=dtype) ** 2 - l2_b = lengthscales[b].to(device=device, dtype=dtype) ** 2 - - inv_L_a = 1.0 / l2_a - inv_L_b = 1.0 / l2_b - inv_L_sum = inv_L_a + inv_L_b - Lambda_ab = 1.0 / inv_L_sum - - z_bar = Lambda_ab * (X_i * inv_L_a + X_j * inv_L_b) - z = z_bar.unsqueeze(0) - joint_mean_flat - - z_flat = z.view( - batch_size, num_train_pts * num_train_pts, self.input_dim - ) - - R_ab = joint_var @ torch.diag(inv_L_sum) + torch.eye( - self.input_dim, dtype=m_x.dtype, device=m_x.device - ).unsqueeze(0) - - inv_L_plus = 1.0 / (l2_a + l2_b) - exp1 = -0.5 * torch.sum(diff * inv_L_plus * diff, dim=-1) - - M_ab = joint_var + torch.diag(Lambda_ab).unsqueeze(0) - - solved_z_flat = torch.linalg.solve( - M_ab, z_flat.transpose(-2, -1) - ).transpose(-2, -1) - exp2 = (-0.5 * torch.sum(z_flat * solved_z_flat, dim=-1)).view( - batch_size, num_train_pts, num_train_pts - ) - - _, log_det_R_ab = torch.linalg.slogdet(R_ab) - c_ab = variances[a] * variances[b] * torch.exp(-0.5 * log_det_R_ab) + # Λ_a^{-1} ν_i; shape [B, D, n, D+F] + inv_Lambda_nu = nu.unsqueeze(1) @ inv_Lambda_diag_mats.unsqueeze(0) - Q_ab = c_ab.view(-1, 1, 1) * torch.exp(exp1.unsqueeze(0) + exp2) - - Qb = torch.matmul(Q_ab, beta[b]) - pred_cov[:, a, b] = ( - torch.matmul(beta[a].unsqueeze(0), Qb.unsqueeze(-1)) - .squeeze(-1) - .squeeze(-1) - ) - - if a == b: - invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) - trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) - - pred_cov[:, a, a] += variances[a] - trace_val + noises[a] - - outer_mean = torch.bmm(pred_mean.unsqueeze(-1), pred_mean.unsqueeze(-2)) - pred_cov = pred_cov - outer_mean - - pred_cov = (pred_cov + pred_cov.transpose(-2, -1)) / 2.0 - - m_dx = pred_mean - s_dx = pred_cov - c_xdx = cross_cov + # R_a = Λ_a^{-1} Σ̃_{t-1} Λ_a^{-1} + I – normalising matrix in Eq. 15; + # shape [B, D, D+F, D+F] + R_a = ( + inv_Lambda_diag_mats.unsqueeze(0) + @ Sigma_tilde.unsqueeze(1) + @ inv_Lambda_diag_mats.unsqueeze(0) + ) + R_a = R_a + torch.eye(DF, device=device, dtype=dtype).view(1, 1, DF, DF) + + # Solve R_a t = (Λ_a^{-1} ν_i)^T → t = R_a^{-1} Λ_a^{-1} ν_i^T + t = torch.linalg.solve(R_a, inv_Lambda_nu.transpose(-2, -1)).transpose(-2, -1) + + # exp(-½ ν_i^T (Σ̃ + Λ_a)^{-1} ν_i) – exponent in Eq. 15; shape [B, D, n] + scaled_exp = torch.exp(-0.5 * torch.sum(inv_Lambda_nu * t, dim=-1)) + + # Scalar prefactor α²_a / sqrt(|Σ̃_{t-1} Λ_a^{-1} + I|) from Eq. 15; shape [B, D] + det_R_a = torch.linalg.det(R_a) + c_a = alpha_sq.squeeze(-1).unsqueeze(0) / torch.sqrt(det_R_a) + + # β_a ⊙ q_a (pointwise); shape [B, D, n] + beta_q_a = scaled_exp * beta.unsqueeze(0) + + # μ^a_Δ = β_a^T q_a (Eq. 14); shape [B, D] + mu_Delta = torch.sum(beta_q_a, dim=-1) * c_a.squeeze(0) + + # ---- Cross-covariance cov[x̃_{t-1}, Δ_t] (used in Eq. 12) ---- + # Derivative of μ_Δ w.r.t. μ̃, contracted with Σ̃ (Deisenroth 2010); + # shape [B, D+F, D] + t_inv_Lambda = t @ inv_Lambda_diag_mats.unsqueeze(0) + cov_xtilde_Delta = ( + torch.matmul( + t_inv_Lambda.transpose(-2, -1), beta_q_a.unsqueeze(-1) + ).squeeze(-1) + * c_a.unsqueeze(-1) + ).transpose(-2, -1) + + # ---- Compute Q_{ab} (cross-kernel matrix, Eqs. 21-22) ---- + X_i = self.X_tilde_train.unsqueeze(1) # [n, 1, D+F] + X_j = self.X_tilde_train.unsqueeze(0) # [1, n, D+F] + diff_ij = X_i - X_j # x̃_i - x̃_j; [n, n, D+F] (Eq. 22) + + # ell stores ℓ_a; ℓ²_a is the diagonal of Λ_a (Eq. 6) + ell_sq_a = (ell**2)[:, None, :] # [D, 1, D+F] + ell_sq_b = (ell**2)[None, :, :] # [1, D, D+F] + + # Λ_{ab} = (Λ_a^{-1} + Λ_b^{-1})^{-1}, diagonal entries; [D, D, D+F] + inv_ell_sq_sum = 1.0 / ell_sq_a + 1.0 / ell_sq_b + Lambda_ab = 1.0 / inv_ell_sq_sum + + # First exponential in Q_{ab,ij}: kernel product at training inputs (Eq. 22) + # -½ (x̃_i - x̃_j)^T (Λ_a + Λ_b)^{-1} (x̃_i - x̃_j); shape [D, D, n, n] + inv_ell_sq_sum_ab = 1.0 / (ell_sq_a + ell_sq_b) + exp1 = -0.5 * torch.sum( + diff_ij.unsqueeze(0).unsqueeze(0) + * inv_ell_sq_sum_ab.unsqueeze(2).unsqueeze(2) + * diff_ij.unsqueeze(0).unsqueeze(0), + dim=-1, + ) # [D, D, n, n] + + # z̄_{ij} = Λ_{ab} (Λ_a^{-1} x̃_i + Λ_b^{-1} x̃_j) – midpoint (Eq. 22); + # shape [D, D, n, n, D+F] + z_bar = Lambda_ab.unsqueeze(2).unsqueeze(2) * ( + X_i.unsqueeze(0).unsqueeze(0) / ell_sq_a.unsqueeze(2).unsqueeze(2) + + X_j.unsqueeze(0).unsqueeze(0) / ell_sq_b.unsqueeze(2).unsqueeze(2) + ) - cov_xf = upper @ c_xdx + # z_{ij} = z̄_{ij} - μ̃_{t-1}; shape [B, D, D, n, n, D+F] + z_bar = z_bar.unsqueeze(0).expand(B, -1, -1, -1, -1, -1) + z_ij = z_bar - mu_tilde[:, None, None, None, None, :] + z_ij_flat = z_ij.view(B, D, D, n * n, DF) - m_x = m_x + m_dx + # M_{ab} = Σ̃_{t-1} + diag(Λ_{ab}) – matrix in second exp of Eq. 22; + # shape [B, D, D, D+F, D+F] + M_ab = Sigma_tilde[:, None, None] + torch.diag_embed(Lambda_ab) - s_x = s_x + s_dx + cov_xf + cov_xf.transpose(-2, -1) + # Second exponential: -½ z_{ij}^T M_{ab}^{-1} z_{ij}; shape [B, D, D, n, n] + M_ab_solved = torch.linalg.solve(M_ab, z_ij_flat.transpose(-2, -1)).transpose( + -2, -1 + ) + exp2 = (-0.5 * torch.sum(z_ij_flat * M_ab_solved, dim=-1)).view(B, D, D, n, n) + + # R_{ab} = Σ̃_{t-1} (Λ_a^{-1} + Λ_b^{-1}) + I – normalising matrix (Eq. 22); + # shape [B, D, D, D+F, D+F] + R_ab = Sigma_tilde[:, None, None] @ torch.diag_embed( + inv_ell_sq_sum + ) + torch.eye(DF, device=device, dtype=dtype) + det_R_ab = torch.linalg.det(R_ab) # [B, D, D] + + # Scalar prefactor α²_a α²_b / sqrt(|R_{ab}|) (Eq. 22); shape [B, D, D] + c_ab = (alpha_sq.view(1, D, 1) * alpha_sq.view(1, 1, D)) / torch.sqrt(det_R_ab) + + # Q_{ab,ij} (Eq. 22); shape [B, D, D, n, n] + Q_ab = c_ab.unsqueeze(-1).unsqueeze(-1) * torch.exp(exp1.unsqueeze(0) + exp2) + + # ---- Σ_Δ = predictive covariance of Δ (Eqs. 17-23) ---- + # Off-diagonal entries: σ²_{ab} = β_a^T Q_{ab} β_b - μ^a_Δ μ^b_Δ (Eqs. 18, 20) + beta_a = beta.view(1, D, 1, n) # [1, D, 1, n] + beta_b = beta.view(1, 1, D, n) # [1, 1, D, n] + + Q_ab_beta_b = torch.matmul(Q_ab, beta_b.unsqueeze(-1)).squeeze( + -1 + ) # [B, D, D, n] + Sigma_Delta = ( + torch.matmul(beta_a.unsqueeze(-2), Q_ab_beta_b.unsqueeze(-1)) + .squeeze(-1) + .squeeze(-1) + ) # [B, D, D] – β_a^T Q_{ab} β_b (Eq. 20) + + # Diagonal correction E_{x̃}[var_f[Δ_a | x̃]] = α²_a - tr(K_a^{-1} Q_{aa}) + # added to σ²_{aa} (Eqs. 17, 23) + invK_Q = torch.matmul( + inv_K_noisy.unsqueeze(0).unsqueeze(2), # [1, D, 1, n, n] + Q_ab, # [B, D, D, n, n] + ) # [B, D, D, n, n] + trace_invK_Q = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) # [B, D, D] + + diag_idx = torch.arange(D, device=device) + alpha_sq_b = alpha_sq.squeeze(-1).unsqueeze(0).expand(B, -1) # [B, D] + sigma_sq_eps_b = sigma_sq_eps.unsqueeze(0).expand(B, -1) # [B, D] + + # Add α²_a - tr(K_a^{-1} Q_{aa}) + σ²_{ε_a} to the diagonal (Eqs. 17, 23) + Sigma_Delta[:, diag_idx, diag_idx] += ( + alpha_sq_b - trace_invK_Q[:, diag_idx, diag_idx] + sigma_sq_eps_b + ) - s_x = (s_x + s_x.transpose(-2, -1)) / 2.0 - s_x = s_x + 1e-8 * torch.eye(self.obs_dim, device=s_x.device).expand( - s_x.shape[0], -1, -1 + # Subtract outer product of means: Σ_Δ -= μ_Δ μ_Δ^T (Eqs. 17-18) + Sigma_Delta = Sigma_Delta - torch.bmm( + mu_Delta.unsqueeze(-1), mu_Delta.unsqueeze(-2) ) + Sigma_Delta = ( + Sigma_Delta + Sigma_Delta.transpose(-2, -1) + ) / 2 # enforce symmetry + + # ---- Propagate to next-state belief (Eqs. 10-12) ---- + # cov[x_{t-1}, Δ_t] = cov[x_{t-1}, x̃_{t-1}] · cov_xtilde_Delta (Eq. 12) + # cov[x_{t-1}, x̃_{t-1}] is the top-D rows of Σ̃_{t-1}: shape [B, D, D+F]. + # Using only Sigma_x_C_xu ([B, D, F]) here would be wrong — it drops + # the Σ_x block and produces a [B, D, F] @ [B, D+F, D] shape mismatch. + Sigma_x_rows = Sigma_tilde[:, :D, :] # [B, D, D+F] + cov_x_Delta = Sigma_x_rows @ cov_xtilde_Delta # [B, D, D] + + # μ_t = μ_{t-1} + μ_Δ (Eq. 10) + mu_t = mu_x + mu_Delta + + # Σ_t = Σ_{t-1} + Σ_Δ + cov[x_{t-1},Δ_t] + cov[Δ_t,x_{t-1}] (Eq. 11) + Sigma_t = Sigma_x + Sigma_Delta + cov_x_Delta + cov_x_Delta.transpose(-2, -1) + Sigma_t = (Sigma_t + Sigma_t.transpose(-2, -1)) / 2 # enforce symmetry + Sigma_t = Sigma_t + 1e-8 * torch.eye(D, device=device).expand( + B, -1, -1 + ) # jitter + + if unbatched: + mu_t = mu_t.squeeze(0) + Sigma_t = Sigma_t.squeeze(0) out_mean_key, out_var_key = self.out_keys - tensordict.set(out_mean_key, m_x) - tensordict.set(out_var_key, s_x) + tensordict.set(out_mean_key, mu_t) + tensordict.set(out_var_key, Sigma_t) return tensordict def deterministic_forward(self, tensordict: TensorDictBase) -> TensorDictBase: - """Calculates the forward pass when the input observation is deterministic (no variance). + """Deterministic forward pass using GP posterior mean and variance (Eqs. 7-8). + + Used when the input observation is a point estimate with no uncertainty. + Returns the GP posterior mean ``m_f(x̃_*)`` (Eq. 7) and per-dimension + variance ``σ²_f(x̃_*)`` (Eq. 8) for each state dimension. Args: - tensordict (TensorDictBase): A tensordict containing the action and observation tensors. + tensordict (TensorDictBase): Input tensordict with keys defined by + ``in_keys``. Supports arbitrary leading batch dimensions + ``(*batch, D)`` / ``(*batch, F)``, as well as unbatched + ``(D,)`` / ``(F,)`` inputs. Returns: - TensorDictBase: Next observation mean and variance matrices. + TensorDictBase: The same tensordict updated with next-state mean + ``μ_t`` and diagonal covariance ``Σ_t = diag(σ²_Δ)`` at + ``out_keys``. """ u_mean_key, u_var_key, u_cc_key, x_mean_key, x_var_key = self.in_keys - observation_mean = tensordict.get(x_mean_key) - action_mean = tensordict.get(u_mean_key) + mu_x = tensordict.get(x_mean_key) # x_{t-1}, shape (*batch, D) or (D,) + mu_u = tensordict.get(u_mean_key) # u_{t-1}, shape (*batch, F) or (F,) - x_flat = observation_mean.view(-1, self.obs_dim) - u_flat = action_mean.view(-1, self.action_dim) + batch_shape = mu_x.shape[:-1] # leading dims; () for unbatched inputs - X_test = torch.cat([x_flat, u_flat], dim=-1) + # Flatten all leading batch dimensions to a single axis for the GP + # posterior call, then restore the original shape afterwards. + x_flat = mu_x.reshape(-1, self.obs_dim) # [B_flat, D] + u_flat = mu_u.reshape(-1, self.action_dim) # [B_flat, F] - means, stds = [], [] + # x̃_* = [x_{t-1}, u_{t-1}] ∈ R^{B_flat×(D+F)} (Sec. 2.1) + X_tilde_test = torch.cat([x_flat, u_flat], dim=-1) + + # GP posterior mean m_f(x̃_*) (Eq. 7) and std σ_f(x̃_*) (Eq. 8) + mu_Delta_list, sigma_Delta_list = [], [] with torch.no_grad(): - for gp in self.model_list.models: - posterior = gp.posterior(X_test) - means.append(posterior.mean.squeeze(-1)) - stds.append(torch.sqrt(posterior.variance).squeeze(-1)) + for gp_a in self.model_list.models: + posterior_a = gp_a.posterior(X_tilde_test) + mu_Delta_list.append(posterior_a.mean.squeeze(-1)) # m_f (Eq. 7) + sigma_Delta_list.append( + torch.sqrt(posterior_a.variance).squeeze(-1) # σ_f (Eq. 8) + ) - delta_mean_flat = torch.stack(means, dim=-1) - delta_std_flat = torch.stack(stds, dim=-1) + # μ_Δ – predicted residual mean; restore original batch shape + mu_Delta = torch.stack(mu_Delta_list, dim=-1).view(*batch_shape, self.obs_dim) + + # σ_Δ – predicted residual std; restore original batch shape + sigma_Delta = torch.stack(sigma_Delta_list, dim=-1).view( + *batch_shape, self.obs_dim + ) - batch_shape = observation_mean.shape[:-1] - delta_mean = delta_mean_flat.view(*batch_shape, self.obs_dim) - delta_std = delta_std_flat.view(*batch_shape, self.obs_dim) + # μ_t = x_{t-1} + μ_Δ (deterministic version of Eq. 10) + mu_t = mu_x + mu_Delta - m_x = observation_mean + delta_mean - s_x = torch.diag_embed(delta_std**2) + # Σ_t = diag(σ²_Δ) – diagonal covariance from independent GP variances (Eq. 8) + Sigma_t = torch.diag_embed(sigma_Delta**2) out_mean_key, out_var_key = self.out_keys - tensordict.set(out_mean_key, m_x) - tensordict.set(out_var_key, s_x) + tensordict.set(out_mean_key, mu_t) + tensordict.set(out_var_key, Sigma_t) return tensordict