Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ Intermediate
tutorials/collector_trajectory_assembly
tutorials/evaluator
tutorials/rb_tutorial
tutorials/memory_efficient_rl
tutorials/export

Advanced
Expand Down
10 changes: 6 additions & 4 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,12 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
compact_obs (bool, optional): if ``True``, each worker drops the
observation and state keys from the ``("next", ...)`` sub-tensordict
before stacking. See
:class:`~torchrl.collectors.SyncDataCollector` for details and the
pairing with
:class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor`
at sampling time.
:class:`~torchrl.collectors.SyncDataCollector` for the full
explanation and tradeoffs (most notably:
:class:`~torchrl.envs.transforms.MultiStepTransform` cannot be used
in compact mode), plus the pairing with
:class:`~torchrl.envs.transforms.NextStateReconstructor` at
sampling time and the *Memory-efficient RL training* tutorial.
Defaults to ``False``.
worker_idx (int, optional): the index of the worker.

Expand Down
26 changes: 18 additions & 8 deletions torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,24 @@ class Collector(BaseCollector):
compact_obs (bool, optional): if ``True``, the collector drops the
observation and state keys from the ``("next", ...)`` sub-tensordict
before stacking per-step data. Those keys are bit-for-bit identical
to the root keys of the next step (modulo the last frame of the
rollout), so storing both copies wastes memory. ``("next", "reward")``,
``("next", "done")`` and ``("next", "truncated")`` are preserved
because they cannot be reconstructed from the root keys. The dropped
keys can be re-hydrated at sampling time with
:class:`~torchrl.envs.transforms.rb_transforms.NextStateReconstructor`
when consuming a :class:`~torchrl.data.SliceSampler`-backed replay
buffer. Defaults to ``False``.
to the root keys of the next step (modulo the last frame of each
trajectory), so storing both copies roughly doubles the observation
footprint for nothing. ``("next", "reward")``, ``("next", "done")``
and ``("next", "truncated")`` are preserved because they cannot be
reconstructed from the root keys. The dropped keys can be
re-hydrated at sampling time with
:class:`~torchrl.envs.transforms.NextStateReconstructor`; trajectory
ends will carry ``NaN`` for the missing ``("next", obs)`` and the
value-estimator forward pass substitutes a finite placeholder so
GAE / TD targets stay numerically defined (see
:meth:`~torchrl.objectives.value.ValueEstimatorBase._sanitize_next_obs_nan`).
Default is ``False`` because the canonical ``("next", obs)`` is
still required by some downstream losses — most notably
:class:`~torchrl.envs.transforms.MultiStepTransform`, which uses
the n-step ``("next", obs)`` (and its in-trajectory fallback at
the last ``n - 1`` frames) and cannot reconstruct that from root
obs alone. See also the *Memory-efficient RL training* tutorial
for an end-to-end pipeline. Defaults to ``False``.

Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down
15 changes: 15 additions & 0 deletions torchrl/envs/transforms/rb_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ class NextStateReconstructor(Transform):
... }, batch_size=[8])
>>> rb.extend(data)
>>> sample = rb.sample() # ('next', 'observation') is reconstructed

.. seealso::

:class:`~torchrl.collectors.SyncDataCollector`'s ``compact_obs`` flag
is the producer side of this transform — it drops the duplicated
``("next", obs)`` before stacking. Trajectory ends carry ``NaN`` after
rehydration; the value-estimator pipeline keeps GAE / TD targets
numerically defined via
:meth:`~torchrl.objectives.value.ValueEstimatorBase._sanitize_next_obs_nan`.
:class:`~torchrl.envs.transforms.MultiStepTransform` is **not**
compatible with the compact path: it needs the canonical ``("next", obs)``
to read the n-step neighbour (and to keep working at the last
``n - 1`` frames of every trajectory, where the n-step lookup falls
back to the in-trajectory neighbours). See the *Memory-efficient RL
training* tutorial for an end-to-end pipeline.
"""

def __init__(
Expand Down
8 changes: 8 additions & 0 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,14 @@ def _sanitize_next_obs_nan(

Operates on a shallow copy so the caller's ``tensordict`` is not
mutated.

.. seealso::

:class:`~torchrl.collectors.SyncDataCollector` (``compact_obs``)
and :class:`~torchrl.envs.transforms.NextStateReconstructor` are
the typical producers of ``NaN`` next-observations at trajectory
ends. The *Memory-efficient RL training* tutorial wires the three
together end-to-end.
"""
copied = False
for k in in_keys:
Expand Down
Loading
Loading