Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
PrioritizedSampler,
PrioritizedSliceSampler,
RandomSampler,
Sampler,
SamplerEnsemble,
SamplerWithoutReplacement,
SliceSampler,
Expand Down Expand Up @@ -4692,6 +4693,130 @@ def test_prb_ndim(self):
(s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all()
)

def test_replacement_kwarg_random(self):
# RandomSampler(replacement=True) is a regular RandomSampler
s = RandomSampler()
assert type(s) is RandomSampler
s = RandomSampler(replacement=True)
assert type(s) is RandomSampler

# RandomSampler(replacement=False) dispatches to SamplerWithoutReplacement
s = RandomSampler(replacement=False)
assert type(s) is SamplerWithoutReplacement
# default kwargs propagate
assert s.drop_last is False
assert s.shuffle is True

# Extra kwargs are forwarded to SamplerWithoutReplacement
s = RandomSampler(replacement=False, drop_last=True, shuffle=False)
assert type(s) is SamplerWithoutReplacement
assert s.drop_last is True
assert s.shuffle is False

# isinstance is preserved
assert isinstance(s, Sampler)
assert isinstance(s, SamplerWithoutReplacement)

def test_replacement_kwarg_slice(self):
# SliceSampler(replacement=True) is a regular SliceSampler
s = SliceSampler(slice_len=5)
assert type(s) is SliceSampler
s = SliceSampler(replacement=True, slice_len=5)
assert type(s) is SliceSampler

# SliceSampler(replacement=False) dispatches to SliceSamplerWithoutReplacement
s = SliceSampler(replacement=False, slice_len=5)
assert type(s) is SliceSamplerWithoutReplacement
assert s.slice_len == 5
assert s.drop_last is False
assert s.shuffle is True

# Extra without-replacement kwargs forward correctly
s = SliceSampler(
replacement=False,
slice_len=5,
drop_last=True,
shuffle=False,
traj_key="episode",
strict_length=False,
)
assert type(s) is SliceSamplerWithoutReplacement
assert s.slice_len == 5
assert s.drop_last is True
assert s.shuffle is False
assert s.traj_key == "episode"
assert s.strict_length is False

# isinstance preserves the SliceSampler hierarchy
assert isinstance(s, SliceSampler)
assert isinstance(s, SamplerWithoutReplacement)

def test_replacement_kwarg_subclass_unaffected(self):
# PrioritizedSliceSampler inherits from SliceSampler but should NOT dispatch
s = PrioritizedSliceSampler(
slice_len=5, max_capacity=10, alpha=0.5, beta=0.5
)
assert type(s) is PrioritizedSliceSampler

# SamplerWithoutReplacement(replacement=...) is a no-op pop
s = SamplerWithoutReplacement(replacement=False, drop_last=True)
assert type(s) is SamplerWithoutReplacement
assert s.drop_last is True
s = SliceSamplerWithoutReplacement(replacement=False, slice_len=5)
assert type(s) is SliceSamplerWithoutReplacement
assert s.slice_len == 5

def test_replacement_kwarg_no_variant_errors(self):
# PrioritizedSampler has no without-replacement variant -> TypeError
with pytest.raises(TypeError, match="no without-replacement variant"):
PrioritizedSampler(
max_capacity=10, alpha=0.5, beta=0.5, replacement=False
)

def test_replacement_kwarg_in_replay_buffer(self):
# End-to-end: a buffer using RandomSampler(replacement=False) should
# exhaust the storage without duplicate indices (like SamplerWithoutReplacement).
torch.manual_seed(0)
data = TensorDict({"a": torch.arange(11)}, batch_size=[11])
rb = ReplayBuffer(
storage=LazyTensorStorage(11),
sampler=RandomSampler(replacement=False, drop_last=False),
batch_size=3,
)
rb.extend(data)
seen = set()
for _ in range(4):
seen.update(rb.sample()["a"].tolist())
assert seen == set(range(11))

def test_replacement_kwarg_slice_in_replay_buffer(self):
# End-to-end: SliceSampler(replacement=False) returns sub-trajectories
torch.manual_seed(0)
episodes = torch.zeros(60, dtype=torch.long)
episodes[:20] = 0
episodes[20:40] = 1
episodes[40:] = 2
data = TensorDict(
{"episode": episodes, "obs": torch.arange(60)},
batch_size=[60],
)
rb = ReplayBuffer(
storage=LazyTensorStorage(60),
sampler=SliceSampler(
replacement=False,
slice_len=5,
traj_key="episode",
strict_length=True,
),
batch_size=10,
)
rb.extend(data)
sample = rb.sample()
# batch_size=10, slice_len=5 -> 2 slices of 5 contiguous obs each
obs = sample["obs"].view(2, 5)
diffs = obs[:, 1:] - obs[:, :-1]
assert (diffs == 1).all(), obs


class TestStalenessAwareSampler:
"""Tests for StalenessAwareSampler."""
Expand Down
83 changes: 78 additions & 5 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import textwrap
import warnings
from abc import ABC, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from collections import OrderedDict
from copy import copy, deepcopy
from multiprocessing.context import get_spawning_popen
Expand Down Expand Up @@ -55,7 +55,48 @@
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."


class Sampler(ABC):
# Maps a "with replacement" sampler class to its "without replacement" counterpart.
# Populated at module import time after the relevant classes are defined.
# Consumed by :class:`_SamplerMeta` to dispatch ``Cls(replacement=False, ...)`` calls
# to ``_REPLACEMENT_DISPATCH[Cls](...)``.
_REPLACEMENT_DISPATCH: dict[type, type] = {}


class _SamplerMeta(ABCMeta):
"""Metaclass enabling ``replacement=False`` dispatch on with-replacement samplers.

When a class registered in :data:`_REPLACEMENT_DISPATCH` (e.g.
:class:`RandomSampler`, :class:`SliceSampler`) is instantiated with
``replacement=False``, the call is dispatched to its without-replacement
counterpart (:class:`SamplerWithoutReplacement` or
:class:`SliceSamplerWithoutReplacement`).

Calls with ``replacement=True`` (the default) behave exactly like a normal
instantiation: the ``replacement`` kwarg is popped before the constructor
runs, so existing ``__init__`` signatures don't need to be changed.

Passing ``replacement=False`` to a sampler that has no without-replacement
variant raises :class:`TypeError`. Passing ``replacement=False`` to a
sampler that is itself already a without-replacement variant is allowed
and treated as a no-op.
"""

def __call__(cls, *args, **kwargs):
if "replacement" in kwargs:
replacement = kwargs.pop("replacement")
if not replacement:
alt = _REPLACEMENT_DISPATCH.get(cls)
if alt is not None:
return alt(*args, **kwargs)
if cls not in _REPLACEMENT_DISPATCH.values():
raise TypeError(
f"{cls.__name__} has no without-replacement variant; "
"cannot be instantiated with replacement=False."
)
return super().__call__(*args, **kwargs)


class Sampler(ABC, metaclass=_SamplerMeta):
"""A generic sampler base class for composable Replay Buffers."""

# Some samplers - mainly those without replacement -
Expand Down Expand Up @@ -133,9 +174,23 @@ def __getstate__(self):
class RandomSampler(Sampler):
"""A uniformly random sampler for composable replay buffers.

Args:
batch_size (int, optional): if provided, the batch size to be used by
the replay buffer when calling :meth:`ReplayBuffer.sample`.
Keyword Args:
replacement (bool, optional): if ``False``, the call is dispatched to
:class:`SamplerWithoutReplacement`, and any additional keyword
arguments (e.g. ``drop_last``, ``shuffle``) are forwarded to its
constructor. Defaults to ``True``.

Examples:
>>> from torchrl.data import RandomSampler, SamplerWithoutReplacement
>>> isinstance(RandomSampler(), RandomSampler)
True
>>> isinstance(RandomSampler(replacement=False), SamplerWithoutReplacement)
True
>>> isinstance(
... RandomSampler(replacement=False, drop_last=True),
... SamplerWithoutReplacement,
... )
True

"""

Expand Down Expand Up @@ -1154,12 +1209,19 @@ class SliceSampler(Sampler):

This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Equivalently, ``SliceSampler(replacement=False, ...)`` dispatches to
:class:`SliceSamplerWithoutReplacement` and forwards the remaining keyword
arguments (including ``drop_last`` and ``shuffle``).

.. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate
its execution, prefer using `end_key` over `traj_key`, and consider the following
keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`.

Keyword Args:
replacement (bool, optional): if ``False``, the call is dispatched to
:class:`SliceSamplerWithoutReplacement` (which accepts the same
keyword arguments as well as ``drop_last`` and ``shuffle``).
Defaults to ``True``.
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
Expand Down Expand Up @@ -3173,3 +3235,14 @@ def __len__(self):
def __repr__(self):
samplers = textwrap.indent(f"samplers={self._samplers}", " " * 4)
return f"{self.__class__.__name__}(\n{samplers})"


# Register without-replacement dispatch targets. Importing this module makes
# ``RandomSampler(replacement=False)`` dispatch to ``SamplerWithoutReplacement``
# and ``SliceSampler(replacement=False)`` to ``SliceSamplerWithoutReplacement``.
_REPLACEMENT_DISPATCH.update(
{
RandomSampler: SamplerWithoutReplacement,
SliceSampler: SliceSamplerWithoutReplacement,
}
)
Loading