From 5c1207154d8223a7a2efa08ceea8e48fec1657a6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 13 May 2026 21:57:17 +0100 Subject: [PATCH] [Feature] Dispatch RandomSampler/SliceSampler to without-replacement variant via replacement=False A metaclass on Sampler intercepts the ``replacement`` keyword: when False, ``RandomSampler(...)`` is dispatched to ``SamplerWithoutReplacement(...)`` and ``SliceSampler(...)`` to ``SliceSamplerWithoutReplacement(...)``. Additional kwargs (``drop_last``, ``shuffle``) are forwarded to the variant constructor. Subclasses without a registered variant raise ``TypeError`` when invoked with ``replacement=False``. Co-Authored-By: Claude Opus 4.7 --- test/test_rb.py | 125 ++++++++++++++++++++++++ torchrl/data/replay_buffers/samplers.py | 83 +++++++++++++++- 2 files changed, 203 insertions(+), 5 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 13778eac894..ccaa553b84b 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -62,6 +62,7 @@ PrioritizedSampler, PrioritizedSliceSampler, RandomSampler, + Sampler, SamplerEnsemble, SamplerWithoutReplacement, SliceSampler, @@ -4692,6 +4693,130 @@ def test_prb_ndim(self): (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all() ) + def test_replacement_kwarg_random(self): + # RandomSampler(replacement=True) is a regular RandomSampler + s = RandomSampler() + assert type(s) is RandomSampler + s = RandomSampler(replacement=True) + assert type(s) is RandomSampler + + # RandomSampler(replacement=False) dispatches to SamplerWithoutReplacement + s = RandomSampler(replacement=False) + assert type(s) is SamplerWithoutReplacement + # default kwargs propagate + assert s.drop_last is False + assert s.shuffle is True + + # Extra kwargs are forwarded to SamplerWithoutReplacement + s = RandomSampler(replacement=False, drop_last=True, shuffle=False) + assert type(s) is SamplerWithoutReplacement + assert s.drop_last is True + assert s.shuffle is False + + # isinstance is preserved + assert isinstance(s, Sampler) + assert isinstance(s, SamplerWithoutReplacement) + + def test_replacement_kwarg_slice(self): + # SliceSampler(replacement=True) is a regular SliceSampler + s = SliceSampler(slice_len=5) + assert type(s) is SliceSampler + s = SliceSampler(replacement=True, slice_len=5) + assert type(s) is SliceSampler + + # SliceSampler(replacement=False) dispatches to SliceSamplerWithoutReplacement + s = SliceSampler(replacement=False, slice_len=5) + assert type(s) is SliceSamplerWithoutReplacement + assert s.slice_len == 5 + assert s.drop_last is False + assert s.shuffle is True + + # Extra without-replacement kwargs forward correctly + s = SliceSampler( + replacement=False, + slice_len=5, + drop_last=True, + shuffle=False, + traj_key="episode", + strict_length=False, + ) + assert type(s) is SliceSamplerWithoutReplacement + assert s.slice_len == 5 + assert s.drop_last is True + assert s.shuffle is False + assert s.traj_key == "episode" + assert s.strict_length is False + + # isinstance preserves the SliceSampler hierarchy + assert isinstance(s, SliceSampler) + assert isinstance(s, SamplerWithoutReplacement) + + def test_replacement_kwarg_subclass_unaffected(self): + # PrioritizedSliceSampler inherits from SliceSampler but should NOT dispatch + s = PrioritizedSliceSampler( + slice_len=5, max_capacity=10, alpha=0.5, beta=0.5 + ) + assert type(s) is PrioritizedSliceSampler + + # SamplerWithoutReplacement(replacement=...) is a no-op pop + s = SamplerWithoutReplacement(replacement=False, drop_last=True) + assert type(s) is SamplerWithoutReplacement + assert s.drop_last is True + s = SliceSamplerWithoutReplacement(replacement=False, slice_len=5) + assert type(s) is SliceSamplerWithoutReplacement + assert s.slice_len == 5 + + def test_replacement_kwarg_no_variant_errors(self): + # PrioritizedSampler has no without-replacement variant -> TypeError + with pytest.raises(TypeError, match="no without-replacement variant"): + PrioritizedSampler( + max_capacity=10, alpha=0.5, beta=0.5, replacement=False + ) + + def test_replacement_kwarg_in_replay_buffer(self): + # End-to-end: a buffer using RandomSampler(replacement=False) should + # exhaust the storage without duplicate indices (like SamplerWithoutReplacement). + torch.manual_seed(0) + data = TensorDict({"a": torch.arange(11)}, batch_size=[11]) + rb = ReplayBuffer( + storage=LazyTensorStorage(11), + sampler=RandomSampler(replacement=False, drop_last=False), + batch_size=3, + ) + rb.extend(data) + seen = set() + for _ in range(4): + seen.update(rb.sample()["a"].tolist()) + assert seen == set(range(11)) + + def test_replacement_kwarg_slice_in_replay_buffer(self): + # End-to-end: SliceSampler(replacement=False) returns sub-trajectories + torch.manual_seed(0) + episodes = torch.zeros(60, dtype=torch.long) + episodes[:20] = 0 + episodes[20:40] = 1 + episodes[40:] = 2 + data = TensorDict( + {"episode": episodes, "obs": torch.arange(60)}, + batch_size=[60], + ) + rb = ReplayBuffer( + storage=LazyTensorStorage(60), + sampler=SliceSampler( + replacement=False, + slice_len=5, + traj_key="episode", + strict_length=True, + ), + batch_size=10, + ) + rb.extend(data) + sample = rb.sample() + # batch_size=10, slice_len=5 -> 2 slices of 5 contiguous obs each + obs = sample["obs"].view(2, 5) + diffs = obs[:, 1:] - obs[:, :-1] + assert (diffs == 1).all(), obs + class TestStalenessAwareSampler: """Tests for StalenessAwareSampler.""" diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index dd051cf5f58..56f946d39dc 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -7,7 +7,7 @@ import json import textwrap import warnings -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from collections import OrderedDict from copy import copy, deepcopy from multiprocessing.context import get_spawning_popen @@ -55,7 +55,48 @@ _EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage." -class Sampler(ABC): +# Maps a "with replacement" sampler class to its "without replacement" counterpart. +# Populated at module import time after the relevant classes are defined. +# Consumed by :class:`_SamplerMeta` to dispatch ``Cls(replacement=False, ...)`` calls +# to ``_REPLACEMENT_DISPATCH[Cls](...)``. +_REPLACEMENT_DISPATCH: dict[type, type] = {} + + +class _SamplerMeta(ABCMeta): + """Metaclass enabling ``replacement=False`` dispatch on with-replacement samplers. + + When a class registered in :data:`_REPLACEMENT_DISPATCH` (e.g. + :class:`RandomSampler`, :class:`SliceSampler`) is instantiated with + ``replacement=False``, the call is dispatched to its without-replacement + counterpart (:class:`SamplerWithoutReplacement` or + :class:`SliceSamplerWithoutReplacement`). + + Calls with ``replacement=True`` (the default) behave exactly like a normal + instantiation: the ``replacement`` kwarg is popped before the constructor + runs, so existing ``__init__`` signatures don't need to be changed. + + Passing ``replacement=False`` to a sampler that has no without-replacement + variant raises :class:`TypeError`. Passing ``replacement=False`` to a + sampler that is itself already a without-replacement variant is allowed + and treated as a no-op. + """ + + def __call__(cls, *args, **kwargs): + if "replacement" in kwargs: + replacement = kwargs.pop("replacement") + if not replacement: + alt = _REPLACEMENT_DISPATCH.get(cls) + if alt is not None: + return alt(*args, **kwargs) + if cls not in _REPLACEMENT_DISPATCH.values(): + raise TypeError( + f"{cls.__name__} has no without-replacement variant; " + "cannot be instantiated with replacement=False." + ) + return super().__call__(*args, **kwargs) + + +class Sampler(ABC, metaclass=_SamplerMeta): """A generic sampler base class for composable Replay Buffers.""" # Some samplers - mainly those without replacement - @@ -133,9 +174,23 @@ def __getstate__(self): class RandomSampler(Sampler): """A uniformly random sampler for composable replay buffers. - Args: - batch_size (int, optional): if provided, the batch size to be used by - the replay buffer when calling :meth:`ReplayBuffer.sample`. + Keyword Args: + replacement (bool, optional): if ``False``, the call is dispatched to + :class:`SamplerWithoutReplacement`, and any additional keyword + arguments (e.g. ``drop_last``, ``shuffle``) are forwarded to its + constructor. Defaults to ``True``. + + Examples: + >>> from torchrl.data import RandomSampler, SamplerWithoutReplacement + >>> isinstance(RandomSampler(), RandomSampler) + True + >>> isinstance(RandomSampler(replacement=False), SamplerWithoutReplacement) + True + >>> isinstance( + ... RandomSampler(replacement=False, drop_last=True), + ... SamplerWithoutReplacement, + ... ) + True """ @@ -1154,12 +1209,19 @@ class SliceSampler(Sampler): This class samples sub-trajectories with replacement. For a version without replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`. + Equivalently, ``SliceSampler(replacement=False, ...)`` dispatches to + :class:`SliceSamplerWithoutReplacement` and forwards the remaining keyword + arguments (including ``drop_last`` and ``shuffle``). .. note:: `SliceSampler` can be slow to retrieve the trajectory indices. To accelerate its execution, prefer using `end_key` over `traj_key`, and consider the following keyword arguments: :attr:`compile`, :attr:`cache_values` and :attr:`use_gpu`. Keyword Args: + replacement (bool, optional): if ``False``, the call is dispatched to + :class:`SliceSamplerWithoutReplacement` (which accepts the same + keyword arguments as well as ``drop_last`` and ``shuffle``). + Defaults to ``True``. num_slices (int): the number of slices to be sampled. The batch-size must be greater or equal to the ``num_slices`` argument. Exclusive with ``slice_len``. @@ -3173,3 +3235,14 @@ def __len__(self): def __repr__(self): samplers = textwrap.indent(f"samplers={self._samplers}", " " * 4) return f"{self.__class__.__name__}(\n{samplers})" + + +# Register without-replacement dispatch targets. Importing this module makes +# ``RandomSampler(replacement=False)`` dispatch to ``SamplerWithoutReplacement`` +# and ``SliceSampler(replacement=False)`` to ``SliceSamplerWithoutReplacement``. +_REPLACEMENT_DISPATCH.update( + { + RandomSampler: SamplerWithoutReplacement, + SliceSampler: SliceSamplerWithoutReplacement, + } +)