Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3537
Note: Links to docs will display an error until the docs builds have been completed.
|
|
| Prefix | Label Applied | Example |
|---|---|---|
[BugFix] |
BugFix | [BugFix] Fix memory leak in collector |
[Feature] |
Feature | [Feature] Add new optimizer |
[Doc] or [Docs] |
Documentation | [Doc] Update installation guide |
[Refactor] |
Refactoring | [Refactor] Clean up module imports |
[CI] |
CI | [CI] Fix workflow permissions |
[Test] or [Tests] |
Tests | [Tests] Add unit tests for buffer |
[Environment] or [Environments] |
Environments | [Environments] Add Gymnasium support |
[Data] |
Data | [Data] Fix replay buffer sampling |
[Performance] or [Perf] |
Performance | [Performance] Optimize tensor ops |
[BC-Breaking] |
bc breaking | [BC-Breaking] Remove deprecated API |
[Deprecation] |
Deprecation | [Deprecation] Mark old function |
[Quality] |
Quality | [Quality] Fix typos and add codespell |
Note: Common variations like singular/plural are supported (e.g., [Doc] or [Docs]).
vmoens
left a comment
There was a problem hiding this comment.
Excellent first attempt!
Let's try to move most of this to core!
torchrl/modules/models/
gp.py # BoTorchGPWorldModel (renamed: GPWorldModel?)
rbf_controller.py # RBFController
torchrl/objectives/
pilco.py # SaturatingCost (the generic cost module)
sota-implementations/pilco/
pilco.py # Training loop (stays here)
utils.py # make_env, pendulum_cost (thin wrappers, stays here)
config.yaml # Config (stays here)
Missing tests:
- Unit tests for RBFController moment matching (forward pass, squash_sin)
- Unit tests for BoTorchGPWorldModel (fit, deterministic_forward, uncertain_forward)
- At minimum a smoke test for the full PILCO loop (see workflow in sota-implementations CI workflow)
- Numerical validation against the reference implementation (the author credits nrontsis/PILCO) if possible - ok if not
There are no docs. No docstrings on any class or method beyond the one-line pendulum_cost docstring. For core components, all public methods need proper docstrings with shapes documented (especially the moment matching formulas which are dense linear algebra). Docs must be linked in docs/source/reference/...
Avoid single letter variables unless they're indices (for in in range(...)) which are heavily used throughout the moment matching code (m, s, c, B, D, L, U, Q, t, z). These follow the paper notation, which is fine, but in core they should have comments referencing which equation in the paper each block corresponds to.
policy_for_env closure (pilco.py lines 166-200) -- this is an ad-hoc bridge between the Gaussian policy interface and a standard env that expects deterministic actions. In core, this should be a proper transform or wrapper (e.g., MeanActionSelector or similar) rather than a closure rebuilt every epoch.
sota-implementations/pilco/utils.py
Outdated
| return (1.0 - det_term * torch.exp(exp_term)).sum(dim=1) | ||
|
|
||
|
|
||
| class BoTorchGPWorldModel(nn.Module): |
There was a problem hiding this comment.
If properly documented i'm happy with having this in core!
sota-implementations/pilco/utils.py
Outdated
| return observation_mean + delta_mean, torch.diag_embed(delta_std**2) | ||
|
|
||
|
|
||
| class ImaginedEnv(ModelBasedEnvBase): |
There was a problem hiding this comment.
Ditto, maybe we want this in core.
How different is it from DreamerEnv? Can we blend the two together? (ok if we want to keep them separated)
sota-implementations/pilco/utils.py
Outdated
| return out | ||
|
|
||
|
|
||
| class RBFController(nn.Module): |
There was a problem hiding this comment.
ditto happy to have it in core
sota-implementations/pilco/utils.py
Outdated
| for a in range(self.obs_dim): | ||
| for b in range(self.obs_dim): |
There was a problem hiding this comment.
a lot in here can be vectorized
sota-implementations/pilco/utils.py
Outdated
| else: | ||
| return self.deterministic_forward(action, observation) | ||
|
|
||
| def freeze_and_detach(self) -> None: |
sota-implementations/pilco/utils.py
Outdated
| invK_Q = torch.matmul(inv_K[a].unsqueeze(0), Q_ab) | ||
| trace_val = torch.diagonal(invK_Q, dim1=-2, dim2=-1).sum(-1) | ||
|
|
||
| pred_cov[:, a, a] += variances[a] - trace_val + noises[a].item() |
There was a problem hiding this comment.
avoid .item()
Call .tolist() if absolutely necessary before. I think a plain tensor should work. .item() breaks compile and requires cuda sync.
sota-implementations/pilco/utils.py
Outdated
| t = torch.linalg.solve(B_mat, iN.mT).mT | ||
|
|
||
| exp_term = torch.exp(-0.5 * torch.sum(iN * t, dim=-1)) | ||
| detB = torch.linalg.det(B_mat) |
There was a problem hiding this comment.
torch.linalg.slogdet would be numerically more stable (and is already partially used via the Cholesky log-det pattern elsewhere in the same code)
sota-implementations/pilco/utils.py
Outdated
| scaled_exp = torch.exp(-torch.sum(inv_N * t, dim=-1) / 2) | ||
| lb = scaled_exp * beta.unsqueeze(0) | ||
|
|
||
| det_B = torch.linalg.det(B_mat) |
There was a problem hiding this comment.
ditto - let's think in logs!
sota-implementations/pilco/utils.py
Outdated
| batch_size, num_train_pts, num_train_pts | ||
| ) | ||
|
|
||
| det_R_ab = torch.linalg.det(R_ab) |
sota-implementations/pilco/utils.py
Outdated
| from botorch.fit import fit_gpytorch_mll | ||
|
|
||
| from botorch.models import ModelListGP, SingleTaskGP |
There was a problem hiding this comment.
botorch / gpytorch need to be added to the repo's dependencies as optional deps
|
Hi @PSXBRosa , Thank you for your work implementing PILCO for torchrl. A few days ago I opened a discussion about adding MC-PILCO, discussion n 3538, and reading through @vmoens ' review it's clear we're both going to depend on the same core primitives ( A couple of options as I see it:
Happy to go with whatever works best for you. If you want to discuss further feel free to reply here or reach out on discord (cabesamotora) ! |
|
Hi @alektebel, thanks for reaching out. #1 seems like the cleanest approach. I'm in the middle of a move right now so I haven't had much time for the PR this week, but I've started on vmoens' comments and plan to have the current issues resolved by the end of next week. My only progress so far is moving the loss to core. I can push this as a WIP if you'd like to take a look? How do you envision building on top of the current classes? Do you have a specific inheritance plan in mind? I'll ping you on discord. |
|
@PSXBRosa I can help porting stuff to core if you want me to :) |
|
@vmoens, I think that would help! I ended up having less time than I expected to work on the PR this last week. |
Port RBFController, ImaginedEnv, and MeanActionSelector from sota-implementations/pilco/utils.py to torchrl core. Add unit tests, documentation entries, and a botorch CI job in test-linux-libs. - torchrl/modules/models/rbf_controller.py: RBF controller for moment-matching policy search with full docstrings - torchrl/envs/model_based/imagined.py: general-purpose imagination env for model-based policy search - torchrl/envs/transforms/mean_action_selector.py: transform bridging Gaussian belief-space policies with standard environments - Improve GPWorldModel: slogdet for numerical stability, remove .item() - Register GPWorldModel and RBFController in module exports - Add all new components to docs - Add 39 unit tests in test/test_objectives.py - Add botorch CI job to test-linux-libs workflow - Update sota-implementations/pilco to import from core Made-with: Cursor
Made-with: Cursor # Conflicts: # sota-implementations/pilco/pilco.py # sota-implementations/pilco/utils.py # torchrl/modules/models/__init__.py # torchrl/modules/models/rbf_controller.py
torchrl/modules/models/gp.py
Outdated
| for a in range(self.obs_dim): | ||
| for b in range(self.obs_dim): |
There was a problem hiding this comment.
let's optimize this when we can!
There was a problem hiding this comment.
sounds good! I'll work on this after work today
…mpatibility
- Added optional `pilco` dependency group (botorch, gpytorch) to `pyproject.toml`.
- Refactored `GPWorldModel` to follow the TensorDict module interface:
- `forward` now accepts and returns a `TensorDict` instead of tuples.
- Added configurable `in_keys` and `out_keys` for flexible integration.
- Default keys support probabilistic inputs (`action.mean`, `action.var`,
`action.cross_covariance`, `observation.mean`, `observation.var`).
- Deterministic and uncertain forward passes now write results directly into
the TensorDict.
- Removed `freeze_and_detach` utility as it is no longer required.
- Updated internal logic to read/write through TensorDict keys.
- Updated `ImaginedEnv`:
- Added `next_observation_key` argument to specify where the world model writes
predicted observations.
- Default key is `("next", "observation")`.
- Adjusted environment step to read observations using this configurable key.
- Updated documentation examples to reflect new TensorDict conventions.
- Simplified PILCO implementation:
- World model is now used directly instead of wrapping it with `TensorDictModule`.
- Replaced `freeze_and_detach()` with `eval()` mode.
- Adapted rollout handling to align keys between real and imagined trajectories
(selecting mean values for observation/action and matching rollout structure).
- Ensured compatibility when concatenating evaluation rollouts into the dataset.
| def forward( | ||
| self, action: TensorDictBase, observation: TensorDictBase | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| def forward(self, tensordict: TensorDictBase) -> TensorDictBase: |
There was a problem hiding this comment.
Based on a previous comment in the PR, I thought changing this method signature would be better.
|
|
||
| test_rollout = test_rollout.select( | ||
| *rollout.keys(include_nested=True, leaves_only=True) | ||
| ) |
There was a problem hiding this comment.
Due to the changes to the GP class, this was necessary. I'm not sure if there's a cleaner way to handle this.
Description
This PR introduces the implementation of the PILCO (Probabilistic Inference for Learning Control) algorithm to TorchRL.
Key details of the implementation:
Motivation and Context
PILCO is a highly sample-efficient model-based reinforcement learning algorithm, making it a valuable addition to the library's algorithm suite.
close #3513
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!