Skip to content

[Feature] PILCO#3537

Open
PSXBRosa wants to merge 9 commits intopytorch:mainfrom
PSXBRosa:main
Open

[Feature] PILCO#3537
PSXBRosa wants to merge 9 commits intopytorch:mainfrom
PSXBRosa:main

Conversation

@PSXBRosa
Copy link

Description

This PR introduces the implementation of the PILCO (Probabilistic Inference for Learning Control) algorithm to TorchRL.

Key details of the implementation:

  • Gaussian Process Regression: I utilized the external libraries botorch and gpytorch for the GPR components. this avoids the overhead and complexity of maintaining a custom GPR implementation.
  • Moment Matching: I initially considered and experimented with a Monte Carlo approach for moment matching to simplify the underlying mathematics. While I couldn't get the MC approach to stabilize and work correctly (though it remains an interesting alternative), the current working implementation relies on the analytical forms for moment matching. This aligns directly with what was done in the original PILCO paper by Deisenroth and Rasmussen.
  • Credits: I want to give a huge thanks and credit to @nrontsis. I used the code from their repository (https://github.com/nrontsis/PILCO) as a valuable implementation reference and to test/validate different parts of my own PyTorch adaptation.

Motivation and Context

PILCO is a highly sample-efficient model-based reinforcement learning algorithm, making it a valuable addition to the library's algorithm suite.

close #3513

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3537

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 17 Awaiting Approval

As of commit 671b265 with merge base 4d2c3cb (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 27, 2026
@github-actions
Copy link
Contributor

⚠️ PR Title Label Error

Unknown or invalid prefix [Algorithm].

Current title: [Algorithm] PILCO

Supported Prefixes (case-sensitive)

Your PR title must start with exactly one of these prefixes:

Prefix Label Applied Example
[BugFix] BugFix [BugFix] Fix memory leak in collector
[Feature] Feature [Feature] Add new optimizer
[Doc] or [Docs] Documentation [Doc] Update installation guide
[Refactor] Refactoring [Refactor] Clean up module imports
[CI] CI [CI] Fix workflow permissions
[Test] or [Tests] Tests [Tests] Add unit tests for buffer
[Environment] or [Environments] Environments [Environments] Add Gymnasium support
[Data] Data [Data] Fix replay buffer sampling
[Performance] or [Perf] Performance [Performance] Optimize tensor ops
[BC-Breaking] bc breaking [BC-Breaking] Remove deprecated API
[Deprecation] Deprecation [Deprecation] Mark old function
[Quality] Quality [Quality] Fix typos and add codespell

Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).

@PSXBRosa PSXBRosa changed the title [Algorithm] PILCO [Feature] PILCO Feb 27, 2026
@github-actions github-actions bot added the Feature New feature label Feb 27, 2026
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent first attempt!

Let's try to move most of this to core!

torchrl/modules/models/
    gp.py                    # BoTorchGPWorldModel (renamed: GPWorldModel?)
    rbf_controller.py        # RBFController

torchrl/objectives/
    pilco.py                 # SaturatingCost (the generic cost module)

sota-implementations/pilco/
    pilco.py                 # Training loop (stays here)
    utils.py                 # make_env, pendulum_cost (thin wrappers, stays here)
    config.yaml              # Config (stays here)

Missing tests:

  • Unit tests for RBFController moment matching (forward pass, squash_sin)
  • Unit tests for BoTorchGPWorldModel (fit, deterministic_forward, uncertain_forward)
  • At minimum a smoke test for the full PILCO loop (see workflow in sota-implementations CI workflow)
  • Numerical validation against the reference implementation (the author credits nrontsis/PILCO) if possible - ok if not

There are no docs. No docstrings on any class or method beyond the one-line pendulum_cost docstring. For core components, all public methods need proper docstrings with shapes documented (especially the moment matching formulas which are dense linear algebra). Docs must be linked in docs/source/reference/...

Avoid single letter variables unless they're indices (for in in range(...)) which are heavily used throughout the moment matching code (m, s, c, B, D, L, U, Q, t, z). These follow the paper notation, which is fine, but in core they should have comments referencing which equation in the paper each block corresponds to.

policy_for_env closure (pilco.py lines 166-200) -- this is an ad-hoc bridge between the Gaussian policy interface and a standard env that expects deterministic actions. In core, this should be a proper transform or wrapper (e.g., MeanActionSelector or similar) rather than a closure rebuilt every epoch.

return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1)


class BoTorchGPWorldModel(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If properly documented i'm happy with having this in core!

return observation_mean + delta_mean, torch.diag_embed(delta_std**2)


class ImaginedEnv(ModelBasedEnvBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, maybe we want this in core.
How different is it from DreamerEnv? Can we blend the two together? (ok if we want to keep them separated)

return out


class RBFController(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto happy to have it in core

Comment on lines +270 to +271
for a in range(self.obs_dim):
for b in range(self.obs_dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot in here can be vectorized

else:
return self.deterministic_forward(action, observation)

def freeze_and_detach(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab)
trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1)

pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid .item()
Call .tolist() if absolutely necessary before. I think a plain tensor should work. .item() breaks compile and requires cuda sync.

t = torch.linalg.solve(B_mat, iN.mT).mT

exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1))
detB = torch.linalg.det(B_mat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.linalg.slogdet would be numerically more stable (and is already partially used via the Cholesky log-det pattern elsewhere in the same code)

scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2)
lb = scaled_exp * beta.unsqueeze(0)

det_B = torch.linalg.det(B_mat)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto - let's think in logs!

batch_size, num_train_pts, num_train_pts
)

det_R_ab = torch.linalg.det(R_ab)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto let's think in logs!

Comment on lines +5 to +7
from botorch.fit import fit_gpytorch_mll

from botorch.models import ModelListGP, SingleTaskGP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

botorch / gpytorch need to be added to the repo's dependencies as optional deps

@alektebel
Copy link

alektebel commented Mar 4, 2026

Hi @PSXBRosa ,

Thank you for your work implementing PILCO for torchrl.

A few days ago I opened a discussion about adding MC-PILCO, discussion n 3538, and reading through @vmoens ' review it's clear we're both going to depend on the same core primitives (BoTorchGPWorldModel, RBFController). Since @vmoens is already asking for those to be moved into torchrl/modules/models/, it seems like the cleanest path forward is to coordinate so we don't end up duplicating changes or creating interface conflicts.

A couple of options as I see it:

  1. We collaborate on this PR directly, so the core primitives are designed with both PILCO and MC-PILCO in mind from the start (e.g. making sure BoTorchGPWorldModel exposes both uncertain_forward and a sample_forward path for the MC variant).
  2. I open a separate MC-PILCO PR after this one merges, building on top of your primitives once they're in main.

Happy to go with whatever works best for you. If you want to discuss further feel free to reply here or reach out on discord (cabesamotora) !

@PSXBRosa
Copy link
Author

PSXBRosa commented Mar 5, 2026

Hi @alektebel, thanks for reaching out.

#1 seems like the cleanest approach. I'm in the middle of a move right now so I haven't had much time for the PR this week, but I've started on vmoens' comments and plan to have the current issues resolved by the end of next week.

My only progress so far is moving the loss to core. I can push this as a WIP if you'd like to take a look? How do you envision building on top of the current classes? Do you have a specific inheritance plan in mind? I'll ping you on discord.

@vmoens
Copy link
Collaborator

vmoens commented Mar 12, 2026

@PSXBRosa I can help porting stuff to core if you want me to :)

@PSXBRosa
Copy link
Author

@vmoens, I think that would help! I ended up having less time than I expected to work on the PR this last week.

PSXBRosa and others added 3 commits March 14, 2026 14:26
Port RBFController, ImaginedEnv, and MeanActionSelector from
sota-implementations/pilco/utils.py to torchrl core. Add unit tests,
documentation entries, and a botorch CI job in test-linux-libs.

- torchrl/modules/models/rbf_controller.py: RBF controller for
  moment-matching policy search with full docstrings
- torchrl/envs/model_based/imagined.py: general-purpose imagination
  env for model-based policy search
- torchrl/envs/transforms/mean_action_selector.py: transform bridging
  Gaussian belief-space policies with standard environments
- Improve GPWorldModel: slogdet for numerical stability, remove .item()
- Register GPWorldModel and RBFController in module exports
- Add all new components to docs
- Add 39 unit tests in test/test_objectives.py
- Add botorch CI job to test-linux-libs workflow
- Update sota-implementations/pilco to import from core

Made-with: Cursor
Made-with: Cursor

# Conflicts:
#	sota-implementations/pilco/pilco.py
#	sota-implementations/pilco/utils.py
#	torchrl/modules/models/__init__.py
#	torchrl/modules/models/rbf_controller.py
@github-actions github-actions bot added Documentation Improvements or additions to documentation CI Has to do with CI setup (e.g. wheels & builds, tests...) Transforms labels Mar 14, 2026
Comment on lines +275 to +276
for a in range(self.obs_dim):
for b in range(self.obs_dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's optimize this when we can!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! I'll work on this after work today

vmoens and others added 2 commits March 15, 2026 14:41
…mpatibility

- Added optional `pilco` dependency group (botorch, gpytorch) to `pyproject.toml`.

- Refactored `GPWorldModel` to follow the TensorDict module interface:
  - `forward` now accepts and returns a `TensorDict` instead of tuples.
  - Added configurable `in_keys` and `out_keys` for flexible integration.
  - Default keys support probabilistic inputs (`action.mean`, `action.var`,
    `action.cross_covariance`, `observation.mean`, `observation.var`).
  - Deterministic and uncertain forward passes now write results directly into
    the TensorDict.
  - Removed `freeze_and_detach` utility as it is no longer required.
  - Updated internal logic to read/write through TensorDict keys.

- Updated `ImaginedEnv`:
  - Added `next_observation_key` argument to specify where the world model writes
    predicted observations.
  - Default key is `("next", "observation")`.
  - Adjusted environment step to read observations using this configurable key.
  - Updated documentation examples to reflect new TensorDict conventions.

- Simplified PILCO implementation:
  - World model is now used directly instead of wrapping it with `TensorDictModule`.
  - Replaced `freeze_and_detach()` with `eval()` mode.
  - Adapted rollout handling to align keys between real and imagined trajectories
    (selecting mean values for observation/action and matching rollout structure).
  - Ensured compatibility when concatenating evaluation rollouts into the dataset.
def forward(
self, action: TensorDictBase, observation: TensorDictBase
) -> tuple[torch.Tensor, torch.Tensor]:
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on a previous comment in the PR, I thought changing this method signature would be better.


test_rollout = test_rollout.select(
*rollout.keys(include_nested=True, leaves_only=True)
)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to the changes to the GP class, this was necessary. I'm not sure if there's a cleaner way to handle this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI Has to do with CI setup (e.g. wheels & builds, tests...) CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Documentation Improvements or additions to documentation Feature New feature Modules Objectives sota-implementations/ Transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request] Implement PILCO (Probabilistic Inference for Learning Control)

3 participants