diff --git a/test/rb/test_her.py b/test/rb/test_her.py new file mode 100644 index 00000000000..1a1e56f2aea --- /dev/null +++ b/test/rb/test_her.py @@ -0,0 +1,542 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for HERReplayBuffer and HindsightStrategy. + +These tests previously lived in the monolithic ``test/test_rb.py``. +They moved to a dedicated file when the rb test suite was split into +``test/rb/`` upstream. +""" +from __future__ import annotations + +import argparse + +import pytest +import torch +from tensordict import TensorDict + + +class TestHERReplayBuffer: + """Tests for HERReplayBuffer and HindsightStrategy.""" + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + + @staticmethod + def _make_goal_env_data(n_steps: int, goal_dim: int = 3, obs_dim: int = 4): + """Return a TensorDict mimicking a goal-conditioned env rollout.""" + torch.manual_seed(0) + # All transitions belong to a single episode; last step is done. + done = torch.zeros(n_steps, 1, dtype=torch.bool) + done[-1] = True + terminated = torch.zeros(n_steps, 1, dtype=torch.bool) + terminated[-1] = True + + desired_goal = torch.randn(n_steps, goal_dim) + achieved_goal = torch.randn(n_steps, goal_dim) + + return TensorDict( + { + "observation": torch.randn(n_steps, obs_dim), + "desired_goal": desired_goal, + "achieved_goal": achieved_goal, + "action": torch.randn(n_steps, 2), + "next": { + "observation": torch.randn(n_steps, obs_dim), + "desired_goal": desired_goal, + "achieved_goal": achieved_goal, + "reward": torch.zeros(n_steps, 1), + "done": done, + "terminated": terminated, + }, + }, + batch_size=[n_steps], + ) + + @staticmethod + def _sparse_reward_fn(td: TensorDict) -> torch.Tensor: + dist = (td["achieved_goal"] - td["desired_goal"]).norm(dim=-1, keepdim=True) + return (dist < 0.5).float() + + def _make_rb(self, n_steps=20, **kwargs): + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(1000), + batch_size=n_steps, + **kwargs, + ) + data = self._make_goal_env_data(n_steps) + rb.extend(data) + return rb + + # ------------------------------------------------------------------ + # basic API + # ------------------------------------------------------------------ + + def test_import(self): + from torchrl.data import HERReplayBuffer, HindsightStrategy # noqa: F401 + + def test_invalid_her_ratio(self): + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + with pytest.raises(ValueError, match="her_ratio"): + HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + her_ratio=1.5, + ) + + def test_sample_shape(self): + rb = self._make_rb(n_steps=20, her_ratio=0.8) + batch = rb.sample() + assert batch.batch_size == torch.Size([20]) + + def test_her_ratio_zero_unchanged(self): + """her_ratio=0 must return data with the original stored goals.""" + rb = self._make_rb(n_steps=20, her_ratio=0.0) + batch, info = rb.sample(return_info=True) + idx = info["index"] + stored = rb._storage.get(idx) + torch.testing.assert_close(batch["desired_goal"], stored["desired_goal"]) + + # ------------------------------------------------------------------ + # strategy correctness + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("strategy", ["future", "final", "episode", "random"]) + def test_strategies_run(self, strategy): + """All four strategies must produce a valid batch without error.""" + rb = self._make_rb(n_steps=20, strategy=strategy, her_ratio=0.8) + batch = rb.sample() + assert batch.batch_size == torch.Size([20]) + + def test_final_strategy_uses_last_achieved(self): + """FINAL strategy: relabeled goal == achieved_goal of the last step.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 10 + data = self._make_goal_env_data(n) + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + strategy="final", + ) + rb.extend(data) + batch = rb.sample() + + # All relabeled goals must equal the achieved_goal of the last transition. + last_achieved = data["next", "achieved_goal"][-1] # shape [goal_dim] + for i in range(n): + torch.testing.assert_close(batch["desired_goal"][i], last_achieved) + + def test_future_goal_not_from_past(self): + """FUTURE strategy: goal source index must be >= the sampled index.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 30 + data = self._make_goal_env_data(n) + + # Tag each achieved_goal with a unique step index so we can trace + # which step was used as goal source. + step_ids = torch.arange(n, dtype=torch.float).unsqueeze(1).expand(n, 3) + data["achieved_goal"] = step_ids.clone() + data["next", "achieved_goal"] = step_ids.clone() + # desired_goal starts as all-zeros so we can detect relabeling. + data["desired_goal"] = torch.zeros(n, 3) + data["next", "desired_goal"] = torch.zeros(n, 3) + + rb = HERReplayBuffer( + reward_fn=lambda td: torch.zeros(td.batch_size[0], 1), + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + strategy="future", + ) + rb.extend(data) + + # info["index"] gives us the storage indices that were sampled. + batch, info = rb.sample(return_info=True) + sampled_idx = info["index"] + + n_her = n # her_ratio=1.0 + for i in range(n_her): + sampled_step = sampled_idx[i].item() + # The relabeled goal is the step ID of the goal source. + goal_step = batch["desired_goal"][i][0].item() + assert goal_step >= sampled_step, ( + f"FUTURE goal at storage idx {sampled_step} came from " + f"earlier step {goal_step}" + ) + + # ------------------------------------------------------------------ + # reward recomputation + # ------------------------------------------------------------------ + + def test_reward_recomputed_for_her_transitions(self): + """Relabeled transitions must have reward recomputed by reward_fn.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 20 + data = self._make_goal_env_data(n) + # Set all stored rewards to a sentinel value (-99) so we can detect + # which ones were recomputed. + sentinel = -99.0 + data["next", "reward"] = torch.full((n, 1), sentinel) + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=0.8, + ) + rb.extend(data) + batch = rb.sample() + + n_her = int(n * 0.8) + # HER slice: reward must not be the sentinel (was recomputed) + assert not (batch["next", "reward"][:n_her] == sentinel).all() + # Non-HER slice: reward remains as stored + assert (batch["next", "reward"][n_her:] == sentinel).all() + + # ------------------------------------------------------------------ + # multi-episode correctness + # ------------------------------------------------------------------ + + def test_multi_episode_final_stays_within_episode(self): + """FINAL strategy: each relabeled goal must come from the correct episode.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + ep_lens = [5, 8, 7] + n = sum(ep_lens) + done = torch.zeros(n, 1, dtype=torch.bool) + terminated = torch.zeros(n, 1, dtype=torch.bool) + # Mark episode ends + ends = [4, 12, 19] # 0-indexed last step of each episode + for e in ends: + done[e] = True + terminated[e] = True + + # Tag achieved_goal with the episode index so we can verify + episode_ids = torch.zeros(n, dtype=torch.long) + episode_ids[5:13] = 1 + episode_ids[13:] = 2 + achieved = episode_ids.float().unsqueeze(1).expand(n, 3).clone() + + data = TensorDict( + { + "observation": torch.randn(n, 4), + "desired_goal": torch.zeros(n, 3), + "achieved_goal": achieved, + "action": torch.randn(n, 2), + "next": { + "observation": torch.randn(n, 4), + "desired_goal": torch.zeros(n, 3), + "achieved_goal": achieved, + "reward": torch.zeros(n, 1), + "done": done, + "terminated": terminated, + }, + }, + batch_size=[n], + ) + + rb = HERReplayBuffer( + reward_fn=lambda td: torch.zeros(*td.batch_size, 1), + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + strategy="final", + ) + rb.extend(data) + + # Run multiple times to average over randomness in index selection + for _ in range(10): + batch, info = rb.sample(return_info=True) + sampled_idx = info["index"] + for i in range(n): + sid = sampled_idx[i].item() + src_ep = int(episode_ids[sid].item()) + relabeled_ep = int(batch["desired_goal"][i][0].item()) + assert src_ep == relabeled_ep, ( + f"Transition from ep {src_ep} (idx {sid}) got goal " + f"from ep {relabeled_ep}" + ) + + # ------------------------------------------------------------------ + # cache invalidation + # ------------------------------------------------------------------ + + def test_cache_rebuilds_after_extend(self): + """Episode cache must reflect new data after extend.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(200), + batch_size=10, + ) + data1 = self._make_goal_env_data(10) + rb.extend(data1) + # Force a cache build by sampling once + rb.sample() + key1 = rb._last_cache_key + + data2 = self._make_goal_env_data(10) + rb.extend(data2) + # Sample again — cache must rebuild because storage changed + rb.sample() + assert rb._last_cache_key != key1, "Cache key should change after extend" + + def test_cache_rebuilds_after_add(self): + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(200), + batch_size=5, + ) + data = self._make_goal_env_data(5) + rb.extend(data) + rb.sample() + key1 = rb._last_cache_key + + single = self._make_goal_env_data(1) + rb.add(single[0]) + rb.sample() + assert rb._last_cache_key != key1 + + # ------------------------------------------------------------------ + # HindsightStrategy enum + # ------------------------------------------------------------------ + + def test_strategy_accepts_string(self): + from torchrl.data import HERReplayBuffer, HindsightStrategy, LazyTensorStorage + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + strategy="future", + ) + assert rb.strategy is HindsightStrategy.FUTURE + + def test_strategy_invalid(self): + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + with pytest.raises(ValueError): + HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + strategy="invalid_strategy", + ) + + # ------------------------------------------------------------------ + # EPISODE strategy stays within episode + # ------------------------------------------------------------------ + + def test_episode_strategy_stays_within_episode(self): + """EPISODE strategy: goal source must lie within the same episode.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 30 + data = self._make_goal_env_data(n) + step_ids = torch.arange(n, dtype=torch.float).unsqueeze(1).expand(n, 3).clone() + data["achieved_goal"] = step_ids + data["next", "achieved_goal"] = step_ids + data["desired_goal"] = torch.zeros(n, 3) + + rb = HERReplayBuffer( + reward_fn=lambda td: torch.zeros(td.batch_size[0], 1), + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + strategy="episode", + ) + rb.extend(data) + + for _ in range(5): + batch = rb.sample() + goal_step_ids = batch["desired_goal"][:, 0] + assert (goal_step_ids >= 0).all() + assert (goal_step_ids <= n - 1).all() + + # ------------------------------------------------------------------ + # her_ratio=1.0 — full batch relabeled + # ------------------------------------------------------------------ + + def test_her_ratio_one_full_relabel(self): + """her_ratio=1.0: every transition must be relabeled.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 15 + data = self._make_goal_env_data(n) + sentinel = -99.0 + data["next", "reward"] = torch.full((n, 1), sentinel) + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + ) + rb.extend(data) + batch = rb.sample() + assert not (batch["next", "reward"] == sentinel).any() + + # ------------------------------------------------------------------ + # custom reward_key + # ------------------------------------------------------------------ + + def test_custom_reward_key(self): + """reward_key parameter controls where recomputed reward is written.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 10 + done = torch.zeros(n, 1, dtype=torch.bool) + done[-1] = True + data = TensorDict( + { + "observation": torch.randn(n, 4), + "desired_goal": torch.randn(n, 3), + "achieved_goal": torch.randn(n, 3), + "action": torch.randn(n, 2), + "shaped_reward": torch.full((n, 1), -99.0), + "next": { + "observation": torch.randn(n, 4), + "desired_goal": torch.randn(n, 3), + "achieved_goal": torch.randn(n, 3), + "done": done, + }, + }, + batch_size=[n], + ) + rb = HERReplayBuffer( + reward_fn=lambda td: torch.ones(td.batch_size[0], 1), + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + reward_key="shaped_reward", + ) + rb.extend(data) + batch = rb.sample() + assert (batch["shaped_reward"] == 1.0).all() + + # ------------------------------------------------------------------ + # missing key validation + # ------------------------------------------------------------------ + + def test_missing_goal_key_raises(self): + """Clear KeyError when goal_key is absent from storage.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 5 + done = torch.zeros(n, 1, dtype=torch.bool) + done[-1] = True + data = TensorDict( + { + "observation": torch.randn(n, 4), + "achieved_goal": torch.randn(n, 3), + "action": torch.randn(n, 2), + "next": {"reward": torch.zeros(n, 1), "done": done}, + }, + batch_size=[n], + ) + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(100), + batch_size=n, + ) + rb.extend(data) + with pytest.raises(KeyError, match="goal_key"): + rb.sample() + + # ------------------------------------------------------------------ + # repr + # ------------------------------------------------------------------ + + def test_repr(self): + rb = self._make_rb(n_steps=10) + r = repr(rb) + assert "HERReplayBuffer" in r + assert "future" in r + assert "desired_goal" in r + + # ------------------------------------------------------------------ + # state_dict / load_state_dict + # ------------------------------------------------------------------ + + def test_state_dict_round_trip(self): + """state_dict must preserve the episode-boundary cache.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + rb = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(200), + batch_size=10, + ) + rb.extend(self._make_goal_env_data(10)) + rb.sample() + sd = rb.state_dict() + assert "_her" in sd + assert sd["_her"]["episode_ends_cache"] is not None + + rb2 = HERReplayBuffer( + reward_fn=self._sparse_reward_fn, + storage=LazyTensorStorage(200), + batch_size=10, + ) + rb2.load_state_dict(sd) + assert rb2._last_cache_key == rb._last_cache_key + + # ------------------------------------------------------------------ + # nested goal keys + # ------------------------------------------------------------------ + + def test_nested_goal_key(self): + """goal_key and achieved_goal_key can be nested tuples.""" + from torchrl.data import HERReplayBuffer, LazyTensorStorage + + n = 10 + done = torch.zeros(n, 1, dtype=torch.bool) + done[-1] = True + data = TensorDict( + { + "obs": { + "desired_goal": torch.randn(n, 3), + "achieved_goal": torch.randn(n, 3), + "pixels": torch.randn(n, 4), + }, + "action": torch.randn(n, 2), + "next": { + "obs": { + "desired_goal": torch.randn(n, 3), + "achieved_goal": torch.randn(n, 3), + }, + "reward": torch.zeros(n, 1), + "done": done, + }, + }, + batch_size=[n], + ) + rb = HERReplayBuffer( + reward_fn=lambda td: torch.zeros(*td.batch_size, 1), + storage=LazyTensorStorage(100), + batch_size=n, + her_ratio=1.0, + goal_key=("obs", "desired_goal"), + achieved_goal_key=("obs", "achieved_goal"), + ) + rb.extend(data) + # Should not raise + batch = rb.sample() + assert batch.batch_size == torch.Size([n]) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index ffbda6ac08f..216d73d4cd5 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -39,6 +39,8 @@ H5Combine, H5Split, H5StorageCheckpointer, + HERReplayBuffer, + HindsightStrategy, ImmutableDatasetWriter, LazyMemmapStorage, LazyStackStorage, @@ -118,6 +120,8 @@ "H5Combine", "H5Split", "H5StorageCheckpointer", + "HERReplayBuffer", + "HindsightStrategy", "HashToInt", "History", "ImmutableDatasetWriter", diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index e6566a0f9df..0a9a7fa2a68 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -13,6 +13,7 @@ StorageEnsembleCheckpointer, TensorStorageCheckpointer, ) +from .her import HERReplayBuffer, HindsightStrategy from .ray_buffer import RayReplayBuffer from .replay_buffers import ( PrioritizedReplayBuffer, @@ -55,6 +56,8 @@ ) __all__ = [ + "HERReplayBuffer", + "HindsightStrategy", "CompressedListStorage", "CompressedListStorageCheckpointer", "FlatStorageCheckpointer", diff --git a/torchrl/data/replay_buffers/her.py b/torchrl/data/replay_buffers/her.py new file mode 100644 index 00000000000..147d848769f --- /dev/null +++ b/torchrl/data/replay_buffers/her.py @@ -0,0 +1,463 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Hindsight Experience Replay (HER) replay buffer. + +Reference: Andrychowicz et al., "Hindsight Experience Replay," NeurIPS 2017. +https://arxiv.org/abs/1707.01495 +""" +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, Literal + +import torch +from tensordict.base import TensorDictBase +from tensordict.utils import NestedKey + +from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import RandomSampler, Sampler +from torchrl.data.replay_buffers.storages import Storage + +try: + from enum import StrEnum # Python >= 3.11 +except ImportError: + from enum import Enum + + class StrEnum(str, Enum): # noqa: D101 - backport of enum.StrEnum + pass + + +class HindsightStrategy(StrEnum): + """Goal selection strategy for :class:`HERReplayBuffer`. + + Attributes: + FUTURE: sample a goal from a future transition in the same episode + (recommended; strongest empirically). + FINAL: use the final achieved state of the episode as the goal. + EPISODE: sample any achieved state from the same episode uniformly. + RANDOM: sample a random achieved state from the entire buffer. + """ + + FUTURE = "future" + FINAL = "final" + EPISODE = "episode" + RANDOM = "random" + + +class HERReplayBuffer(TensorDictReplayBuffer): + """Hindsight Experience Replay (HER) replay buffer. + + Applies goal relabeling at sample time for goal-conditioned RL. For a + fraction ``her_ratio`` of each sampled batch, the desired goal is + replaced with an achieved goal drawn from the same episode (or from the + full buffer when ``strategy="random"``), and the reward is recomputed + via ``reward_fn``. The remaining ``1 - her_ratio`` fraction is returned + with the original goals and rewards. + + Episode boundaries are detected from ``end_key`` (done flag) and the + corresponding ``terminated`` flag in storage. At least one must be + present; if neither is found a :class:`KeyError` is raised at sample time. + See :ref:`Environment-API` for the semantics of ``"done"`` / + ``"terminated"`` / ``"truncated"``. + + .. note:: + Episode-boundary detection currently assumes a 1D storage. If the + buffer is full and a trajectory spans the wrap-around point, the + write cursor is treated as a synthetic episode boundary, so goals + will not be sampled across the wrap. ``nonzero()`` on the done + signal is a CPU/GPU sync that breaks ``torch.compile`` graphs; + there is no efficient way around this for sparse boundary detection. + + Args: + reward_fn (Callable[[TensorDictBase], Tensor]): receives a tensordict + containing the *relabeled* transition (with the new + ``goal_key`` already set) and must return a reward tensor with + shape ``(*batch, 1)`` or ``(*batch,)``. + + Keyword Args: + her_ratio (float): fraction of sampled transitions to relabel. Must + be in ``[0, 1]``. Default: ``0.8``. + strategy (HindsightStrategy or str): one of ``"future"``, + ``"final"``, ``"episode"``, ``"random"``. Default: ``"future"``. + goal_key (NestedKey): key for the desired goal. Default: + ``"desired_goal"``. + achieved_goal_key (NestedKey): key for the achieved goal. Default: + ``"achieved_goal"``. + reward_key (NestedKey): key where the reward is stored. Defaults to + ``("next", "reward")`` which is the TorchRL convention. + end_key (NestedKey): key indicating episode boundaries (done flag). + The corresponding ``terminated`` key is derived by replacing the + last component with ``"terminated"``. Default: ``("next", "done")``. + sampler (Sampler, optional): index sampler used to draw transitions + (any standard :class:`~torchrl.data.Sampler` works). The HER + relabeling logic is applied on top of the indices returned by + this sampler. Defaults to :class:`~torchrl.data.RandomSampler`. + **kwargs: forwarded to :class:`~torchrl.data.TensorDictReplayBuffer`. + + Example: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data import HERReplayBuffer, LazyMemmapStorage + >>> + >>> def reward_fn(td): + ... dist = (td["achieved_goal"] - td["desired_goal"]).norm(dim=-1, keepdim=True) + ... return (dist < 0.05).float() + >>> + >>> rb = HERReplayBuffer( + ... reward_fn=reward_fn, + ... storage=LazyMemmapStorage(10_000), + ... batch_size=256, + ... ) + >>> # Store several transitions, marking the last one as done + >>> for i in range(4): + ... td = TensorDict( + ... { + ... "observation": torch.randn(4), + ... "desired_goal": torch.zeros(3), + ... "achieved_goal": torch.randn(3), + ... "action": torch.randn(2), + ... "next": { + ... "observation": torch.randn(4), + ... "desired_goal": torch.zeros(3), + ... "achieved_goal": torch.randn(3), + ... "reward": torch.zeros(1), + ... "done": torch.tensor([i == 3]), + ... }, + ... }, + ... batch_size=[], + ... ) + ... _ = rb.add(td) + >>> batch = rb.sample(4) + >>> batch["desired_goal"].shape + torch.Size([4, 3]) + >>> batch[("next", "reward")].shape + torch.Size([4, 1]) + """ + + def __init__( + self, + reward_fn: Callable[[TensorDictBase], torch.Tensor], + *, + her_ratio: float = 0.8, + strategy: HindsightStrategy + | Literal["future", "final", "episode", "random"] = "future", + goal_key: NestedKey = "desired_goal", + achieved_goal_key: NestedKey = "achieved_goal", + reward_key: NestedKey = ("next", "reward"), + end_key: NestedKey = ("next", "done"), + sampler: Sampler | None = None, + **kwargs: Any, + ) -> None: + if not 0.0 <= her_ratio <= 1.0: + raise ValueError(f"her_ratio must be in [0, 1], got {her_ratio}") + if sampler is None: + sampler = RandomSampler() + super().__init__(sampler=sampler, **kwargs) + self.reward_fn = reward_fn + self.her_ratio = her_ratio + self.strategy = HindsightStrategy(strategy) + self.end_key = end_key + self.goal_key = goal_key + self.achieved_goal_key = achieved_goal_key + self.reward_key = reward_key + self._keys_validated: bool = False + self._episode_ends_cache: torch.Tensor | None = None + self._last_cache_key: tuple = (-1, None, False) + + # -- key helpers --------------------------------------------------------- + + def _terminated_key(self) -> NestedKey: + if isinstance(self.end_key, tuple): + return self.end_key[:-1] + ("terminated",) + return "terminated" + + @staticmethod + def _next_key(key: NestedKey) -> tuple: + if isinstance(key, tuple): + return ("next",) + key + return ("next", key) + + # -- episode cache ------------------------------------------------------- + + def _storage_cache_key(self, storage: Storage) -> tuple: + return ( + len(storage), + getattr(storage, "_last_cursor", None), + bool(getattr(storage, "_is_full", False)), + ) + + def _get_episode_ends(self, storage: Storage) -> torch.Tensor: + key = self._storage_cache_key(storage) + if key != self._last_cache_key: + self._episode_ends_cache = self._build_episode_cache(storage) + self._last_cache_key = key + return self._episode_ends_cache + + def _build_episode_cache(self, storage: Storage) -> torch.Tensor: + n = len(storage) + if n == 0: + return torch.zeros(0, dtype=torch.long) + + # HER assumes a 1D storage; multi-dim storages would require flattening + # logic that is not currently in scope. + if getattr(storage, "ndim", 1) != 1: + raise NotImplementedError( + "HERReplayBuffer currently only supports 1D storages " + f"(got ndim={storage.ndim})." + ) + + # Fetch only the done/terminated fields rather than the whole storage. + # For a TensorDict-backed storage, ``storage[:][key]`` materialises a + # single field, not the full transition data. + try: + view = storage[:] + except Exception as e: + raise RuntimeError( + "HERReplayBuffer could not read from storage to compute " + "episode boundaries." + ) from e + + done = view.get(self.end_key, None) if hasattr(view, "get") else None + terminated = ( + view.get(self._terminated_key(), None) if hasattr(view, "get") else None + ) + + if done is None and terminated is None: + raise KeyError( + f"Neither {self.end_key!r} nor {self._terminated_key()!r} were " + "found in storage. Ensure episode boundaries are stored under " + "one of these keys. See " + "https://pytorch.org/rl/main/reference/envs.html#environment-api " + "for the semantics of done / terminated / truncated." + ) + + if done is not None and terminated is not None: + boundary = done.bool() | terminated.bool() + elif done is not None: + boundary = done.bool() + else: + boundary = terminated.bool() + + # Move bookkeeping to CPU: sample indices live on CPU and the boundary + # tensor is a single binary signal so the transfer is cheap even for + # CUDA storages. ``nonzero`` is a sync op and intentionally breaks the + # compile graph here; there isn't a sparse-boundary alternative. + boundary = boundary.to(device="cpu").reshape(-1)[:n] + + episode_ends = boundary.nonzero(as_tuple=True)[0] + + is_full = bool(getattr(storage, "_is_full", False)) + if is_full: + # When the storage has wrapped around, the write cursor delimits + # the last in-progress trajectory. Treat it as a synthetic + # boundary so goal sampling does not cross the wrap. + cursor = getattr(storage, "_last_cursor", None) + cursor_idx = _resolve_cursor(cursor) + if cursor_idx is not None: + cursor_idx = cursor_idx % n + if cursor_idx not in episode_ends.tolist(): + episode_ends = torch.cat( + [ + episode_ends, + torch.tensor( + [cursor_idx], + dtype=episode_ends.dtype, + device=episode_ends.device, + ), + ] + ) + episode_ends, _ = torch.sort(episode_ends) + else: + # Bound the last (possibly partial) episode at n-1. + if len(episode_ends) == 0 or int(episode_ends[-1]) != n - 1: + episode_ends = torch.cat( + [ + episode_ends, + torch.tensor( + [n - 1], + dtype=episode_ends.dtype, + device=episode_ends.device, + ), + ] + ) + return episode_ends + + # -- episode lookup ------------------------------------------------------ + + def _get_episode_range( + self, + idx: torch.Tensor, + episode_ends: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Both idx and episode_ends live on CPU (see _build_episode_cache). + idx_cpu = idx.to(episode_ends.device) + ep_end_pos = torch.searchsorted(episode_ends, idx_cpu, right=False).clamp( + max=len(episode_ends) - 1 + ) + ep_end = episode_ends[ep_end_pos] + prev_pos = (ep_end_pos - 1).clamp(min=0) + ep_start = torch.where( + ep_end_pos == 0, + torch.zeros_like(ep_end), + episode_ends[prev_pos] + 1, + ) + return ep_start, ep_end + + def _sample_goal_indices( + self, + her_idx: torch.Tensor, + ep_starts: torch.Tensor, + ep_ends: torch.Tensor, + storage_len: int, + ) -> torch.Tensor: + n = len(her_idx) + device = her_idx.device + strategy = self.strategy + + if strategy == HindsightStrategy.FUTURE: + span = (ep_ends - her_idx).clamp(min=0).float() + offsets = ( + (torch.rand(n, device=device) * (span + 1)) + .long() + .clamp(max=span.long()) + ) + return her_idx + offsets + + if strategy == HindsightStrategy.FINAL: + return ep_ends.clone() + + if strategy == HindsightStrategy.EPISODE: + span = (ep_ends - ep_starts + 1).float() + offsets = (torch.rand(n, device=device) * span).long() + return ep_starts + offsets + + if strategy == HindsightStrategy.RANDOM: + return torch.randint(storage_len, (n,), device=device) + + raise ValueError(f"Unknown strategy: {strategy!r}") + + # -- validation ---------------------------------------------------------- + + def _validate_storage_keys(self) -> None: + if self._keys_validated or len(self._storage) == 0: + return + sample = self._storage.get(torch.tensor([0])) + missing = [] + for key, role in ( + (self.goal_key, "goal_key"), + (self.achieved_goal_key, "achieved_goal_key"), + ): + if sample.get(key, None) is None: + missing.append(f" {role}={key!r}") + if missing: + raise KeyError( + "The following keys are not present in storage:\n" + + "\n".join(missing) + + "\nEnsure they are stored in every transition before sampling." + ) + self._keys_validated = True + + # -- sampling ------------------------------------------------------------ + + def _sample(self, batch_size: int) -> tuple[Any, dict]: + self._validate_storage_keys() + data, info = super()._sample(batch_size) + + episode_ends = self._get_episode_ends(self._storage) + if len(episode_ends) == 0: + return data, info + + idx = info["index"] + if isinstance(idx, tuple): + idx = idx[0] + if not isinstance(idx, torch.Tensor): + idx = torch.as_tensor(idx) + + n = idx.shape[0] + n_her = int(n * self.her_ratio) + if n_her == 0: + return data, info + + storage_idx = idx[:n_her] + ep_starts, ep_ends = self._get_episode_range(storage_idx, episode_ends) + goal_src_idx = self._sample_goal_indices( + storage_idx, ep_starts, ep_ends, len(self._storage) + ) + + goal_src_tds = self._storage.get(goal_src_idx) + achieved_goals = goal_src_tds.get(self.achieved_goal_key) + + her_slice = data[:n_her] + with data.unlock_(): + her_slice.set_(self.goal_key, achieved_goals) + next_goal_key = self._next_key(self.goal_key) + if her_slice.get(next_goal_key, None) is not None: + her_slice.set_(next_goal_key, achieved_goals) + new_rewards = self.reward_fn(her_slice) + her_slice.set_(self.reward_key, new_rewards) + + return data, info + + # -- empty / checkpoint -------------------------------------------------- + + def empty(self, empty_write_count: bool = True) -> None: + super().empty(empty_write_count=empty_write_count) + self._episode_ends_cache = None + self._last_cache_key = (-1, None, False) + + def state_dict(self) -> dict: + sd = super().state_dict() + sd["_her"] = { + "goal_key": self.goal_key, + "achieved_goal_key": self.achieved_goal_key, + "reward_key": self.reward_key, + "her_ratio": self.her_ratio, + "strategy": self.strategy.value, + "episode_ends_cache": self._episode_ends_cache, + "last_cache_key": self._last_cache_key, + } + return sd + + def load_state_dict(self, state_dict: dict) -> None: + her_state = state_dict.pop("_her", None) + super().load_state_dict(state_dict) + if her_state is not None: + self._episode_ends_cache = her_state.get("episode_ends_cache") + self._last_cache_key = her_state.get("last_cache_key", (-1, None, False)) + + # -- repr --------------------------------------------------------------- + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"strategy={self.strategy.value!r}, " + f"her_ratio={self.her_ratio}, " + f"goal_key={self.goal_key!r}, " + f"achieved_goal_key={self.achieved_goal_key!r}, " + f"reward_key={self.reward_key!r}, " + f"storage={self._storage})" + ) + + +def _resolve_cursor(cursor: Any) -> int | None: + """Coerce a storage's ``_last_cursor`` attribute into a single int.""" + if cursor is None: + return None + if isinstance(cursor, torch.Tensor): + if cursor.numel() == 0: + return None + return int(cursor.flatten()[-1].item()) + if isinstance(cursor, range): + if len(cursor) == 0: + return None + return int(cursor[-1]) + if isinstance(cursor, (list, tuple)): + if not cursor: + return None + return int(cursor[-1]) + try: + return int(cursor) + except (TypeError, ValueError): + return None