diff --git a/.github/labeler.yml b/.github/labeler.yml index 5fbcec3e732..b0197f8c69b 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -152,9 +152,7 @@ - changed-files: - any-glob-to-any-file: - 'torchrl/modules/**' - - 'test/test_modules.py' - - 'test/test_tensordictmodules.py' - - 'test/test_actors.py' + - 'test/modules/**' "Transforms": - changed-files: @@ -183,8 +181,7 @@ - changed-files: - any-glob-to-any-file: - 'torchrl/data/replay_buffers/**' - - 'test/test_rb.py' - - 'test/test_storage_map.py' + - 'test/rb/**' "Services": - changed-files: diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 0d301a0b3be..8d6475f42c0 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -344,9 +344,9 @@ run_distributed_tests() { local json_report_dir="${RUNNER_ARTIFACT_DIR:-${root_dir}}" local json_report_args="--json-report --json-report-file=${json_report_dir}/test-results-distributed.json --json-report-indent=2" - # Run both test_distributed.py and test_rb_distributed.py (both use torch.distributed) + # Run both test/test_distributed.py and test/rb/test_rb_distributed.py (both use torch.distributed) # Note: distributed tests always run on GPU, no need for GPU_MARKER_FILTER here - python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py test/test_rb_distributed.py \ + python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py test/rb/test_rb_distributed.py \ ${json_report_args} \ --instafail --durations 200 -vv --capture no \ --timeout=120 --mp_fork_if_no_cuda @@ -362,7 +362,7 @@ run_non_distributed_tests() { # - Shard 2: test/envs/, test_collectors.py (multiprocessing-heavy) # - Shard 3: Everything else (can use pytest-xdist for parallelism) local shard="${TORCHRL_TEST_SHARD:-all}" - local common_ignores="--ignore test/test_rlhf.py --ignore test/test_distributed.py --ignore test/test_rb_distributed.py --ignore test/llm --ignore test/test_setup.py" + local common_ignores="--ignore test/test_rlhf.py --ignore test/test_distributed.py --ignore test/rb/test_rb_distributed.py --ignore test/llm --ignore test/test_setup.py" local common_args="--instafail --durations 200 -vv --capture no --timeout=120 --mp_fork_if_no_cuda" # JSON report output for flaky test tracking diff --git a/setup.cfg b/setup.cfg index bb68b316728..617fc382724 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,8 +15,7 @@ per-file-ignores = test/smoke_test_deps.py: F401 test_*.py: F841, E731, E266 test/opengl_rendering.py: F401 - test/test_modules.py: F841, E731, E266, TOR101 - test/test_tensordictmodules.py: F841, E731, E266, TOR101 + test/modules/test_*.py: F841, E731, E266, TOR101 torchrl/objectives/cql.py: TOR101 torchrl/objectives/deprecated.py: TOR101 torchrl/objectives/iql.py: TOR101 diff --git a/test/modules/_modules_common.py b/test/modules/_modules_common.py new file mode 100644 index 00000000000..c6d6ea98b71 --- /dev/null +++ b/test/modules/_modules_common.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util +import sys + +import torch +from packaging import version + +_has_transformers = importlib.util.find_spec("transformers") is not None +_has_vllm = importlib.util.find_spec("vllm") is not None + +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) +IS_WINDOWS = sys.platform == "win32" + + +def _has_triton_backend() -> bool: + """Mirror of the triton-availability check inside the RNN backend. + + Triton must be installed, CUDA must be available, and the Triton build + must expose the ``triton.language.extra.libdevice`` submodule + (Triton >= 2.2). Older Triton installations are routed to scan/pad + backends, so the triton-specific tests are skipped there. + """ + if importlib.util.find_spec("triton") is None or not torch.cuda.is_available(): + return False + return importlib.util.find_spec("triton.language.extra.libdevice") is not None + + +_has_triton = _has_triton_backend() +_triton_skip_reason = "requires triton (>= 2.2) and CUDA" + +_has_functorch = False +try: + try: + from torch import vmap as vmap # noqa: F401 + except ImportError: + from functorch import vmap as vmap # noqa: F401 + + _has_functorch = True +except ImportError: + pass diff --git a/test/modules/conftest.py b/test/modules/conftest.py new file mode 100644 index 00000000000..8d74c0c6081 --- /dev/null +++ b/test/modules/conftest.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import pytest +import torch + + +@pytest.fixture +def double_prec_fixture(): + dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + yield + torch.set_default_dtype(dtype) diff --git a/test/modules/test_actor.py b/test/modules/test_actor.py new file mode 100644 index 00000000000..36e2948715f --- /dev/null +++ b/test/modules/test_actor.py @@ -0,0 +1,726 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import numpy as np +import pytest +import torch +from _modules_common import _has_transformers +from tensordict import NonTensorData, TensorDict +from tensordict.nn import CompositeDistribution, TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor + +from torch import distributions as dist, nn + +from torchrl.data import Bounded, Composite +from torchrl.envs import CatFrames, Compose, InitTracker, SerialEnv, TransformedEnv +from torchrl.modules import ( + DiffusionActor, + MultiStepActorWrapper, + ProbabilisticActor, + SafeModule, + TanhDelta, + TanhModule, + TanhNormal, + ValueOperator, +) +from torchrl.modules.distributions.utils import safeatanh, safetanh +from torchrl.modules.models import NoisyLazyLinear, NoisyLinear +from torchrl.modules.tensordict_module.actors import ( + ActorValueOperator, + LMHeadActorValueOperator, +) + +from torchrl.testing import get_default_devices +from torchrl.testing.mocking_classes import CountingEnv, NestedCountingEnv + + +@pytest.mark.parametrize( + "log_prob_key", + [ + None, + "sample_log_prob", + ("nested", "sample_log_prob"), + ("data", "sample_log_prob"), + ], +) +def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=1): + env = NestedCountingEnv(nested_dim=nested_dim) + action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) + policy_module = TensorDictModule( + nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] + ) + policy = ProbabilisticActor( + module=policy_module, + spec=action_spec, + in_keys=[("data", "param")], + out_keys=[("data", "action")], + distribution_class=TanhDelta, + distribution_kwargs={ + "low": action_spec.space.low, + "high": action_spec.space.high, + }, + log_prob_key=log_prob_key, + return_log_prob=True, + ) + + td = env.reset() + td["data", "states"] = td["data", "states"].to(torch.float) + td_out = policy(td) + assert td_out["data", "action"].shape == (5, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (5,) + else: + assert td_out["data", "action_log_prob"].shape == (5,) + + policy = ProbabilisticActor( + module=policy_module, + spec=action_spec, + in_keys={"param": ("data", "param")}, + out_keys=[("data", "action")], + distribution_class=TanhDelta, + distribution_kwargs={ + "low": action_spec.space.low, + "high": action_spec.space.high, + }, + log_prob_key=log_prob_key, + return_log_prob=True, + ) + td_out = policy(td) + assert td_out["data", "action"].shape == (5, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (5,) + else: + assert td_out["data", "action_log_prob"].shape == (5,) + + +@pytest.mark.parametrize( + "log_prob_key", + [ + None, + "sample_log_prob", + ("nested", "sample_log_prob"), + ("data", "sample_log_prob"), + ], +) +def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions=3): + env = NestedCountingEnv(nested_dim=nested_dim) + action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) + actor_net = nn.Sequential( + nn.Linear(1, 2), + NormalParamExtractor(), + ) + policy_module = TensorDictModule( + actor_net, + in_keys=[("data", "states")], + out_keys=[("data", "loc"), ("data", "scale")], + ) + policy = ProbabilisticActor( + module=policy_module, + spec=action_spec, + in_keys=[("data", "loc"), ("data", "scale")], + out_keys=[("data", "action")], + distribution_class=TanhNormal, + distribution_kwargs={ + "low": action_spec.space.low, + "high": action_spec.space.high, + }, + log_prob_key=log_prob_key, + return_log_prob=True, + ) + + td = env.reset() + td["data", "states"] = td["data", "states"].to(torch.float) + td_out = policy(td) + assert td_out["data", "action"].shape == (5, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (5,) + else: + assert td_out["data", "action_log_prob"].shape == (5,) + + policy = ProbabilisticActor( + module=policy_module, + spec=action_spec, + in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, + out_keys=[("data", "action")], + distribution_class=TanhNormal, + distribution_kwargs={ + "low": action_spec.space.low, + "high": action_spec.space.high, + }, + log_prob_key=log_prob_key, + return_log_prob=True, + ) + td_out = policy(td) + assert td_out["data", "action"].shape == (5, 1) + if log_prob_key: + assert td_out[log_prob_key].shape == (5,) + else: + assert td_out["data", "action_log_prob"].shape == (5,) + + +class TestProbabilisticActorGenerator: + """Tests for the ``generator`` kwarg on ``ProbabilisticActor``. + + The actual sampling logic lives in ``tensordict.nn`` and is exhaustively tested there; + these tests just verify the kwarg threads through ``ProbabilisticActor`` → + ``SafeProbabilisticModule`` → ``ProbabilisticTensorDictModule``. + """ + + @staticmethod + def _make_actor(generator=None): + module = TensorDictModule( + lambda x: (x, torch.ones_like(x)), + in_keys=["obs"], + out_keys=["loc", "scale"], + ) + return ProbabilisticActor( + module=module, + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=dist.Normal, + default_interaction_type="random", + generator=generator, + ) + + def test_generator_object(self): + """Two same-seeded Generators must produce identical actions.""" + a1 = self._make_actor(torch.Generator().manual_seed(0)) + a2 = self._make_actor(torch.Generator().manual_seed(0)) + # Set the global RNG to a different state to make sure it's not consulted. + torch.manual_seed(999) + s1 = a1(TensorDict(obs=torch.zeros(4)))["action"].clone() + s2 = a2(TensorDict(obs=torch.zeros(4)))["action"].clone() + assert torch.equal(s1, s2) + + def test_generator_int_seed(self): + """Module-level int is shorthand for ``Generator().manual_seed(int)``.""" + a_int = self._make_actor(generator=0) + a_gen = self._make_actor(generator=torch.Generator().manual_seed(0)) + s_int = a_int(TensorDict(obs=torch.zeros(4)))["action"].clone() + s_gen = a_gen(TensorDict(obs=torch.zeros(4)))["action"].clone() + assert torch.equal(s_int, s_gen) + + def test_generator_isolates_global_rng(self): + """Sampling with a generator must not advance the global RNG.""" + a = self._make_actor(torch.Generator().manual_seed(0)) + torch.manual_seed(1234) + before = torch.get_rng_state() + a(TensorDict(obs=torch.zeros(4))) + after = torch.get_rng_state() + assert torch.equal(before, after) + + def test_generator_advances_in_place(self): + a = self._make_actor(torch.Generator().manual_seed(0)) + s1 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() + s2 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() + assert not torch.equal(s1, s2) + + def test_generator_td_key_int_writeback(self): + """Int seed in the input tensordict is treated as a stream-key (JAX-style).""" + a = self._make_actor(generator="rng") + + def run(seed, n_steps): + td = TensorDict(obs=torch.zeros(4)) + td["rng"] = NonTensorData(seed) + samples = [] + for _ in range(n_steps): + samples.append(a(td)["action"].clone()) + return samples + + traj_a = run(42, 3) + traj_b = run(42, 3) + for x, y in zip(traj_a, traj_b): + assert torch.equal(x, y) + assert not torch.equal(traj_a[0], traj_a[1]) + + def test_generator_td_key_generator_form(self): + """A Generator placed in the input tensordict is used in place.""" + a = self._make_actor(generator="rng") + td = TensorDict(obs=torch.zeros(4)) + td["rng"] = NonTensorData(torch.Generator().manual_seed(0)) + s_key = a(td)["action"].clone() + a_ref = self._make_actor(torch.Generator().manual_seed(0)) + s_ref = a_ref(TensorDict(obs=torch.zeros(4)))["action"].clone() + assert torch.equal(s_key, s_ref) + + def test_generator_default_unchanged(self): + """generator=None preserves existing global-RNG behaviour.""" + a = self._make_actor(generator=None) + torch.manual_seed(0) + s1 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() + torch.manual_seed(0) + s2 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() + assert torch.equal(s1, s2) + + +@pytest.mark.parametrize( + "layer_class", + [ + NoisyLinear, + NoisyLazyLinear, + ], +) +@pytest.mark.parametrize("device", get_default_devices()) +def test_noisy(layer_class, device, seed=0): + torch.manual_seed(seed) + layer = layer_class(3, 4, device=device) + x = torch.randn(10, 3, device=device) + y1 = layer(x) + layer.reset_noise() + y2 = layer(x) + y3 = layer(x) + torch.testing.assert_close(y2, y3) + with pytest.raises(AssertionError): + torch.testing.assert_close(y1, y2) + + +class TestTanh: + def test_errors(self): + with pytest.raises( + ValueError, match="in_keys and out_keys should have the same length" + ): + TanhModule(in_keys=["a", "b"], out_keys=["a"]) + with pytest.raises(ValueError, match=r"The minimum value \(-2\) provided"): + spec = Bounded(-1, 1, shape=()) + TanhModule(in_keys=["act"], low=-2, spec=spec) + with pytest.raises(ValueError, match=r"The maximum value \(-2\) provided to"): + spec = Bounded(-1, 1, shape=()) + TanhModule(in_keys=["act"], high=-2, spec=spec) + with pytest.raises(ValueError, match="Got high < low"): + TanhModule(in_keys=["act"], high=-2, low=-1) + + def test_minmax(self): + mod = TanhModule( + in_keys=["act"], + high=2, + ) + assert isinstance(mod.act_high, torch.Tensor) + mod = TanhModule( + in_keys=["act"], + low=-2, + ) + assert isinstance(mod.act_low, torch.Tensor) + mod = TanhModule( + in_keys=["act"], + high=np.ones((1,)), + ) + assert isinstance(mod.act_high, torch.Tensor) + mod = TanhModule( + in_keys=["act"], + low=-np.ones((1,)), + ) + assert isinstance(mod.act_low, torch.Tensor) + + @pytest.mark.parametrize("clamp", [True, False]) + def test_boundaries(self, clamp): + torch.manual_seed(0) + eps = torch.finfo(torch.float).resolution + for _ in range(10): + min, max = (5 * torch.randn(2)).sort()[0] + mod = TanhModule(in_keys=["act"], low=min, high=max, clamp=clamp) + assert mod.non_trivial + td = TensorDict({"act": (2 * torch.rand(100) - 1) * 10}, []) + mod(td) + # we should have a good proportion of samples close to the boundaries + assert torch.isclose(td["act"], max).any() + assert torch.isclose(td["act"], min).any() + if not clamp: + assert (td["act"] <= max + eps).all() + assert (td["act"] >= min - eps).all() + else: + assert (td["act"] < max + eps).all() + assert (td["act"] > min - eps).all() + + @pytest.mark.parametrize("out_keys", [[("a", "c"), "b"], None]) + @pytest.mark.parametrize("has_spec", [[True, True], [True, False], [False, False]]) + def test_multi_inputs(self, out_keys, has_spec): + in_keys = [("x", "z"), "y"] + real_out_keys = out_keys if out_keys is not None else in_keys + + if any(has_spec): + spec = {} + if has_spec[0]: + spec.update({real_out_keys[0]: Bounded(-2.0, 2.0, shape=())}) + low, high = -2.0, 2.0 + if has_spec[1]: + spec.update({real_out_keys[1]: Bounded(-3.0, 3.0, shape=())}) + low, high = None, None + spec = Composite(spec) + else: + spec = None + low, high = -2.0, 2.0 + + mod = TanhModule( + in_keys=in_keys, + out_keys=out_keys, + low=low, + high=high, + spec=spec, + clamp=False, + ) + data = TensorDict({in_key: torch.randn(100) * 100 for in_key in in_keys}, []) + mod(data) + assert all(out_key in data.keys(True, True) for out_key in real_out_keys) + eps = torch.finfo(torch.float).resolution + + for out_key in real_out_keys: + key = out_key if isinstance(out_key, str) else "_".join(out_key) + low_key = f"{key}_low" + high_key = f"{key}_high" + min, max = getattr(mod, low_key), getattr(mod, high_key) + assert torch.isclose(data[out_key], max).any() + assert torch.isclose(data[out_key], min).any() + assert (data[out_key] <= max + eps).all() + assert (data[out_key] >= min - eps).all() + + +@pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required") +@pytest.mark.parametrize("use_vmap", [False, True]) +@pytest.mark.parametrize("scale", range(10)) +def test_tanh_atanh(use_vmap, scale): + if use_vmap: + try: + from torch import vmap + except ImportError: + try: + from functorch import vmap + except ImportError: + raise pytest.skip("functorch not found") + + torch.manual_seed(0) + x = (torch.randn(10, dtype=torch.double) * scale).requires_grad_(True) + if not use_vmap: + y = safetanh(x, 1e-6) + else: + y = vmap(safetanh, (0, None))(x, 1e-6) + + if not use_vmap: + xp = safeatanh(y, 1e-6) + else: + xp = vmap(safeatanh, (0, None))(y, 1e-6) + + xp.sum().backward() + torch.testing.assert_close(x.grad, torch.ones_like(x)) + + +class TestDiffusionActor: + def test_output_shape(self): + actor = DiffusionActor(action_dim=2, obs_dim=3, num_steps=5) + td = TensorDict({"observation": torch.randn(4, 3)}, batch_size=[4]) + td = actor(td) + assert td["action"].shape == torch.Size([4, 2]) + + def test_unbatched(self): + actor = DiffusionActor(action_dim=4, obs_dim=6, num_steps=3) + td = TensorDict({"observation": torch.randn(6)}, batch_size=[]) + td = actor(td) + assert td["action"].shape == torch.Size([4]) + + def test_custom_in_out_keys(self): + actor = DiffusionActor( + action_dim=2, + obs_dim=3, + num_steps=3, + in_keys=["obs"], + out_keys=["act"], + ) + assert actor.in_keys == ["obs"] + assert actor.out_keys == ["act"] + td = TensorDict({"obs": torch.randn(4, 3)}, batch_size=[4]) + td = actor(td) + assert td["act"].shape == torch.Size([4, 2]) + + def test_custom_score_network(self): + score_net = nn.Linear(2 + 3 + 1, 2) + actor = DiffusionActor( + action_dim=2, obs_dim=3, score_network=score_net, num_steps=3 + ) + td = TensorDict({"observation": torch.randn(4, 3)}, batch_size=[4]) + td = actor(td) + assert td["action"].shape == torch.Size([4, 2]) + + def test_spec_wrapping(self): + spec = Bounded(low=-1.0, high=1.0, shape=(2,)) + actor = DiffusionActor(action_dim=2, obs_dim=3, num_steps=3, spec=spec) + assert actor.spec is not None + + def test_gradients_flow(self): + actor = DiffusionActor(action_dim=2, obs_dim=3, num_steps=3) + obs = torch.randn(4, 3) + td = TensorDict({"observation": obs}, batch_size=[4]) + td = actor(td) + td["action"].sum().backward() + for p in actor.parameters(): + assert p.grad is not None + + +@pytest.mark.parametrize("device", get_default_devices()) +def test_actorcritic(device): + common_module = SafeModule( + module=nn.Linear(3, 4), in_keys=["obs"], out_keys=["hidden"], spec=None + ).to(device) + module = SafeModule(nn.Linear(4, 5), in_keys=["hidden"], out_keys=["param"]) + policy_operator = ProbabilisticActor( + module=module, in_keys=["param"], spec=None, return_log_prob=True + ).to(device) + value_operator = ValueOperator(nn.Linear(4, 1), in_keys=["hidden"]).to(device) + op = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_operator, + value_operator=value_operator, + ).to(device) + td = TensorDict( + source={"obs": torch.randn(4, 3)}, + batch_size=[ + 4, + ], + ).to(device) + td_total = op(td.clone()) + policy_op = op.get_policy_operator() + td_policy = policy_op(td.clone()) + value_op = op.get_value_operator() + td_value = value_op(td) + torch.testing.assert_close(td_total.get("action"), td_policy.get("action")) + torch.testing.assert_close( + td_total.get("sample_log_prob"), td_policy.get("sample_log_prob") + ) + torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value")) + + value_params = set( + list(op.get_value_operator().parameters()) + list(op.module[0].parameters()) + ) + value_params2 = set(value_op.parameters()) + assert len(value_params.difference(value_params2)) == 0 and len( + value_params.intersection(value_params2) + ) == len(value_params) + + policy_params = set( + list(op.get_policy_operator().parameters()) + list(op.module[0].parameters()) + ) + policy_params2 = set(policy_op.parameters()) + assert len(policy_params.difference(policy_params2)) == 0 and len( + policy_params.intersection(policy_params2) + ) == len(policy_params) + + +@pytest.mark.parametrize("name_map", [True, False]) +def test_compound_actor(name_map): + class Module(nn.Module): + def forward(self, x): + return x[..., :3], x[..., 3:6], x[..., 6:] + + module = TensorDictModule( + Module(), + in_keys=["x"], + out_keys=[ + ("params", "normal", "loc"), + ("params", "normal", "scale"), + ("params", "categ", "logits"), + ], + ) + distribution_kwargs = { + "distribution_map": {"normal": dist.Normal, "categ": dist.Categorical} + } + if name_map: + distribution_kwargs.update( + { + "name_map": { + "normal": ("action", "normal"), + "categ": ("action", "categ"), + }, + } + ) + actor = ProbabilisticActor( + module, + in_keys=["params"], + distribution_class=CompositeDistribution, + distribution_kwargs=distribution_kwargs, + ) + if not name_map: + assert actor.out_keys == module.out_keys + ["normal", "categ"] + else: + assert actor.out_keys == module.out_keys + [ + ("action", "normal"), + ("action", "categ"), + ] + + data = TensorDict({"x": torch.rand(10)}, []) + actor(data) + assert set(data.keys(True, True)) == { + "categ" if not name_map else ("action", "categ"), + "normal" if not name_map else ("action", "normal"), + ("params", "categ", "logits"), + ("params", "normal", "loc"), + ("params", "normal", "scale"), + "x", + } + + +@pytest.mark.skipif(not _has_transformers, reason="missing dependencies") +@pytest.mark.parametrize("device", get_default_devices()) +def test_lmhead_actorvalueoperator(device): + from transformers import AutoModelForCausalLM, GPT2Config + + config = GPT2Config(return_dict=False) + base_model = AutoModelForCausalLM.from_config(config).eval() + aco = LMHeadActorValueOperator(base_model).to(device) + + # check common + assert aco.module[0][0].module is base_model.transformer + assert aco.module[0][1].in_keys == ["x"] + assert aco.module[0][1].out_keys == ["x"] + + # check actor + assert aco.module[1].in_keys == ["x"] + assert aco.module[1].out_keys == ["logits", "action", "action_log_prob"] + assert aco.module[1][0].module is base_model.lm_head + + # check critic + assert aco.module[2].in_keys == ["x"] + assert aco.module[2].out_keys == ["state_value"] + assert isinstance(aco.module[2].module, nn.Linear) + assert aco.module[2].module.in_features == base_model.transformer.embed_dim + assert aco.module[2].module.out_features == 1 + + td = TensorDict( + source={ + "input_ids": torch.randint(50257, (4, 3)), + "attention_mask": torch.ones((4, 3)), + }, + batch_size=[ + 4, + ], + device=device, + ) + td_total = aco(td.clone()) + policy_op = aco.get_policy_operator() + td_policy = policy_op(td.clone()) + value_op = aco.get_value_operator() + td_value = value_op(td) + torch.testing.assert_close(td_total.get("action"), td_policy.get("action")) + torch.testing.assert_close( + td_total.get("sample_log_prob"), td_policy.get("sample_log_prob") + ) + torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value")) + + value_params = set( + list(aco.get_value_operator().parameters()) + list(aco.module[0].parameters()) + ) + value_params2 = set(value_op.parameters()) + assert len(value_params.difference(value_params2)) == 0 and len( + value_params.intersection(value_params2) + ) == len(value_params) + + policy_params = set( + list(aco.get_policy_operator().parameters()) + list(aco.module[0].parameters()) + ) + policy_params2 = set(policy_op.parameters()) + assert len(policy_params.difference(policy_params2)) == 0 and len( + policy_params.intersection(policy_params2) + ) == len(policy_params) + + +class TestBatchedActor: + def test_batched_actor_exceptions(self): + time_steps = 5 + actor_base = TensorDictModule( + lambda x: torch.ones( + x.shape[0], time_steps, 1, device=x.device, dtype=x.dtype + ), + in_keys=["observation_cat"], + out_keys=["action"], + ) + with pytest.raises(ValueError, match="Only a single init_key can be passed"): + MultiStepActorWrapper(actor_base, n_steps=time_steps, init_key=["init_key"]) + + batch = 2 + + # The second env has frequent resets, the first none + base_env = SerialEnv( + batch, + [lambda: CountingEnv(max_steps=5000), lambda: CountingEnv(max_steps=5)], + ) + env = TransformedEnv( + base_env, + CatFrames( + N=time_steps, + in_keys=["observation"], + out_keys=["observation_cat"], + dim=-1, + ), + ) + actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) + with pytest.raises(KeyError, match="No init key was passed"): + env.rollout(2, actor) + + env = TransformedEnv( + base_env, + Compose( + InitTracker(), + CatFrames( + N=time_steps, + in_keys=["observation"], + out_keys=["observation_cat"], + dim=-1, + ), + ), + ) + td = env.rollout(10)[..., -1]["next"] + actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) + with pytest.raises(RuntimeError, match="Cannot initialize the wrapper"): + env.rollout(10, actor, tensordict=td, auto_reset=False) + + actor = MultiStepActorWrapper(actor_base, n_steps=time_steps - 1) + with pytest.raises(RuntimeError, match="The action's time dimension"): + env.rollout(10, actor) + + @pytest.mark.parametrize("time_steps", [3, 5]) + def test_batched_actor_simple(self, time_steps): + + batch = 2 + + # The second env has frequent resets, the first none + base_env = SerialEnv( + batch, + [lambda: CountingEnv(max_steps=5000), lambda: CountingEnv(max_steps=5)], + ) + env = TransformedEnv( + base_env, + Compose( + InitTracker(), + CatFrames( + N=time_steps, + in_keys=["observation"], + out_keys=["observation_cat"], + dim=-1, + ), + ), + ) + + actor_base = TensorDictModule( + lambda x: torch.ones( + x.shape[0], time_steps, 1, device=x.device, dtype=x.dtype + ), + in_keys=["observation_cat"], + out_keys=["action"], + ) + actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) + # rollout = env.rollout(100, break_when_any_done=False) + rollout = env.rollout(50, actor, break_when_any_done=False) + unique = rollout[0]["observation"].unique() + predicted = torch.arange(unique.numel()) + assert (unique == predicted).all() + assert ( + rollout[1]["observation"] + == (torch.arange(50) % 6).reshape_as(rollout[1]["observation"]) + ).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/modules/test_decision_transformer.py b/test/modules/test_decision_transformer.py new file mode 100644 index 00000000000..766700e87fd --- /dev/null +++ b/test/modules/test_decision_transformer.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import pytest +import torch +from _modules_common import _has_transformers +from tensordict import TensorDict, unravel_key_list +from tensordict.nn import TensorDictModule +from torchrl.modules import ( + DecisionTransformerInferenceWrapper, + DTActor, + OnlineDTActor, + ProbabilisticActor, + TanhDelta, + TanhNormal, +) +from torchrl.modules.models.decision_transformer import DecisionTransformer + + +@pytest.mark.skipif( + not _has_transformers, reason="transformers needed for TestDecisionTransformer" +) +class TestDecisionTransformer: + def test_init(self): + DecisionTransformer( + 3, + 4, + ) + with pytest.raises(TypeError): + DecisionTransformer(3, 4, config="some_str") + DecisionTransformer( + 3, + 4, + config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_exec(self, batch_dims, T=5): + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + model = DecisionTransformer( + 3, + 4, + config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + out = model(observations, actions, r2go) + assert out.shape == torch.Size([*batch_dims, T, 16]) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_dtactor(self, batch_dims, T=5): + dtactor = DTActor( + 3, + 4, + transformer_config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + out = dtactor(observations, actions, r2go) + assert out.shape == torch.Size([*batch_dims, T, 4]) + + @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) + def test_onlinedtactor(self, batch_dims, T=5): + dtactor = OnlineDTActor( + 3, + 4, + transformer_config=DecisionTransformer.DTConfig( + n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 + ), + ) + observations = torch.randn(*batch_dims, T, 3) + actions = torch.randn(*batch_dims, T, 4) + r2go = torch.randn(*batch_dims, T, 1) + mu, sig = dtactor(observations, actions, r2go) + assert mu.shape == torch.Size([*batch_dims, T, 4]) + assert sig.shape == torch.Size([*batch_dims, T, 4]) + assert (dtactor.log_std_min < sig.log()).all() + assert (dtactor.log_std_max > sig.log()).all() + + +@pytest.mark.skipif( + not _has_transformers, reason="transformers needed to test DT classes" +) +class TestDecisionTransformerInferenceWrapper: + @pytest.mark.parametrize("online", [True, False]) + def test_dt_inference_wrapper(self, online): + action_key = ("nested", ("action",)) + if online: + dtactor = OnlineDTActor( + state_dim=4, action_dim=2, transformer_config=DTActor.default_config() + ) + in_keys = ["loc", "scale"] + actor_module = TensorDictModule( + dtactor, + in_keys=["observation", action_key, "return_to_go"], + out_keys=in_keys, + ) + dist_class = TanhNormal + else: + dtactor = DTActor( + state_dim=4, action_dim=2, transformer_config=DTActor.default_config() + ) + in_keys = ["param"] + actor_module = TensorDictModule( + dtactor, + in_keys=["observation", action_key, "return_to_go"], + out_keys=in_keys, + ) + dist_class = TanhDelta + dist_kwargs = { + "low": -1.0, + "high": 1.0, + } + actor = ProbabilisticActor( + in_keys=in_keys, + out_keys=[action_key], + module=actor_module, + distribution_class=dist_class, + distribution_kwargs=dist_kwargs, + ) + inference_actor = DecisionTransformerInferenceWrapper(actor) + sequence_length = 20 + td = TensorDict( + { + "observation": torch.randn(1, sequence_length, 4), + action_key: torch.randn(1, sequence_length, 2), + "return_to_go": torch.randn(1, sequence_length, 1), + }, + [1], + ) + with pytest.raises( + ValueError, + match="The value of out_action_key", + ): + result = inference_actor(td) + inference_actor.set_tensor_keys(action=action_key, out_action=action_key) + result = inference_actor(td) + # checks that the seq length has disappeared + assert result.get(action_key).shape == torch.Size([1, 2]) + assert inference_actor.out_keys == unravel_key_list( + sorted([action_key, *in_keys, "observation", "return_to_go"], key=str) + ) + assert set(result.keys(True, True)) - set(td.keys(True, True)) == set( + inference_actor.out_keys + ) - set(inference_actor.in_keys) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/modules/test_dreamer_components.py b/test/modules/test_dreamer_components.py new file mode 100644 index 00000000000..deec83a7924 --- /dev/null +++ b/test/modules/test_dreamer_components.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import pytest +import torch +from packaging import version +from tensordict import TensorDict +from torchrl.data.tensor_specs import Bounded +from torchrl.modules import SafeModule +from torchrl.modules.models.model_based import ( + DreamerActor, + ObsDecoder, + ObsEncoder, + RSSMPosterior, + RSSMPrior, + RSSMRollout, +) + +from torchrl.testing import get_default_devices + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("batch_size", [[], [3], [5]]) +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("1.11.0"), + reason="""Dreamer works with batches of null to 2 dimensions. Torch < 1.11 +requires one-dimensional batches (for RNN and Conv nets for instance). If you'd like +to see torch < 1.11 supported for dreamer, please submit an issue.""", +) +class TestDreamerComponents: + @pytest.mark.parametrize("out_features", [3, 5]) + @pytest.mark.parametrize("temporal_size", [[], [2], [4]]) + def test_dreamer_actor(self, device, batch_size, temporal_size, out_features): + actor = DreamerActor( + out_features, + ).to(device) + emb = torch.randn(*batch_size, *temporal_size, 15, device=device) + state = torch.randn(*batch_size, *temporal_size, 2, device=device) + loc, scale = actor(emb, state) + assert loc.shape == (*batch_size, *temporal_size, out_features) + assert scale.shape == (*batch_size, *temporal_size, out_features) + assert torch.all(scale > 0) + + @pytest.mark.parametrize("depth", [32, 64]) + @pytest.mark.parametrize("temporal_size", [[], [2], [4]]) + def test_dreamer_encoder(self, device, temporal_size, batch_size, depth): + encoder = ObsEncoder(channels=depth).to(device) + obs = torch.randn(*batch_size, *temporal_size, 3, 64, 64, device=device) + emb = encoder(obs) + assert emb.shape == (*batch_size, *temporal_size, depth * 8 * 4) + + @pytest.mark.parametrize("depth", [32, 64]) + @pytest.mark.parametrize("stoch_size", [10, 20]) + @pytest.mark.parametrize("deter_size", [20, 30]) + @pytest.mark.parametrize("temporal_size", [[], [2], [4]]) + def test_dreamer_decoder( + self, device, batch_size, temporal_size, depth, stoch_size, deter_size + ): + decoder = ObsDecoder(channels=depth).to(device) + stoch_state = torch.randn( + *batch_size, *temporal_size, stoch_size, device=device + ) + det_state = torch.randn(*batch_size, *temporal_size, deter_size, device=device) + obs = decoder(stoch_state, det_state) + assert obs.shape == (*batch_size, *temporal_size, 3, 64, 64) + + @pytest.mark.parametrize("depth", [32, 64]) + @pytest.mark.parametrize("out_channels", [1, 3]) + @pytest.mark.parametrize("stoch_size", [10]) + @pytest.mark.parametrize("deter_size", [20]) + def test_dreamer_decoder_out_channels( + self, device, batch_size, depth, out_channels, stoch_size, deter_size + ): + decoder = ObsDecoder(channels=depth, out_channels=out_channels).to(device) + stoch_state = torch.randn(*batch_size, stoch_size, device=device) + det_state = torch.randn(*batch_size, deter_size, device=device) + obs = decoder(stoch_state, det_state) + assert obs.shape == (*batch_size, out_channels, 64, 64) + + @pytest.mark.parametrize("stoch_size", [10, 20]) + @pytest.mark.parametrize("deter_size", [20, 30]) + @pytest.mark.parametrize("action_size", [3, 6]) + def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size): + action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) + rssm_prior = RSSMPrior( + action_spec, + hidden_dim=stoch_size, + rnn_hidden_dim=stoch_size, + state_dim=deter_size, + ).to(device) + state = torch.randn(*batch_size, deter_size, device=device) + action = torch.randn(*batch_size, action_size, device=device) + belief = torch.randn(*batch_size, stoch_size, device=device) + prior_mean, prior_std, next_state, belief = rssm_prior(state, belief, action) + assert prior_mean.shape == (*batch_size, deter_size) + assert prior_std.shape == (*batch_size, deter_size) + assert next_state.shape == (*batch_size, deter_size) + assert belief.shape == (*batch_size, stoch_size) + assert torch.all(prior_std > 0) + + @pytest.mark.parametrize("stoch_size", [10, 20]) + @pytest.mark.parametrize("deter_size", [20, 30]) + def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size): + rssm_posterior = RSSMPosterior( + hidden_dim=stoch_size, + state_dim=deter_size, + ).to(device) + belief = torch.randn(*batch_size, stoch_size, device=device) + obs_emb = torch.randn(*batch_size, 1024, device=device) + # Init of lazy linears + _ = rssm_posterior(belief.clone(), obs_emb.clone()) + + torch.manual_seed(0) + posterior_mean, posterior_std, next_state = rssm_posterior( + belief.clone(), obs_emb.clone() + ) + assert posterior_mean.shape == (*batch_size, deter_size) + assert posterior_std.shape == (*batch_size, deter_size) + assert next_state.shape == (*batch_size, deter_size) + assert torch.all(posterior_std > 0) + + torch.manual_seed(0) + posterior_mean_bis, posterior_std_bis, next_state_bis = rssm_posterior( + belief.clone(), obs_emb.clone() + ) + assert torch.allclose(posterior_mean, posterior_mean_bis) + assert torch.allclose(posterior_std, posterior_std_bis) + assert torch.allclose(next_state, next_state_bis) + + @pytest.mark.parametrize("stoch_size", [10, 20]) + @pytest.mark.parametrize("deter_size", [20, 30]) + @pytest.mark.parametrize("temporal_size", [2, 4]) + @pytest.mark.parametrize("action_size", [3, 6]) + def test_rssm_rollout( + self, device, batch_size, temporal_size, stoch_size, deter_size, action_size + ): + action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) + rssm_prior = RSSMPrior( + action_spec, + hidden_dim=stoch_size, + rnn_hidden_dim=stoch_size, + state_dim=deter_size, + ).to(device) + rssm_posterior = RSSMPosterior( + hidden_dim=stoch_size, + state_dim=deter_size, + ).to(device) + + rssm_rollout = RSSMRollout( + SafeModule( + rssm_prior, + in_keys=["state", "belief", "action"], + out_keys=[ + ("next", "prior_mean"), + ("next", "prior_std"), + "_", + ("next", "belief"), + ], + ), + SafeModule( + rssm_posterior, + in_keys=[("next", "belief"), ("next", "encoded_latents")], + out_keys=[ + ("next", "posterior_mean"), + ("next", "posterior_std"), + ("next", "state"), + ], + ), + ) + + state = torch.randn(*batch_size, temporal_size, deter_size, device=device) + belief = torch.randn(*batch_size, temporal_size, stoch_size, device=device) + action = torch.randn(*batch_size, temporal_size, action_size, device=device) + obs_emb = torch.randn(*batch_size, temporal_size, 1024, device=device) + + tensordict = TensorDict( + { + "state": state.clone(), + "action": action.clone(), + "next": { + "encoded_latents": obs_emb.clone(), + "belief": belief.clone(), + }, + }, + device=device, + batch_size=torch.Size([*batch_size, temporal_size]), + ) + ## Init of lazy linears + _ = rssm_rollout(tensordict.clone()) + torch.manual_seed(0) + rollout = rssm_rollout(tensordict) + assert rollout["next", "prior_mean"].shape == ( + *batch_size, + temporal_size, + deter_size, + ) + assert rollout["next", "prior_std"].shape == ( + *batch_size, + temporal_size, + deter_size, + ) + assert rollout["next", "state"].shape == ( + *batch_size, + temporal_size, + deter_size, + ) + assert rollout["next", "belief"].shape == ( + *batch_size, + temporal_size, + stoch_size, + ) + assert rollout["next", "posterior_mean"].shape == ( + *batch_size, + temporal_size, + deter_size, + ) + assert rollout["next", "posterior_std"].shape == ( + *batch_size, + temporal_size, + deter_size, + ) + assert torch.all(rollout["next", "prior_std"] > 0) + assert torch.all(rollout["next", "posterior_std"] > 0) + + state[..., 1:, :] = 0 + belief[..., 1:, :] = 0 + # Only the first state is used for the prior. The rest are recomputed + + tensordict_bis = TensorDict( + { + "state": state.clone(), + "action": action.clone(), + "next": {"encoded_latents": obs_emb.clone(), "belief": belief.clone()}, + }, + device=device, + batch_size=torch.Size([*batch_size, temporal_size]), + ) + torch.manual_seed(0) + rollout_bis = rssm_rollout(tensordict_bis) + + assert torch.allclose( + rollout["next", "prior_mean"], rollout_bis["next", "prior_mean"] + ), (rollout["next", "prior_mean"] - rollout_bis["next", "prior_mean"]).norm() + assert torch.allclose( + rollout["next", "prior_std"], rollout_bis["next", "prior_std"] + ) + assert torch.allclose(rollout["next", "state"], rollout_bis["next", "state"]) + assert torch.allclose(rollout["next", "belief"], rollout_bis["next", "belief"]) + assert torch.allclose( + rollout["next", "posterior_mean"], rollout_bis["next", "posterior_mean"] + ) + assert torch.allclose( + rollout["next", "posterior_std"], rollout_bis["next", "posterior_std"] + ) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/modules/test_mlp_conv.py b/test/modules/test_mlp_conv.py new file mode 100644 index 00000000000..7cedfd66ed9 --- /dev/null +++ b/test/modules/test_mlp_conv.py @@ -0,0 +1,338 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +from numbers import Number + +import pytest +import torch +from torch import nn +from torchrl.modules.models import BatchRenorm1d, Conv3dNet, ConvNet, MLP, NoisyLinear +from torchrl.modules.models.recipes.impala import _ConvNetBlock +from torchrl.modules.models.utils import SquashDims + +from torchrl.testing import get_default_devices + + +class TestMLP: + @pytest.mark.parametrize("in_features", [3, 10, None]) + @pytest.mark.parametrize("out_features", [3, (3, 10)]) + @pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))]) + @pytest.mark.parametrize( + "activation_class, activation_kwargs", + [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], + ) + @pytest.mark.parametrize( + "norm_class, norm_kwargs", + [ + (nn.LazyBatchNorm1d, {}), + (nn.BatchNorm1d, {"num_features": 32}), + (nn.LayerNorm, {"normalized_shape": 32}), + ], + ) + @pytest.mark.parametrize("dropout", [0.0, 0.5]) + @pytest.mark.parametrize("bias_last_layer", [True, False]) + @pytest.mark.parametrize("single_bias_last_layer", [True, False]) + @pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_mlp( + self, + in_features, + out_features, + depth, + num_cells, + activation_class, + activation_kwargs, + dropout, + bias_last_layer, + norm_class, + norm_kwargs, + single_bias_last_layer, + layer_class, + device, + seed=0, + ): + torch.manual_seed(seed) + batch = 2 + mlp = MLP( + in_features=in_features, + out_features=out_features, + depth=depth, + num_cells=num_cells, + activation_class=activation_class, + activation_kwargs=activation_kwargs, + norm_class=norm_class, + norm_kwargs=norm_kwargs, + dropout=dropout, + bias_last_layer=bias_last_layer, + single_bias_last_layer=False, + layer_class=layer_class, + device=device, + ) + if in_features is None: + in_features = 5 + x = torch.randn(batch, in_features, device=device) + y = mlp(x) + out_features = ( + [out_features] if isinstance(out_features, Number) else out_features + ) + assert y.shape == torch.Size([batch, *out_features]) + + def test_kwargs(self): + def make_activation(shift): + return lambda x: x + shift + + def layer(*args, **kwargs): + linear = nn.Linear(*args, **kwargs) + linear.weight.data.copy_(torch.eye(4)) + return linear + + in_features = 4 + out_features = 4 + num_cells = [4, 4, 4] + mlp = MLP( + in_features=in_features, + out_features=out_features, + num_cells=num_cells, + activation_class=make_activation, + activation_kwargs=[{"shift": 0}, {"shift": 1}, {"shift": 2}], + layer_class=layer, + layer_kwargs=[{"bias": False}] * 4, + bias_last_layer=False, + ) + x = torch.zeros(4) + y = mlp(x) + for i, module in enumerate(mlp.modules()): + if isinstance(module, nn.Linear): + assert (module.weight == torch.eye(4)).all(), i + assert module.bias is None, i + assert (y == 3).all() + + +@pytest.mark.parametrize("in_features", [3, 10, None]) +@pytest.mark.parametrize( + "input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features", + [(100, None, None, 3, 1, 0, 32 * 94 * 94), (100, 3, 32, 3, 1, 1, 32 * 100 * 100)], +) +@pytest.mark.parametrize( + "activation_class, activation_kwargs", + [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], +) +@pytest.mark.parametrize( + "norm_class, norm_kwargs", + [(None, None), (nn.LazyBatchNorm2d, {}), (nn.BatchNorm2d, {"num_features": 32})], +) +@pytest.mark.parametrize("bias_last_layer", [True, False]) +@pytest.mark.parametrize( + "aggregator_class, aggregator_kwargs", + [(SquashDims, {})], +) +@pytest.mark.parametrize("squeeze_output", [False]) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("batch", [(2,), (2, 2)]) +def test_convnet( + batch, + in_features, + depth, + num_cells, + kernel_sizes, + strides, + paddings, + activation_class, + activation_kwargs, + norm_class, + norm_kwargs, + bias_last_layer, + aggregator_class, + aggregator_kwargs, + squeeze_output, + device, + input_size, + expected_features, + seed=0, +): + torch.manual_seed(seed) + convnet = ConvNet( + in_features=in_features, + depth=depth, + num_cells=num_cells, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + activation_class=activation_class, + activation_kwargs=activation_kwargs, + norm_class=norm_class, + norm_kwargs=norm_kwargs, + bias_last_layer=bias_last_layer, + aggregator_class=aggregator_class, + aggregator_kwargs=aggregator_kwargs, + squeeze_output=squeeze_output, + device=device, + ) + if in_features is None: + in_features = 5 + x = torch.randn(*batch, in_features, input_size, input_size, device=device) + y = convnet(x) + assert y.shape == torch.Size([*batch, expected_features]) + + +class TestConv3d: + @pytest.mark.parametrize("in_features", [3, 10, None]) + @pytest.mark.parametrize( + "input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features", + [ + (10, None, None, 3, 1, 0, 32 * 4 * 4 * 4), + (10, 3, 32, 3, 1, 1, 32 * 10 * 10 * 10), + ], + ) + @pytest.mark.parametrize( + "activation_class, activation_kwargs", + [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], + ) + @pytest.mark.parametrize( + "norm_class, norm_kwargs", + [ + (None, None), + (nn.LazyBatchNorm3d, {}), + (nn.BatchNorm3d, {"num_features": 32}), + ], + ) + @pytest.mark.parametrize("bias_last_layer", [True, False]) + @pytest.mark.parametrize( + "aggregator_class, aggregator_kwargs", + [(SquashDims, None)], + ) + @pytest.mark.parametrize("squeeze_output", [False]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("batch", [(2,), (2, 2)]) + def test_conv3dnet( + self, + batch, + in_features, + depth, + num_cells, + kernel_sizes, + strides, + paddings, + activation_class, + activation_kwargs, + norm_class, + norm_kwargs, + bias_last_layer, + aggregator_class, + aggregator_kwargs, + squeeze_output, + device, + input_size, + expected_features, + seed=0, + ): + torch.manual_seed(seed) + conv3dnet = Conv3dNet( + in_features=in_features, + depth=depth, + num_cells=num_cells, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + activation_class=activation_class, + activation_kwargs=activation_kwargs, + norm_class=norm_class, + norm_kwargs=norm_kwargs, + bias_last_layer=bias_last_layer, + aggregator_class=aggregator_class, + aggregator_kwargs=aggregator_kwargs, + squeeze_output=squeeze_output, + device=device, + ) + if in_features is None: + in_features = 5 + x = torch.randn( + *batch, in_features, input_size, input_size, input_size, device=device + ) + y = conv3dnet(x) + assert y.shape == torch.Size([*batch, expected_features]) + with pytest.raises(ValueError, match="must have at least 4 dimensions"): + conv3dnet(torch.randn(3, 16, 16)) + + def test_errors(self): + with pytest.raises( + ValueError, match="Null depth is not permitted with Conv3dNet" + ): + conv3dnet = Conv3dNet( + in_features=5, + num_cells=32, + depth=0, + ) + with pytest.raises( + ValueError, match="depth=None requires one of the input args" + ): + conv3dnet = Conv3dNet( + in_features=5, + num_cells=32, + depth=None, + ) + with pytest.raises( + ValueError, match="consider matching or specifying a constant num_cells" + ): + conv3dnet = Conv3dNet( + in_features=5, + num_cells=[32], + depth=None, + kernel_sizes=[3, 3], + ) + + +class TestBatchRenorm: + @pytest.mark.parametrize("num_steps", [0, 5]) + @pytest.mark.parametrize("smooth", [False, True]) + def test_batchrenorm(self, num_steps, smooth): + torch.manual_seed(0) + bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) + brn = BatchRenorm1d( + 5, + momentum=0.1, + eps=1e-5, + warmup_steps=num_steps, + max_d=10000, + max_r=10000, + smooth=smooth, + ) + bn.train() + brn.train() + data_train = torch.randn(100, 5).split(25) + data_test = torch.randn(100, 5) + for i, d in enumerate(data_train): + b = bn(d) + a = brn(d) + if num_steps > 0 and ( + (i < num_steps and not smooth) or (i == 0 and smooth) + ): + torch.testing.assert_close(a, b) + else: + assert not torch.isclose(a, b).all(), i + + bn.eval() + brn.eval() + torch.testing.assert_close(bn(data_test), brn(data_test)) + + +def test_convnetblock_uses_both_resnets(): + """Regression test for https://github.com/pytorch/rl/issues/3519.""" + block = _ConvNetBlock(num_ch=16) + x = torch.randn(2, 3, 8, 8) + out = block(x).mean() + out.backward() + + resnet1_grad = sum(p.grad.abs().sum() for p in block.resnet1.parameters()) + resnet2_grad = sum(p.grad.abs().sum() for p in block.resnet2.parameters()) + assert resnet1_grad > 0, "resnet1 parameters received no gradients" + assert resnet2_grad > 0, "resnet2 parameters received no gradients" + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/modules/test_multiagent_models.py b/test/modules/test_multiagent_models.py new file mode 100644 index 00000000000..e64ee916d15 --- /dev/null +++ b/test/modules/test_multiagent_models.py @@ -0,0 +1,617 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import re + +import pytest +import torch +from tensordict import TensorDict +from torch import nn +from torchrl.modules import MultiAgentConvNet, MultiAgentMLP, QMixer, VDNMixer +from torchrl.modules.models.multiagent import MultiAgentNetBase + +from torchrl.testing import retry + + +class TestMultiAgent: + def _get_mock_input_td( + self, n_agents, n_agents_inputs, state_shape=(64, 64, 3), T=None, batch=(2,) + ): + if T is not None: + batch = batch + (T,) + obs = torch.randn(*batch, n_agents, n_agents_inputs) + state = torch.randn(*batch, *state_shape) + + td = TensorDict( + { + "agents": TensorDict( + {"observation": obs}, + [*batch, n_agents], + ), + "state": state, + }, + batch_size=batch, + ) + return td + + @retry(AssertionError, 5) + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + @pytest.mark.parametrize("n_agent_inputs", [6, None]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) + def test_multiagent_mlp( + self, + n_agents, + centralized, + share_params, + batch, + n_agent_inputs, + n_agent_outputs=2, + ): + torch.manual_seed(1) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + ) + if n_agent_inputs is None: + n_agent_inputs = 6 + td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) + obs = td.get(("agents", "observation")) + + out = mlp(obs) + assert out.shape == (*batch, n_agents, n_agent_outputs) + for i in range(n_agents): + if centralized and share_params: + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + obs[..., 0, 0] += 1 + out2 = mlp(obs) + for i in range(n_agents): + if centralized: + # a modification to the input of agent 0 will impact all agents + assert not torch.allclose(out[..., i, :], out2[..., i, :]) + elif i > 0: + assert torch.allclose(out[..., i, :], out2[..., i, :]) + + obs = ( + torch.randn(*batch, 1, n_agent_inputs) + .expand(*batch, n_agents, n_agent_inputs) + .clone() + ) + out = mlp(obs) + for i in range(n_agents): + if share_params: + # same input same output + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + # same input different output + assert not torch.allclose(out[..., i, :], out[..., j, :]) + pattern = rf"""MultiAgentMLP\( + MLP\( + \(0\): Linear\(in_features=\d+, out_features=32, bias=True\) + \(1\): Tanh\(\) + \(2\): Linear\(in_features=32, out_features=32, bias=True\) + \(3\): Tanh\(\) + \(4\): Linear\(in_features=32, out_features=2, bias=True\) + \), + n_agents={n_agents}, + share_params={share_params}, + centralized={centralized}, + agent_dim={-2}\)""" + assert re.match(pattern, str(mlp), re.DOTALL) + + @retry(AssertionError, 5) + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + @pytest.mark.parametrize("n_agent_inputs", [6, None]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) + def test_multiagent_mlp_init( + self, + n_agents, + centralized, + share_params, + batch, + n_agent_inputs, + n_agent_outputs=2, + ): + torch.manual_seed(1) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + ) + for m in mlp.modules(): + if isinstance(m, nn.Linear): + assert not isinstance(m.weight, nn.Parameter) + assert m.weight.device == torch.device("meta") + break + else: + raise RuntimeError("could not find a Linear module") + if n_agent_inputs is None: + n_agent_inputs = 6 + td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) + obs = td.get(("agents", "observation")) + mlp(obs) + snet = mlp.get_stateful_net() + assert snet is not mlp._empty_net + + def zero_inplace(mod): + if hasattr(mod, "weight"): + mod.weight.data *= 0 + if hasattr(mod, "bias"): + mod.bias.data *= 0 + + snet.apply(zero_inplace) + assert (mlp.params == 0).all() + + def one_outofplace(mod): + if hasattr(mod, "weight"): + mod.weight = nn.Parameter(torch.ones_like(mod.weight.data)) + if hasattr(mod, "bias"): + mod.bias = nn.Parameter(torch.ones_like(mod.bias.data)) + + snet.apply(one_outofplace) + assert (mlp.params == 0).all() + mlp.from_stateful_net(snet) + assert (mlp.params == 1).all() + + @retry(AssertionError, 5) + @pytest.mark.parametrize("n_agents", [3]) + @pytest.mark.parametrize("share_params", [True]) + @pytest.mark.parametrize("centralized", [True]) + @pytest.mark.parametrize("n_agent_inputs", [6]) + @pytest.mark.parametrize("batch", [(4,)]) + @pytest.mark.parametrize("tdparams", [True, False]) + def test_multiagent_mlp_tdparams( + self, + n_agents, + centralized, + share_params, + batch, + n_agent_inputs, + tdparams, + n_agent_outputs=2, + ): + torch.manual_seed(1) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + use_td_params=tdparams, + ) + if tdparams: + assert list(mlp._empty_net.parameters()) == [] + assert list(mlp.params.parameters()) == list(mlp.parameters()) + else: + assert list(mlp._empty_net.parameters()) == list(mlp.parameters()) + assert not hasattr(mlp.params, "parameters") + if torch.backends.mps.is_available(): + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + return + mlp = nn.Sequential(mlp) + mlp.to(device) + param_set = set(mlp.parameters()) + for p in mlp[0].params.values(True, True): + assert p in param_set + + def test_multiagent_mlp_lazy(self): + torch.manual_seed(0) + mlp = MultiAgentMLP( + n_agent_inputs=None, + n_agent_outputs=6, + n_agents=3, + centralized=True, + share_params=False, + depth=2, + ) + optim = torch.optim.SGD(mlp.parameters(), lr=1e-3) + for p in mlp.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for _ in range(2): + td = self._get_mock_input_td(3, 4, batch=(10,)) + obs = td.get(("agents", "observation")) + out = mlp(obs) + assert ( + not mlp.params[0] + .apply(lambda x, y: torch.isclose(x, y), mlp.params[1]) + .any() + ) + out.mean().backward() + optim.step() + for p in mlp.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + def test_multiagent_reset_mlp( + self, + n_agents, + centralized, + share_params, + ): + actor_net = MultiAgentMLP( + n_agent_inputs=4, + n_agent_outputs=6, + num_cells=(4, 4), + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + ) + params_before = actor_net.params.clone() + actor_net.reset_parameters() + params_after = actor_net.params + assert not params_before.apply( + lambda x, y: torch.isclose(x, y), params_after, batch_size=[] + ).any() + if params_after.numel() > 1: + assert ( + not params_after[0] + .apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[]) + .any() + ) + + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("agent_dim", [1, -3]) + def test_multiagent_custom_agent_dim(self, share_params, agent_dim): + """Test that custom agent_dim values work correctly. + + Regression test for https://github.com/pytorch/rl/issues/3288 + """ + n_agents = 3 + obs_dim = 5 + seq_len = 6 + output_dim = 4 + + class SingleAgentMLP(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 32), + nn.Tanh(), + nn.Linear(32, out_dim), + ) + + def forward(self, x): + return self.net(x) + + class MultiAgentPolicyNet(MultiAgentNetBase): + def __init__( + self, + obs_dim, + output_dim, + n_agents, + share_params, + agent_dim, + device=None, + ): + self.obs_dim = obs_dim + self.output_dim = output_dim + self._agent_dim = agent_dim + + super().__init__( + n_agents=n_agents, + centralized=False, + share_params=share_params, + agent_dim=agent_dim, + device=device, + ) + + def _build_single_net(self, *, device, **kwargs): + net = SingleAgentMLP(self.obs_dim, self.output_dim) + return net.to(device) if device is not None else net + + def _pre_forward_check(self, inputs): + if inputs.shape[self._agent_dim] != self.n_agents: + raise ValueError( + f"Multi-agent network expected input with shape[{self._agent_dim}]={self.n_agents}," + f" but got {inputs.shape}" + ) + return inputs + + policy_net = MultiAgentPolicyNet( + obs_dim=obs_dim, + output_dim=output_dim, + n_agents=n_agents, + share_params=share_params, + agent_dim=agent_dim, + ) + + # Input shape: (batch, n_agents, seq_len, obs_dim) with agents at dim 1 + batch_size = 4 + obs = torch.randn(batch_size, n_agents, seq_len, obs_dim) + out = policy_net(obs) + + # Output should preserve agent dimension position + expected_shape = (batch_size, n_agents, seq_len, output_dim) + assert ( + out.shape == expected_shape + ), f"Expected {expected_shape}, got {out.shape}" + + # Verify different agents produce different outputs (unless share_params with same input) + if not share_params: + for i in range(n_agents): + for j in range(i + 1, n_agents): + assert not torch.allclose(out[:, i], out[:, j]) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + @pytest.mark.parametrize("channels", [3, None]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) + def test_multiagent_cnn( + self, + n_agents, + centralized, + share_params, + batch, + channels, + x=15, + y=15, + ): + torch.manual_seed(0) + cnn = MultiAgentConvNet( + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + in_features=channels, + kernel_sizes=3, + ) + if channels is None: + channels = 3 + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.randn(*batch, n_agents, channels, x, y)}, + [*batch, n_agents], + ) + }, + batch_size=batch, + ) + obs = td[("agents", "observation")] + out = cnn(obs) + assert out.shape[:-1] == (*batch, n_agents) + if centralized and share_params: + torch.testing.assert_close(out, out[..., :1, :].expand_as(out)) + else: + for i in range(n_agents): + for j in range(i + 1, n_agents): + assert not torch.allclose(out[..., i, :], out[..., j, :]) + obs[..., 0, 0, 0, 0] += 1 + out2 = cnn(obs) + if centralized: + # a modification to the input of agent 0 will impact all agents + assert not torch.isclose(out, out2).all() + elif n_agents > 1: + assert not torch.isclose(out[..., 0, :], out2[..., 0, :]).all() + torch.testing.assert_close(out[..., 1:, :], out2[..., 1:, :]) + + obs = torch.randn(*batch, 1, channels, x, y).expand( + *batch, n_agents, channels, x, y + ) + out = cnn(obs) + for i in range(n_agents): + if share_params: + # same input same output + assert torch.allclose(out[..., i, :], out[..., 0, :]) + else: + for j in range(i + 1, n_agents): + # same input different output + assert not torch.allclose(out[..., i, :], out[..., j, :]) + + def test_multiagent_cnn_lazy(self): + torch.manual_seed(42) + n_agents = 5 + n_channels = 3 + cnn = MultiAgentConvNet( + n_agents=n_agents, + centralized=False, + share_params=False, + in_features=None, + kernel_sizes=3, + ) + optim = torch.optim.SGD(cnn.parameters(), lr=1e-3) + for p in cnn.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + break + else: + raise AssertionError("No UninitializedParameter found") + for _ in range(2): + td = TensorDict( + { + "agents": TensorDict( + {"observation": torch.randn(4, n_agents, n_channels, 15, 15)}, + [4, 5], + ) + }, + batch_size=[4], + ) + obs = td[("agents", "observation")] + out = cnn(obs) + assert ( + not cnn.params[0] + .apply(lambda x, y: torch.isclose(x, y), cnn.params[1]) + .any() + ) + out.mean().backward() + optim.step() + for p in cnn.parameters(): + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + for p in optim.param_groups[0]["params"]: + if isinstance(p, torch.nn.parameter.UninitializedParameter): + raise AssertionError("UninitializedParameter found") + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralized", [True, False]) + def test_multiagent_reset_cnn( + self, + n_agents, + centralized, + share_params, + ): + torch.manual_seed(42) + actor_net = MultiAgentConvNet( + in_features=4, + num_cells=[5, 5], + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + ) + params_before = actor_net.params.clone() + actor_net.reset_parameters() + params_after = actor_net.params + assert not params_before.apply( + lambda x, y: torch.isclose(x, y), params_after, batch_size=[] + ).any() + if params_after.numel() > 1: + assert ( + not params_after[0] + .apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[]) + .any() + ) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + def test_vdn(self, n_agents, batch): + torch.manual_seed(0) + mixer = VDNMixer(n_agents=n_agents, device="cpu") + + td = self._get_mock_input_td(n_agents, batch=batch, n_agents_inputs=1) + obs = td.get(("agents", "observation")) + assert obs.shape == (*batch, n_agents, 1) + out = mixer(obs) + assert out.shape == (*batch, 1) + assert torch.equal(obs.sum(-2), out) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + @pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)]) + def test_qmix(self, n_agents, batch, state_shape): + torch.manual_seed(0) + mixer = QMixer( + n_agents=n_agents, + state_shape=state_shape, + mixing_embed_dim=32, + device="cpu", + ) + + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape + ) + obs = td.get(("agents", "observation")) + state = td.get("state") + assert obs.shape == (*batch, n_agents, 1) + assert state.shape == (*batch, *state_shape) + out = mixer(obs, state) + assert out.shape == (*batch, 1) + + @pytest.mark.parametrize("mixer", ["qmix", "vdn"]) + def test_mixer_malformed_input( + self, mixer, n_agents=3, batch=(32,), state_shape=(64, 64, 3) + ): + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=3, state_shape=state_shape + ) + if mixer == "qmix": + mixer = QMixer( + n_agents=n_agents, + state_shape=state_shape, + mixing_embed_dim=32, + device="cpu", + ) + else: + mixer = VDNMixer(n_agents=n_agents, device="cpu") + obs = td.get(("agents", "observation")) + state = td.get("state") + + if mixer.needs_state: + with pytest.raises( + ValueError, + match="Mixer that needs state was passed more than 2 inputs", + ): + mixer(obs) + else: + with pytest.raises( + ValueError, + match="Mixer that doesn't need state was passed more than 1 input", + ): + mixer(obs, state) + + in_put = [obs, state] if mixer.needs_state else [obs] + with pytest.raises( + ValueError, + match="Mixer network expected chosen_action_value with last 2 dimensions", + ): + mixer(*in_put) + if mixer.needs_state: + state_diff = state.unsqueeze(-1) + with pytest.raises( + ValueError, + match="Mixer network expected state with ending shape", + ): + mixer(obs, state_diff) + + td = self._get_mock_input_td( + n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape + ) + obs = td.get(("agents", "observation")) + state = td.get("state") + obs = obs.sum(-2) + in_put = [obs, state] if mixer.needs_state else [obs] + with pytest.raises( + ValueError, + match="Mixer network expected chosen_action_value with last 2 dimensions", + ): + mixer(*in_put) + + obs = td.get(("agents", "observation")) + state = td.get("state") + in_put = [obs, state] if mixer.needs_state else [obs] + mixer(*in_put) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/modules/test_planners.py b/test/modules/test_planners.py new file mode 100644 index 00000000000..40d5f5a0077 --- /dev/null +++ b/test/modules/test_planners.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import pytest +import torch +from tensordict import TensorDict +from torch import nn +from torchrl.modules import CEMPlanner, ValueOperator +from torchrl.modules.planners.mppi import MPPIPlanner +from torchrl.objectives.value import TDLambdaEstimator + +from torchrl.testing import get_default_devices +from torchrl.testing.mocking_classes import MockBatchedUnLockedEnv + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("batch_size", [3, 5]) +class TestPlanner: + def test_CEM_model_free_env(self, device, batch_size, seed=1): + env = MockBatchedUnLockedEnv(device=device) + torch.manual_seed(seed) + planner = CEMPlanner( + env, + planning_horizon=10, + optim_steps=2, + num_candidates=100, + top_k=2, + ) + td = env.reset(TensorDict(batch_size=batch_size).to(device)) + td_copy = td.clone() + td = planner(td) + assert ( + td.get("action").shape[-len(env.action_spec.shape) :] + == env.action_spec.shape + ) + assert env.action_spec.is_in(td.get("action")) + + for key in td.keys(): + if key != "action": + assert torch.allclose(td[key], td_copy[key]) + + def test_MPPI(self, device, batch_size, seed=1): + torch.manual_seed(seed) + env = MockBatchedUnLockedEnv(device=device) + value_net = nn.LazyLinear(1, device=device) + value_net = ValueOperator(value_net, in_keys=["observation"]) + advantage_module = TDLambdaEstimator( + gamma=0.99, + lmbda=0.95, + value_network=value_net, + ) + value_net(env.reset()) + planner = MPPIPlanner( + env, + advantage_module, + temperature=1.0, + planning_horizon=10, + optim_steps=2, + num_candidates=100, + top_k=2, + ) + td = env.reset(TensorDict(batch_size=batch_size).to(device)) + td_copy = td.clone() + td = planner(td) + assert ( + td.get("action").shape[-len(env.action_spec.shape) :] + == env.action_spec.shape + ) + assert env.action_spec.is_in(td.get("action")) + + for key in td.keys(): + if key != "action": + assert torch.allclose(td[key], td_copy[key]) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_actors.py b/test/modules/test_qvalue_actor.py similarity index 64% rename from test/test_actors.py rename to test/modules/test_qvalue_actor.py index 67e5429a409..4d8d2fd6949 100644 --- a/test/test_actors.py +++ b/test/modules/test_qvalue_actor.py @@ -5,258 +5,30 @@ from __future__ import annotations import argparse -import importlib.util import warnings import pytest import torch -from tensordict import NonTensorData, TensorDict -from tensordict.nn import CompositeDistribution, TensorDictModule -from tensordict.nn.distributions import NormalParamExtractor +from tensordict import TensorDict +from tensordict.nn import TensorDictModule -from torch import distributions as dist, nn +from torch import nn -from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot -from torchrl.data.llm.dataset import _has_transformers -from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal +from torchrl.data import Binary, Categorical, Composite, MultiOneHot, OneHot +from torchrl.modules import MLP from torchrl.modules.tensordict_module.actors import ( _process_action_space_spec, - ActorValueOperator, DistributionalQValueActor, DistributionalQValueHook, DistributionalQValueModule, - LMHeadActorValueOperator, - ProbabilisticActor, QValueActor, QValueHook, QValueModule, - ValueOperator, ) from torchrl.testing import get_default_devices from torchrl.testing.mocking_classes import NestedCountingEnv -_has_vllm = importlib.util.find_spec("vllm") is not None - - -@pytest.mark.parametrize( - "log_prob_key", - [ - None, - "sample_log_prob", - ("nested", "sample_log_prob"), - ("data", "sample_log_prob"), - ], -) -def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=1): - env = NestedCountingEnv(nested_dim=nested_dim) - action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) - policy_module = TensorDictModule( - nn.Linear(1, 1), in_keys=[("data", "states")], out_keys=[("data", "param")] - ) - policy = ProbabilisticActor( - module=policy_module, - spec=action_spec, - in_keys=[("data", "param")], - out_keys=[("data", "action")], - distribution_class=TanhDelta, - distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, - }, - log_prob_key=log_prob_key, - return_log_prob=True, - ) - - td = env.reset() - td["data", "states"] = td["data", "states"].to(torch.float) - td_out = policy(td) - assert td_out["data", "action"].shape == (5, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (5,) - else: - assert td_out["data", "action_log_prob"].shape == (5,) - - policy = ProbabilisticActor( - module=policy_module, - spec=action_spec, - in_keys={"param": ("data", "param")}, - out_keys=[("data", "action")], - distribution_class=TanhDelta, - distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, - }, - log_prob_key=log_prob_key, - return_log_prob=True, - ) - td_out = policy(td) - assert td_out["data", "action"].shape == (5, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (5,) - else: - assert td_out["data", "action_log_prob"].shape == (5,) - - -@pytest.mark.parametrize( - "log_prob_key", - [ - None, - "sample_log_prob", - ("nested", "sample_log_prob"), - ("data", "sample_log_prob"), - ], -) -def test_probabilistic_actor_nested_normal(log_prob_key, nested_dim=5, n_actions=3): - env = NestedCountingEnv(nested_dim=nested_dim) - action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1) - actor_net = nn.Sequential( - nn.Linear(1, 2), - NormalParamExtractor(), - ) - policy_module = TensorDictModule( - actor_net, - in_keys=[("data", "states")], - out_keys=[("data", "loc"), ("data", "scale")], - ) - policy = ProbabilisticActor( - module=policy_module, - spec=action_spec, - in_keys=[("data", "loc"), ("data", "scale")], - out_keys=[("data", "action")], - distribution_class=TanhNormal, - distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, - }, - log_prob_key=log_prob_key, - return_log_prob=True, - ) - - td = env.reset() - td["data", "states"] = td["data", "states"].to(torch.float) - td_out = policy(td) - assert td_out["data", "action"].shape == (5, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (5,) - else: - assert td_out["data", "action_log_prob"].shape == (5,) - - policy = ProbabilisticActor( - module=policy_module, - spec=action_spec, - in_keys={"loc": ("data", "loc"), "scale": ("data", "scale")}, - out_keys=[("data", "action")], - distribution_class=TanhNormal, - distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, - }, - log_prob_key=log_prob_key, - return_log_prob=True, - ) - td_out = policy(td) - assert td_out["data", "action"].shape == (5, 1) - if log_prob_key: - assert td_out[log_prob_key].shape == (5,) - else: - assert td_out["data", "action_log_prob"].shape == (5,) - - -class TestProbabilisticActorGenerator: - """Tests for the ``generator`` kwarg on ``ProbabilisticActor``. - - The actual sampling logic lives in ``tensordict.nn`` and is exhaustively tested there; - these tests just verify the kwarg threads through ``ProbabilisticActor`` → - ``SafeProbabilisticModule`` → ``ProbabilisticTensorDictModule``. - """ - - @staticmethod - def _make_actor(generator=None): - module = TensorDictModule( - lambda x: (x, torch.ones_like(x)), - in_keys=["obs"], - out_keys=["loc", "scale"], - ) - return ProbabilisticActor( - module=module, - in_keys=["loc", "scale"], - out_keys=["action"], - distribution_class=dist.Normal, - default_interaction_type="random", - generator=generator, - ) - - def test_generator_object(self): - """Two same-seeded Generators must produce identical actions.""" - a1 = self._make_actor(torch.Generator().manual_seed(0)) - a2 = self._make_actor(torch.Generator().manual_seed(0)) - # Set the global RNG to a different state to make sure it's not consulted. - torch.manual_seed(999) - s1 = a1(TensorDict(obs=torch.zeros(4)))["action"].clone() - s2 = a2(TensorDict(obs=torch.zeros(4)))["action"].clone() - assert torch.equal(s1, s2) - - def test_generator_int_seed(self): - """Module-level int is shorthand for ``Generator().manual_seed(int)``.""" - a_int = self._make_actor(generator=0) - a_gen = self._make_actor(generator=torch.Generator().manual_seed(0)) - s_int = a_int(TensorDict(obs=torch.zeros(4)))["action"].clone() - s_gen = a_gen(TensorDict(obs=torch.zeros(4)))["action"].clone() - assert torch.equal(s_int, s_gen) - - def test_generator_isolates_global_rng(self): - """Sampling with a generator must not advance the global RNG.""" - a = self._make_actor(torch.Generator().manual_seed(0)) - torch.manual_seed(1234) - before = torch.get_rng_state() - a(TensorDict(obs=torch.zeros(4))) - after = torch.get_rng_state() - assert torch.equal(before, after) - - def test_generator_advances_in_place(self): - a = self._make_actor(torch.Generator().manual_seed(0)) - s1 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() - s2 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() - assert not torch.equal(s1, s2) - - def test_generator_td_key_int_writeback(self): - """Int seed in the input tensordict is treated as a stream-key (JAX-style).""" - a = self._make_actor(generator="rng") - - def run(seed, n_steps): - td = TensorDict(obs=torch.zeros(4)) - td["rng"] = NonTensorData(seed) - samples = [] - for _ in range(n_steps): - samples.append(a(td)["action"].clone()) - return samples - - traj_a = run(42, 3) - traj_b = run(42, 3) - for x, y in zip(traj_a, traj_b): - assert torch.equal(x, y) - assert not torch.equal(traj_a[0], traj_a[1]) - - def test_generator_td_key_generator_form(self): - """A Generator placed in the input tensordict is used in place.""" - a = self._make_actor(generator="rng") - td = TensorDict(obs=torch.zeros(4)) - td["rng"] = NonTensorData(torch.Generator().manual_seed(0)) - s_key = a(td)["action"].clone() - a_ref = self._make_actor(torch.Generator().manual_seed(0)) - s_ref = a_ref(TensorDict(obs=torch.zeros(4)))["action"].clone() - assert torch.equal(s_key, s_ref) - - def test_generator_default_unchanged(self): - """generator=None preserves existing global-RNG behaviour.""" - a = self._make_actor(generator=None) - torch.manual_seed(0) - s1 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() - torch.manual_seed(0) - s2 = a(TensorDict(obs=torch.zeros(4)))["action"].clone() - assert torch.equal(s1, s2) - class TestQValue: def test_qvalue_hook_wrong_action_space(self): @@ -902,172 +674,6 @@ def make_net(): assert (0 <= action).all() and (action < action_dim).all() -@pytest.mark.parametrize("device", get_default_devices()) -def test_actorcritic(device): - common_module = SafeModule( - module=nn.Linear(3, 4), in_keys=["obs"], out_keys=["hidden"], spec=None - ).to(device) - module = SafeModule(nn.Linear(4, 5), in_keys=["hidden"], out_keys=["param"]) - policy_operator = ProbabilisticActor( - module=module, in_keys=["param"], spec=None, return_log_prob=True - ).to(device) - value_operator = ValueOperator(nn.Linear(4, 1), in_keys=["hidden"]).to(device) - op = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_operator, - value_operator=value_operator, - ).to(device) - td = TensorDict( - source={"obs": torch.randn(4, 3)}, - batch_size=[ - 4, - ], - ).to(device) - td_total = op(td.clone()) - policy_op = op.get_policy_operator() - td_policy = policy_op(td.clone()) - value_op = op.get_value_operator() - td_value = value_op(td) - torch.testing.assert_close(td_total.get("action"), td_policy.get("action")) - torch.testing.assert_close( - td_total.get("sample_log_prob"), td_policy.get("sample_log_prob") - ) - torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value")) - - value_params = set( - list(op.get_value_operator().parameters()) + list(op.module[0].parameters()) - ) - value_params2 = set(value_op.parameters()) - assert len(value_params.difference(value_params2)) == 0 and len( - value_params.intersection(value_params2) - ) == len(value_params) - - policy_params = set( - list(op.get_policy_operator().parameters()) + list(op.module[0].parameters()) - ) - policy_params2 = set(policy_op.parameters()) - assert len(policy_params.difference(policy_params2)) == 0 and len( - policy_params.intersection(policy_params2) - ) == len(policy_params) - - -@pytest.mark.parametrize("name_map", [True, False]) -def test_compound_actor(name_map): - class Module(nn.Module): - def forward(self, x): - return x[..., :3], x[..., 3:6], x[..., 6:] - - module = TensorDictModule( - Module(), - in_keys=["x"], - out_keys=[ - ("params", "normal", "loc"), - ("params", "normal", "scale"), - ("params", "categ", "logits"), - ], - ) - distribution_kwargs = { - "distribution_map": {"normal": dist.Normal, "categ": dist.Categorical} - } - if name_map: - distribution_kwargs.update( - { - "name_map": { - "normal": ("action", "normal"), - "categ": ("action", "categ"), - }, - } - ) - actor = ProbabilisticActor( - module, - in_keys=["params"], - distribution_class=CompositeDistribution, - distribution_kwargs=distribution_kwargs, - ) - if not name_map: - assert actor.out_keys == module.out_keys + ["normal", "categ"] - else: - assert actor.out_keys == module.out_keys + [ - ("action", "normal"), - ("action", "categ"), - ] - - data = TensorDict({"x": torch.rand(10)}, []) - actor(data) - assert set(data.keys(True, True)) == { - "categ" if not name_map else ("action", "categ"), - "normal" if not name_map else ("action", "normal"), - ("params", "categ", "logits"), - ("params", "normal", "loc"), - ("params", "normal", "scale"), - "x", - } - - -@pytest.mark.skipif(not _has_transformers, reason="missing dependencies") -@pytest.mark.parametrize("device", get_default_devices()) -def test_lmhead_actorvalueoperator(device): - from transformers import AutoModelForCausalLM, GPT2Config - - config = GPT2Config(return_dict=False) - base_model = AutoModelForCausalLM.from_config(config).eval() - aco = LMHeadActorValueOperator(base_model).to(device) - - # check common - assert aco.module[0][0].module is base_model.transformer - assert aco.module[0][1].in_keys == ["x"] - assert aco.module[0][1].out_keys == ["x"] - - # check actor - assert aco.module[1].in_keys == ["x"] - assert aco.module[1].out_keys == ["logits", "action", "action_log_prob"] - assert aco.module[1][0].module is base_model.lm_head - - # check critic - assert aco.module[2].in_keys == ["x"] - assert aco.module[2].out_keys == ["state_value"] - assert isinstance(aco.module[2].module, nn.Linear) - assert aco.module[2].module.in_features == base_model.transformer.embed_dim - assert aco.module[2].module.out_features == 1 - - td = TensorDict( - source={ - "input_ids": torch.randint(50257, (4, 3)), - "attention_mask": torch.ones((4, 3)), - }, - batch_size=[ - 4, - ], - device=device, - ) - td_total = aco(td.clone()) - policy_op = aco.get_policy_operator() - td_policy = policy_op(td.clone()) - value_op = aco.get_value_operator() - td_value = value_op(td) - torch.testing.assert_close(td_total.get("action"), td_policy.get("action")) - torch.testing.assert_close( - td_total.get("sample_log_prob"), td_policy.get("sample_log_prob") - ) - torch.testing.assert_close(td_total.get("state_value"), td_value.get("state_value")) - - value_params = set( - list(aco.get_value_operator().parameters()) + list(aco.module[0].parameters()) - ) - value_params2 = set(value_op.parameters()) - assert len(value_params.difference(value_params2)) == 0 and len( - value_params.intersection(value_params2) - ) == len(value_params) - - policy_params = set( - list(aco.get_policy_operator().parameters()) + list(aco.module[0].parameters()) - ) - policy_params2 = set(policy_op.parameters()) - assert len(policy_params.difference(policy_params2)) == 0 and len( - policy_params.intersection(policy_params2) - ) == len(policy_params) - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordictmodules.py b/test/modules/test_rnn.py similarity index 73% rename from test/test_tensordictmodules.py rename to test/modules/test_rnn.py index 557e936fa58..661f63dd49a 100644 --- a/test/test_tensordictmodules.py +++ b/test/modules/test_rnn.py @@ -10,18 +10,21 @@ import pytest import torch - import torchrl.modules - +from _modules_common import ( + _has_functorch, + _has_triton, + _triton_skip_reason, + TORCH_VERSION, +) from packaging import version -from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list -from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential +from tensordict import pad, TensorDict +from tensordict.nn import TensorDictModule, TensorDictSequential from tensordict.utils import assert_close from torch import nn + from torchrl.collectors import SyncDataCollector -from torchrl.data.tensor_specs import Bounded, Composite, Unbounded from torchrl.envs import ( - CatFrames, Compose, EnvCreator, InitTracker, @@ -30,657 +33,233 @@ TensorDictPrimer, TransformedEnv, ) -from torchrl.envs.utils import set_exploration_type, step_mdp +from torchrl.envs.utils import step_mdp from torchrl.modules import ( - AdditiveGaussianModule, ConsistentDropoutModule, - DecisionTransformerInferenceWrapper, - DTActor, + GRU, + GRUCell, GRUModule, + LSTM, + LSTMCell, LSTMModule, MLP, - MultiStepActorWrapper, - NormalParamExtractor, - OnlineDTActor, ProbabilisticActor, - SafeModule, set_recurrent_mode, - TanhDelta, - TanhNormal, - ValueOperator, ) -from torchrl.modules.models.decision_transformer import _has_transformers -from torchrl.modules.tensordict_module.common import ( - ensure_tensordict_compatible, - is_tensordict_compatible, - VmapModule, -) -from torchrl.modules.tensordict_module.probabilistic import ( - SafeProbabilisticModule, - SafeProbabilisticTensorDictSequential, -) -from torchrl.modules.tensordict_module.sequence import SafeSequential from torchrl.modules.utils import ( get_env_transforms_from_module, get_primers_from_module, ) from torchrl.modules.utils.utils import _compute_missing_env_transforms -from torchrl.objectives import DDPGLoss +from torchrl.testing import get_default_devices from torchrl.testing.mocking_classes import CountingEnv, DiscreteActionVecMockEnv -TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) - -import importlib.util as _importlib_util # noqa: E402 - - -def _has_triton_backend() -> bool: - """Mirror of the triton-availability check inside the RNN backend. - - Triton must be installed, CUDA must be available, and the Triton build - must expose the ``triton.language.extra.libdevice`` submodule - (Triton >= 2.2). Older Triton installations are routed to scan/pad - backends, so the triton-specific tests are skipped there. - """ - if _importlib_util.find_spec("triton") is None or not torch.cuda.is_available(): - return False - return _importlib_util.find_spec("triton.language.extra.libdevice") is not None - - -_has_triton = _has_triton_backend() -_triton_skip_reason = "requires triton (>= 2.2) and CUDA" - -_has_functorch = False -try: +if _has_functorch: try: from torch import vmap except ImportError: from functorch import vmap - _has_functorch = True -except ImportError: - pass - - -class TestTDModule: - def test_multiple_output(self): - class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2, out_3): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) - self.linear_3 = nn.Linear(in_1, out_3) - - def forward(self, x): - return self.linear_1(x), self.linear_2(x), self.linear_3(x) - - tensordict_module = SafeModule( - MultiHeadLinear(5, 4, 3, 2), - in_keys=["input"], - out_keys=["out_1", "out_2", "out_3"], - ) - td = TensorDict({"input": torch.randn(3, 5)}, batch_size=[3]) - td = tensordict_module(td) - assert td.shape == torch.Size([3]) - assert "input" in td.keys() - assert "out_1" in td.keys() - assert "out_2" in td.keys() - assert "out_3" in td.keys() - assert td.get("out_3").shape == torch.Size([3, 2]) - - # Using "_" key to ignore some output - tensordict_module = SafeModule( - MultiHeadLinear(5, 4, 3, 2), - in_keys=["input"], - out_keys=["_", "_", "out_3"], - ) - td = TensorDict({"input": torch.randn(3, 5)}, batch_size=[3]) - td = tensordict_module(td) - assert td.shape == torch.Size([3]) - assert "input" in td.keys() - assert "out_3" in td.keys() - assert "_" not in td.keys() - assert td.get("out_3").shape == torch.Size([3, 2]) - - def test_spec_key_warning(self): - class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) - - def forward(self, x): - return self.linear_1(x), self.linear_2(x) - - spec_dict = { - "_": Unbounded((4,)), - "out_2": Unbounded((3,)), - } - # warning due to "_" in spec keys - with pytest.warns(UserWarning, match='got a spec with key "_"'): - tensordict_module = SafeModule( - MultiHeadLinear(5, 4, 3), - in_keys=["input"], - out_keys=["_", "out_2"], - spec=Composite(**spec_dict), - ) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +def test_python_lstm_cell(device, bias): + lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias) + lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias) - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 1 - if lazy: - net = nn.LazyLinear(4 * param_multiplier) - else: - net = nn.Linear(3, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = Bounded(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = Unbounded(4) - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - tensordict_module = SafeModule( - module=net, - spec=spec, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) - return - else: - tensordict_module = SafeModule( - module=net, - spec=spec, - in_keys=["in"], - out_keys=["out"], - safe=safe, - ) + lstm_cell1.load_state_dict(lstm_cell2.state_dict()) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tensordict_module(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any(), td.get("out") - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) - @pytest.mark.parametrize("lazy", [True, False]) - @pytest.mark.parametrize( - "exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None] + # Make sure parameters match + for (k1, v1), (k2, v2) in zip( + lstm_cell1.named_parameters(), lstm_cell2.named_parameters() + ): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + # Run loop + input = torch.randn(2, 3, 10, device=device) + h0 = torch.randn(3, 20, device=device) + c0 = torch.randn(3, 20, device=device) + with torch.no_grad(): + for i in range(input.size()[0]): + h1, c1 = lstm_cell1(input[i], (h0, c0)) + h2, c2 = lstm_cell2(input[i], (h0, c0)) + + # Make sure the final hidden states have the same shape + assert h1.shape == h2.shape + assert c1.shape == c2.shape + torch.testing.assert_close(h1, h2) + torch.testing.assert_close(c1, c2) + h0 = h1 + c0 = c1 + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +def test_python_gru_cell(device, bias): + gru_cell1 = GRUCell(10, 20, device=device, bias=bias) + gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias) + + gru_cell2.load_state_dict(gru_cell1.state_dict()) + + # Make sure parameters match + for (k1, v1), (k2, v2) in zip( + gru_cell1.named_parameters(), gru_cell2.named_parameters() + ): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert (v1 == v2).all() + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + # Run loop + input = torch.randn(2, 3, 10, device=device) + h0 = torch.zeros(3, 20, device=device) + with torch.no_grad(): + for i in range(input.size()[0]): + h1 = gru_cell1(input[i], h0) + h2 = gru_cell2(input[i], h0) + + # Make sure the final hidden states have the same shape + assert h1.shape == h2.shape + torch.testing.assert_close(h1, h2) + h0 = h1 + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("batch_first", [True, False]) +@pytest.mark.parametrize("dropout", [0.0, 0.5]) +@pytest.mark.parametrize("num_layers", [1, 2]) +def test_python_lstm(device, bias, dropout, batch_first, num_layers): + B = 5 + T = 3 + lstm1 = LSTM( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, ) - def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net = nn.LazyLinear(4 * param_multiplier) - else: - net = nn.Linear(3, 4 * param_multiplier) - - in_keys = ["in"] - net = SafeModule( - module=nn.Sequential(net, NormalParamExtractor()), - spec=None, - in_keys=in_keys, - out_keys=out_keys, - ) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = Bounded(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = Unbounded(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - if out_keys == ["loc", "scale"]: - dist_in_keys = ["loc", "scale"] - elif out_keys == ["loc_1", "scale_1"]: - dist_in_keys = {"loc": "loc_1", "scale": "scale_1"} - else: - raise NotImplementedError - - if safe and spec is None: - with pytest.raises( - RuntimeError, - match="is not a valid configuration as the tensor specs are not " - "specified", - ): - prob_module = SafeProbabilisticModule( - in_keys=dist_in_keys, - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - return - else: - prob_module = SafeProbabilisticModule( - in_keys=dist_in_keys, - out_keys=["out"], - spec=spec, - safe=safe, - **kwargs, - ) - - tensordict_module = SafeProbabilisticTensorDictSequential(net, prob_module) - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - with set_exploration_type(exp_mode): - tensordict_module(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - -class TestTDSequence: - # Temporarily disabling this test until 473 is merged in tensordict - # def test_in_key_warning(self): - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] - # ) - # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): - # tensordict_module = SafeModule( - # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] - # ) - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 1 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = Bounded(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = Unbounded(4) - - kwargs = {} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - spec=None, - in_keys=["in"], - out_keys=["hidden"], - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - spec=None, - in_keys=["hidden"], - out_keys=["hidden"], - safe=False, - ) - tdmodule2 = SafeModule( - spec=spec, - module=net2, - in_keys=["hidden"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 3 - tdmodule[1] = tdmodule2 - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 3 - del tdmodule[2] - assert len(tdmodule) == 2 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - @pytest.mark.parametrize("safe", [True, False]) - @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) - @pytest.mark.parametrize("lazy", [True, False]) - def test_stateful_probabilistic(self, safe, spec_type, lazy): - torch.manual_seed(0) - param_multiplier = 2 - if lazy: - net1 = nn.LazyLinear(4) - dummy_net = nn.LazyLinear(4) - net2 = nn.LazyLinear(4 * param_multiplier) - else: - net1 = nn.Linear(3, 4) - dummy_net = nn.Linear(4, 4) - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = nn.Sequential(net2, NormalParamExtractor()) - - if spec_type is None: - spec = None - elif spec_type == "bounded": - spec = Bounded(-0.1, 0.1, 4) - elif spec_type == "unbounded": - spec = Unbounded(4) - else: - raise NotImplementedError - - kwargs = {"distribution_class": TanhNormal} - - if safe and spec is None: - pytest.skip("safe and spec is None is checked elsewhere") - else: - tdmodule1 = SafeModule( - net1, - in_keys=["in"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - dummy_tdmodule = SafeModule( - dummy_net, - in_keys=["hidden"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeModule( - module=net2, - in_keys=["hidden"], - out_keys=["loc", "scale"], - spec=None, - safe=False, - ) - - prob_module = SafeProbabilisticModule( - spec=spec, - in_keys=["loc", "scale"], - out_keys=["out"], - safe=False, - **kwargs, - ) - tdmodule = SafeProbabilisticTensorDictSequential( - tdmodule1, dummy_tdmodule, tdmodule2, prob_module - ) - - assert hasattr(tdmodule, "__setitem__") - assert len(tdmodule) == 4 - tdmodule[1] = tdmodule2 - tdmodule[2] = prob_module - assert len(tdmodule) == 4 - - assert hasattr(tdmodule, "__delitem__") - assert len(tdmodule) == 4 - del tdmodule[3] - assert len(tdmodule) == 3 - - assert hasattr(tdmodule, "__getitem__") - assert tdmodule[0] is tdmodule1 - assert tdmodule[1] is tdmodule2 - assert tdmodule[2] is prob_module - - td = TensorDict({"in": torch.randn(3, 3)}, [3]) - tdmodule(td) - assert td.shape == torch.Size([3]) - assert td.get("out").shape == torch.Size([3, 4]) - - dist = tdmodule.get_dist(td) - assert dist.rsample().shape[: td.ndimension()] == td.shape - - # test bounds - if not safe and spec_type == "bounded": - assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() - elif safe and spec_type == "bounded": - assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() - - def test_submodule_sequence(self): - td_module_1 = SafeModule( - nn.Linear(3, 2), - in_keys=["in"], - out_keys=["hidden"], - ) - td_module_2 = SafeModule( - nn.Linear(2, 4), - in_keys=["hidden"], - out_keys=["out"], - ) - td_module = SafeSequential(td_module_1, td_module_2) - - td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) - sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) - sub_seq_1(td_1) - assert "hidden" in td_1.keys() - assert "out" not in td_1.keys() - td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) - sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) - sub_seq_2(td_2) - assert "out" in td_2.keys() - assert td_2.get("out").shape == torch.Size([5, 4]) - - @pytest.mark.parametrize("stack", [True, False]) - def test_sequential_partial(self, stack): - torch.manual_seed(0) - param_multiplier = 2 - - net1 = nn.Linear(3, 4) - - net2 = nn.Linear(4, 4 * param_multiplier) - net2 = nn.Sequential(net2, NormalParamExtractor()) - net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) - - net3 = nn.Linear(4, 4 * param_multiplier) - net3 = nn.Sequential(net3, NormalParamExtractor()) - net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) - - spec = Bounded(-0.1, 0.1, 4) - - kwargs = {"distribution_class": TanhNormal} - - tdmodule1 = SafeModule( - net1, - in_keys=["a"], - out_keys=["hidden"], - spec=None, - safe=False, - ) - tdmodule2 = SafeProbabilisticTensorDictSequential( - net2, - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=True, - **kwargs, - ), - ) - tdmodule3 = SafeProbabilisticTensorDictSequential( - net3, - SafeProbabilisticModule( - in_keys=["loc", "scale"], - out_keys=["out"], - spec=spec, - safe=True, - **kwargs, - ), - ) - tdmodule = SafeSequential( - tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True - ) - - if stack: - td = LazyStackedTensorDict.maybe_dense_stack( - [ - TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []), - TensorDict({"a": torch.randn(3), "c": torch.randn(4)}, []), - ], - 0, - ) - tdmodule(td) - assert "loc" in td.keys() - assert "scale" in td.keys() - assert "out" in td.keys() - assert td["out"].shape[0] == 2 - assert td["loc"].shape[0] == 2 - assert td["scale"].shape[0] == 2 - assert "b" not in td.keys() - assert "b" in td[0].keys() - else: - td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) - tdmodule(td) - assert "loc" in td.keys() - assert "scale" in td.keys() - assert "out" in td.keys() - assert "b" in td.keys() - - -def test_is_tensordict_compatible(): - class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2, out_3): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) - self.linear_3 = nn.Linear(in_1, out_3) - - def forward(self, x): - return self.linear_1(x), self.linear_2(x), self.linear_3(x) - - td_module = SafeModule( - MultiHeadLinear(5, 4, 3, 2), - in_keys=["in_1", "in_2"], - out_keys=["out_1", "out_2"], + lstm2 = nn.LSTM( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, ) - assert is_tensordict_compatible(td_module) - class MockCompatibleModule(nn.Module): - def __init__(self, in_keys, out_keys): - self.in_keys = in_keys - self.out_keys = out_keys + lstm2.load_state_dict(lstm1.state_dict()) - def forward(self, tensordict): - pass - - compatible_nn_module = MockCompatibleModule( - in_keys=["in_1", "in_2"], - out_keys=["out_1", "out_2"], + # Make sure parameters match + for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" + + if batch_first: + input = torch.randn(B, T, 10, device=device) + else: + input = torch.randn(T, B, 10, device=device) + + h0 = torch.randn(num_layers, 5, 20, device=device) + c0 = torch.randn(num_layers, 5, 20, device=device) + + # Test without hidden states + with torch.no_grad(): + output1, (h1, c1) = lstm1(input) + output2, (h2, c2) = lstm2(input) + + assert h1.shape == h2.shape + assert c1.shape == c2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) + torch.testing.assert_close(c1, c2) + + # Test with hidden states + with torch.no_grad(): + output1, (h1, c1) = lstm1(input, (h0, c0)) + output2, (h2, c2) = lstm1(input, (h0, c0)) + + assert h1.shape == h2.shape + assert c1.shape == c2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) + torch.testing.assert_close(c1, c2) + + +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("batch_first", [True, False]) +@pytest.mark.parametrize("dropout", [0.0, 0.5]) +@pytest.mark.parametrize("num_layers", [1, 2]) +def test_python_gru(device, bias, dropout, batch_first, num_layers): + B = 5 + T = 3 + gru1 = GRU( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, ) - assert is_tensordict_compatible(compatible_nn_module) - - class MockIncompatibleModuleNoKeys(nn.Module): - def forward(self, input): - pass - - incompatible_nn_module_no_keys = MockIncompatibleModuleNoKeys() - assert not is_tensordict_compatible(incompatible_nn_module_no_keys) + gru2 = nn.GRU( + input_size=10, + hidden_size=20, + num_layers=num_layers, + device=device, + bias=bias, + batch_first=batch_first, + ) + gru2.load_state_dict(gru1.state_dict()) - class MockIncompatibleModuleMultipleArgs(nn.Module): - def __init__(self, in_keys, out_keys): - self.in_keys = in_keys - self.out_keys = out_keys + # Make sure parameters match + for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()): + assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" + torch.testing.assert_close(v1, v2) + assert ( + v1.shape == v2.shape + ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - def forward(self, input_1, input_2): - pass + if batch_first: + input = torch.randn(B, T, 10, device=device) + else: + input = torch.randn(T, B, 10, device=device) - incompatible_nn_module_multi_args = MockIncompatibleModuleMultipleArgs( - in_keys=["in_1", "in_2"], - out_keys=["out_1", "out_2"], - ) - with pytest.raises(TypeError): - is_tensordict_compatible(incompatible_nn_module_multi_args) + h0 = torch.randn(num_layers, 5, 20, device=device) + # Test without hidden states + with torch.no_grad(): + output1, h1 = gru1(input) + output2, h2 = gru2(input) -def test_ensure_tensordict_compatible(): - class MultiHeadLinear(nn.Module): - def __init__(self, in_1, out_1, out_2, out_3): - super().__init__() - self.linear_1 = nn.Linear(in_1, out_1) - self.linear_2 = nn.Linear(in_1, out_2) - self.linear_3 = nn.Linear(in_1, out_3) + assert h1.shape == h2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) - def forward(self, x): - return self.linear_1(x), self.linear_2(x), self.linear_3(x) + # Test with hidden states + with torch.no_grad(): + output1, h1 = gru1(input, h0) + output2, h2 = gru2(input, h0) - td_module = SafeModule( - MultiHeadLinear(5, 4, 3, 2), - in_keys=["in_1", "in_2"], - out_keys=["out_1", "out_2"], - ) - ensured_module = ensure_tensordict_compatible(td_module) - assert ensured_module is td_module - with pytest.raises(TypeError): - ensure_tensordict_compatible(td_module, in_keys=["input"]) - with pytest.raises(TypeError): - ensure_tensordict_compatible(td_module, out_keys=["output"]) - - class NonNNModule: - def __init__(self): - pass - - def forward(self, x): - pass - - non_nn_module = NonNNModule() - with pytest.raises(TypeError): - ensure_tensordict_compatible(non_nn_module) - - class ErrorNNModule(nn.Module): - def forward(self, in_1, in_2): - pass - - error_nn_module = ErrorNNModule() - with pytest.raises(TypeError): - ensure_tensordict_compatible(error_nn_module, in_keys=["input"]) - - nn_module = MultiHeadLinear(5, 4, 3, 2) - ensured_module = ensure_tensordict_compatible( - nn_module, - in_keys=["x"], - out_keys=["out_1", "out_2", "out_3"], - ) - assert set(unravel_key_list(ensured_module.in_keys)) == {"x"} - assert isinstance(ensured_module, TensorDictModule) + assert h1.shape == h2.shape + assert output1.shape == output2.shape + if dropout == 0.0: + torch.testing.assert_close(output1, output2) + torch.testing.assert_close(h1, h2) class TestLSTMModule: @@ -2527,218 +2106,6 @@ def test_gru_module_three_backends_equivalent(self, num_layers): ) -def test_safe_specs(): - - out_key = ("a", "b") - spec = Composite(Composite({out_key: Unbounded()})) - original_spec = spec.clone() - mod = SafeModule( - module=nn.Linear(3, 1), - spec=spec, - out_keys=[out_key, ("other", "key")], - in_keys=[], - ) - assert original_spec == spec - assert original_spec[out_key] == mod.spec[out_key] - - -def test_actor_critic_specs(): - action_key = ("agents", "action") - spec = Composite(Composite({action_key: Unbounded(shape=(3,))})) - policy_module = TensorDictModule( - nn.Linear(3, 1), - in_keys=[("agents", "observation")], - out_keys=[action_key], - ) - original_spec = spec.clone() - module = TensorDictSequential( - policy_module, AdditiveGaussianModule(spec=spec, action_key=action_key) - ) - value_module = ValueOperator( - module=module, - in_keys=[("agents", "observation"), action_key], - out_keys=[("agents", "state_action_value")], - ) - assert original_spec == spec - assert module[1].spec == spec - DDPGLoss(actor_network=module, value_network=value_module) - assert original_spec == spec - assert module[1].spec == spec - - -def test_vmapmodule(): - lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) - sample_in = torch.ones((10, 3, 2)) - sample_in_td = TensorDict({"x": sample_in}, batch_size=[10]) - lam(sample_in) - vm = VmapModule(lam, 0) - vm(sample_in_td) - assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() - - -@pytest.mark.skipif( - not _has_transformers, reason="transformers needed to test DT classes" -) -class TestDecisionTransformerInferenceWrapper: - @pytest.mark.parametrize("online", [True, False]) - def test_dt_inference_wrapper(self, online): - action_key = ("nested", ("action",)) - if online: - dtactor = OnlineDTActor( - state_dim=4, action_dim=2, transformer_config=DTActor.default_config() - ) - in_keys = ["loc", "scale"] - actor_module = TensorDictModule( - dtactor, - in_keys=["observation", action_key, "return_to_go"], - out_keys=in_keys, - ) - dist_class = TanhNormal - else: - dtactor = DTActor( - state_dim=4, action_dim=2, transformer_config=DTActor.default_config() - ) - in_keys = ["param"] - actor_module = TensorDictModule( - dtactor, - in_keys=["observation", action_key, "return_to_go"], - out_keys=in_keys, - ) - dist_class = TanhDelta - dist_kwargs = { - "low": -1.0, - "high": 1.0, - } - actor = ProbabilisticActor( - in_keys=in_keys, - out_keys=[action_key], - module=actor_module, - distribution_class=dist_class, - distribution_kwargs=dist_kwargs, - ) - inference_actor = DecisionTransformerInferenceWrapper(actor) - sequence_length = 20 - td = TensorDict( - { - "observation": torch.randn(1, sequence_length, 4), - action_key: torch.randn(1, sequence_length, 2), - "return_to_go": torch.randn(1, sequence_length, 1), - }, - [1], - ) - with pytest.raises( - ValueError, - match="The value of out_action_key", - ): - result = inference_actor(td) - inference_actor.set_tensor_keys(action=action_key, out_action=action_key) - result = inference_actor(td) - # checks that the seq length has disappeared - assert result.get(action_key).shape == torch.Size([1, 2]) - assert inference_actor.out_keys == unravel_key_list( - sorted([action_key, *in_keys, "observation", "return_to_go"], key=str) - ) - assert set(result.keys(True, True)) - set(td.keys(True, True)) == set( - inference_actor.out_keys - ) - set(inference_actor.in_keys) - - -class TestBatchedActor: - def test_batched_actor_exceptions(self): - time_steps = 5 - actor_base = TensorDictModule( - lambda x: torch.ones( - x.shape[0], time_steps, 1, device=x.device, dtype=x.dtype - ), - in_keys=["observation_cat"], - out_keys=["action"], - ) - with pytest.raises(ValueError, match="Only a single init_key can be passed"): - MultiStepActorWrapper(actor_base, n_steps=time_steps, init_key=["init_key"]) - - batch = 2 - - # The second env has frequent resets, the first none - base_env = SerialEnv( - batch, - [lambda: CountingEnv(max_steps=5000), lambda: CountingEnv(max_steps=5)], - ) - env = TransformedEnv( - base_env, - CatFrames( - N=time_steps, - in_keys=["observation"], - out_keys=["observation_cat"], - dim=-1, - ), - ) - actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) - with pytest.raises(KeyError, match="No init key was passed"): - env.rollout(2, actor) - - env = TransformedEnv( - base_env, - Compose( - InitTracker(), - CatFrames( - N=time_steps, - in_keys=["observation"], - out_keys=["observation_cat"], - dim=-1, - ), - ), - ) - td = env.rollout(10)[..., -1]["next"] - actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) - with pytest.raises(RuntimeError, match="Cannot initialize the wrapper"): - env.rollout(10, actor, tensordict=td, auto_reset=False) - - actor = MultiStepActorWrapper(actor_base, n_steps=time_steps - 1) - with pytest.raises(RuntimeError, match="The action's time dimension"): - env.rollout(10, actor) - - @pytest.mark.parametrize("time_steps", [3, 5]) - def test_batched_actor_simple(self, time_steps): - - batch = 2 - - # The second env has frequent resets, the first none - base_env = SerialEnv( - batch, - [lambda: CountingEnv(max_steps=5000), lambda: CountingEnv(max_steps=5)], - ) - env = TransformedEnv( - base_env, - Compose( - InitTracker(), - CatFrames( - N=time_steps, - in_keys=["observation"], - out_keys=["observation_cat"], - dim=-1, - ), - ), - ) - - actor_base = TensorDictModule( - lambda x: torch.ones( - x.shape[0], time_steps, 1, device=x.device, dtype=x.dtype - ), - in_keys=["observation_cat"], - out_keys=["action"], - ) - actor = MultiStepActorWrapper(actor_base, n_steps=time_steps) - # rollout = env.rollout(100, break_when_any_done=False) - rollout = env.rollout(50, actor, break_when_any_done=False) - unique = rollout[0]["observation"].unique() - predicted = torch.arange(unique.numel()) - assert (unique == predicted).all() - assert ( - rollout[1]["observation"] - == (torch.arange(50) % 6).reshape_as(rollout[1]["observation"]) - ).all() - - def test_get_primers_from_module(): # No primers in the model diff --git a/test/modules/test_td_module.py b/test/modules/test_td_module.py new file mode 100644 index 00000000000..d4be1b8687d --- /dev/null +++ b/test/modules/test_td_module.py @@ -0,0 +1,668 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import pytest +import torch +from tensordict import LazyStackedTensorDict, TensorDict, unravel_key_list +from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential +from torch import nn +from torchrl.data.tensor_specs import Bounded, Composite, Unbounded +from torchrl.envs.utils import set_exploration_type +from torchrl.modules import ( + AdditiveGaussianModule, + NormalParamExtractor, + SafeModule, + TanhNormal, + ValueOperator, +) +from torchrl.modules.tensordict_module.common import ( + ensure_tensordict_compatible, + is_tensordict_compatible, + VmapModule, +) +from torchrl.modules.tensordict_module.probabilistic import ( + SafeProbabilisticModule, + SafeProbabilisticTensorDictSequential, +) +from torchrl.modules.tensordict_module.sequence import SafeSequential +from torchrl.objectives import DDPGLoss + + +class TestTDModule: + def test_multiple_output(self): + class MultiHeadLinear(nn.Module): + def __init__(self, in_1, out_1, out_2, out_3): + super().__init__() + self.linear_1 = nn.Linear(in_1, out_1) + self.linear_2 = nn.Linear(in_1, out_2) + self.linear_3 = nn.Linear(in_1, out_3) + + def forward(self, x): + return self.linear_1(x), self.linear_2(x), self.linear_3(x) + + tensordict_module = SafeModule( + MultiHeadLinear(5, 4, 3, 2), + in_keys=["input"], + out_keys=["out_1", "out_2", "out_3"], + ) + td = TensorDict({"input": torch.randn(3, 5)}, batch_size=[3]) + td = tensordict_module(td) + assert td.shape == torch.Size([3]) + assert "input" in td.keys() + assert "out_1" in td.keys() + assert "out_2" in td.keys() + assert "out_3" in td.keys() + assert td.get("out_3").shape == torch.Size([3, 2]) + + # Using "_" key to ignore some output + tensordict_module = SafeModule( + MultiHeadLinear(5, 4, 3, 2), + in_keys=["input"], + out_keys=["_", "_", "out_3"], + ) + td = TensorDict({"input": torch.randn(3, 5)}, batch_size=[3]) + td = tensordict_module(td) + assert td.shape == torch.Size([3]) + assert "input" in td.keys() + assert "out_3" in td.keys() + assert "_" not in td.keys() + assert td.get("out_3").shape == torch.Size([3, 2]) + + def test_spec_key_warning(self): + class MultiHeadLinear(nn.Module): + def __init__(self, in_1, out_1, out_2): + super().__init__() + self.linear_1 = nn.Linear(in_1, out_1) + self.linear_2 = nn.Linear(in_1, out_2) + + def forward(self, x): + return self.linear_1(x), self.linear_2(x) + + spec_dict = { + "_": Unbounded((4,)), + "out_2": Unbounded((3,)), + } + + # warning due to "_" in spec keys + with pytest.warns(UserWarning, match='got a spec with key "_"'): + tensordict_module = SafeModule( + MultiHeadLinear(5, 4, 3), + in_keys=["input"], + out_keys=["_", "out_2"], + spec=Composite(**spec_dict), + ) + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, lazy): + torch.manual_seed(0) + param_multiplier = 1 + if lazy: + net = nn.LazyLinear(4 * param_multiplier) + else: + net = nn.Linear(3, 4 * param_multiplier) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = Bounded(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = Unbounded(4) + + if safe and spec is None: + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + tensordict_module = SafeModule( + module=net, + spec=spec, + in_keys=["in"], + out_keys=["out"], + safe=safe, + ) + return + else: + tensordict_module = SafeModule( + module=net, + spec=spec, + in_keys=["in"], + out_keys=["out"], + safe=safe, + ) + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tensordict_module(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any(), td.get("out") + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("out_keys", [["loc", "scale"], ["loc_1", "scale_1"]]) + @pytest.mark.parametrize("lazy", [True, False]) + @pytest.mark.parametrize( + "exp_mode", [InteractionType.DETERMINISTIC, InteractionType.RANDOM, None] + ) + def test_stateful_probabilistic(self, safe, spec_type, lazy, exp_mode, out_keys): + torch.manual_seed(0) + param_multiplier = 2 + if lazy: + net = nn.LazyLinear(4 * param_multiplier) + else: + net = nn.Linear(3, 4 * param_multiplier) + + in_keys = ["in"] + net = SafeModule( + module=nn.Sequential(net, NormalParamExtractor()), + spec=None, + in_keys=in_keys, + out_keys=out_keys, + ) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = Bounded(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = Unbounded(4) + else: + raise NotImplementedError + + kwargs = {"distribution_class": TanhNormal} + if out_keys == ["loc", "scale"]: + dist_in_keys = ["loc", "scale"] + elif out_keys == ["loc_1", "scale_1"]: + dist_in_keys = {"loc": "loc_1", "scale": "scale_1"} + else: + raise NotImplementedError + + if safe and spec is None: + with pytest.raises( + RuntimeError, + match="is not a valid configuration as the tensor specs are not " + "specified", + ): + prob_module = SafeProbabilisticModule( + in_keys=dist_in_keys, + out_keys=["out"], + spec=spec, + safe=safe, + **kwargs, + ) + return + else: + prob_module = SafeProbabilisticModule( + in_keys=dist_in_keys, + out_keys=["out"], + spec=spec, + safe=safe, + **kwargs, + ) + + tensordict_module = SafeProbabilisticTensorDictSequential(net, prob_module) + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + with set_exploration_type(exp_mode): + tensordict_module(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + +class TestTDSequence: + # Temporarily disabling this test until 473 is merged in tensordict + # def test_in_key_warning(self): + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_"], out_keys=["out1"] + # ) + # with pytest.warns(UserWarning, match='key "_" is for ignoring output'): + # tensordict_module = SafeModule( + # nn.Linear(3, 4), in_keys=["_", "key2"], out_keys=["out1"] + # ) + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful(self, safe, spec_type, lazy): + torch.manual_seed(0) + param_multiplier = 1 + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = Bounded(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = Unbounded(4) + + kwargs = {} + + if safe and spec is None: + pytest.skip("safe and spec is None is checked elsewhere") + else: + tdmodule1 = SafeModule( + net1, + spec=None, + in_keys=["in"], + out_keys=["hidden"], + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + spec=None, + in_keys=["hidden"], + out_keys=["hidden"], + safe=False, + ) + tdmodule2 = SafeModule( + spec=spec, + module=net2, + in_keys=["hidden"], + out_keys=["out"], + safe=False, + **kwargs, + ) + tdmodule = SafeSequential(tdmodule1, dummy_tdmodule, tdmodule2) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 3 + tdmodule[1] = tdmodule2 + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 3 + del tdmodule[2] + assert len(tdmodule) == 2 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tdmodule(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + @pytest.mark.parametrize("safe", [True, False]) + @pytest.mark.parametrize("spec_type", [None, "bounded", "unbounded"]) + @pytest.mark.parametrize("lazy", [True, False]) + def test_stateful_probabilistic(self, safe, spec_type, lazy): + torch.manual_seed(0) + param_multiplier = 2 + if lazy: + net1 = nn.LazyLinear(4) + dummy_net = nn.LazyLinear(4) + net2 = nn.LazyLinear(4 * param_multiplier) + else: + net1 = nn.Linear(3, 4) + dummy_net = nn.Linear(4, 4) + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = nn.Sequential(net2, NormalParamExtractor()) + + if spec_type is None: + spec = None + elif spec_type == "bounded": + spec = Bounded(-0.1, 0.1, 4) + elif spec_type == "unbounded": + spec = Unbounded(4) + else: + raise NotImplementedError + + kwargs = {"distribution_class": TanhNormal} + + if safe and spec is None: + pytest.skip("safe and spec is None is checked elsewhere") + else: + tdmodule1 = SafeModule( + net1, + in_keys=["in"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + dummy_tdmodule = SafeModule( + dummy_net, + in_keys=["hidden"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeModule( + module=net2, + in_keys=["hidden"], + out_keys=["loc", "scale"], + spec=None, + safe=False, + ) + + prob_module = SafeProbabilisticModule( + spec=spec, + in_keys=["loc", "scale"], + out_keys=["out"], + safe=False, + **kwargs, + ) + tdmodule = SafeProbabilisticTensorDictSequential( + tdmodule1, dummy_tdmodule, tdmodule2, prob_module + ) + + assert hasattr(tdmodule, "__setitem__") + assert len(tdmodule) == 4 + tdmodule[1] = tdmodule2 + tdmodule[2] = prob_module + assert len(tdmodule) == 4 + + assert hasattr(tdmodule, "__delitem__") + assert len(tdmodule) == 4 + del tdmodule[3] + assert len(tdmodule) == 3 + + assert hasattr(tdmodule, "__getitem__") + assert tdmodule[0] is tdmodule1 + assert tdmodule[1] is tdmodule2 + assert tdmodule[2] is prob_module + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + tdmodule(td) + assert td.shape == torch.Size([3]) + assert td.get("out").shape == torch.Size([3, 4]) + + dist = tdmodule.get_dist(td) + assert dist.rsample().shape[: td.ndimension()] == td.shape + + # test bounds + if not safe and spec_type == "bounded": + assert ((td.get("out") > 0.1) | (td.get("out") < -0.1)).any() + elif safe and spec_type == "bounded": + assert ((td.get("out") < 0.1) | (td.get("out") > -0.1)).all() + + def test_submodule_sequence(self): + td_module_1 = SafeModule( + nn.Linear(3, 2), + in_keys=["in"], + out_keys=["hidden"], + ) + td_module_2 = SafeModule( + nn.Linear(2, 4), + in_keys=["hidden"], + out_keys=["out"], + ) + td_module = SafeSequential(td_module_1, td_module_2) + + td_1 = TensorDict({"in": torch.randn(5, 3)}, [5]) + sub_seq_1 = td_module.select_subsequence(out_keys=["hidden"]) + sub_seq_1(td_1) + assert "hidden" in td_1.keys() + assert "out" not in td_1.keys() + td_2 = TensorDict({"hidden": torch.randn(5, 2)}, [5]) + sub_seq_2 = td_module.select_subsequence(in_keys=["hidden"]) + sub_seq_2(td_2) + assert "out" in td_2.keys() + assert td_2.get("out").shape == torch.Size([5, 4]) + + @pytest.mark.parametrize("stack", [True, False]) + def test_sequential_partial(self, stack): + torch.manual_seed(0) + param_multiplier = 2 + + net1 = nn.Linear(3, 4) + + net2 = nn.Linear(4, 4 * param_multiplier) + net2 = nn.Sequential(net2, NormalParamExtractor()) + net2 = SafeModule(net2, in_keys=["b"], out_keys=["loc", "scale"]) + + net3 = nn.Linear(4, 4 * param_multiplier) + net3 = nn.Sequential(net3, NormalParamExtractor()) + net3 = SafeModule(net3, in_keys=["c"], out_keys=["loc", "scale"]) + + spec = Bounded(-0.1, 0.1, 4) + + kwargs = {"distribution_class": TanhNormal} + + tdmodule1 = SafeModule( + net1, + in_keys=["a"], + out_keys=["hidden"], + spec=None, + safe=False, + ) + tdmodule2 = SafeProbabilisticTensorDictSequential( + net2, + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=["out"], + spec=spec, + safe=True, + **kwargs, + ), + ) + tdmodule3 = SafeProbabilisticTensorDictSequential( + net3, + SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=["out"], + spec=spec, + safe=True, + **kwargs, + ), + ) + tdmodule = SafeSequential( + tdmodule1, tdmodule2, tdmodule3, partial_tolerant=True + ) + + if stack: + td = LazyStackedTensorDict.maybe_dense_stack( + [ + TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []), + TensorDict({"a": torch.randn(3), "c": torch.randn(4)}, []), + ], + 0, + ) + tdmodule(td) + assert "loc" in td.keys() + assert "scale" in td.keys() + assert "out" in td.keys() + assert td["out"].shape[0] == 2 + assert td["loc"].shape[0] == 2 + assert td["scale"].shape[0] == 2 + assert "b" not in td.keys() + assert "b" in td[0].keys() + else: + td = TensorDict({"a": torch.randn(3), "b": torch.randn(4)}, []) + tdmodule(td) + assert "loc" in td.keys() + assert "scale" in td.keys() + assert "out" in td.keys() + assert "b" in td.keys() + + +def test_is_tensordict_compatible(): + class MultiHeadLinear(nn.Module): + def __init__(self, in_1, out_1, out_2, out_3): + super().__init__() + self.linear_1 = nn.Linear(in_1, out_1) + self.linear_2 = nn.Linear(in_1, out_2) + self.linear_3 = nn.Linear(in_1, out_3) + + def forward(self, x): + return self.linear_1(x), self.linear_2(x), self.linear_3(x) + + td_module = SafeModule( + MultiHeadLinear(5, 4, 3, 2), + in_keys=["in_1", "in_2"], + out_keys=["out_1", "out_2"], + ) + assert is_tensordict_compatible(td_module) + + class MockCompatibleModule(nn.Module): + def __init__(self, in_keys, out_keys): + self.in_keys = in_keys + self.out_keys = out_keys + + def forward(self, tensordict): + pass + + compatible_nn_module = MockCompatibleModule( + in_keys=["in_1", "in_2"], + out_keys=["out_1", "out_2"], + ) + assert is_tensordict_compatible(compatible_nn_module) + + class MockIncompatibleModuleNoKeys(nn.Module): + def forward(self, input): + pass + + incompatible_nn_module_no_keys = MockIncompatibleModuleNoKeys() + assert not is_tensordict_compatible(incompatible_nn_module_no_keys) + + class MockIncompatibleModuleMultipleArgs(nn.Module): + def __init__(self, in_keys, out_keys): + self.in_keys = in_keys + self.out_keys = out_keys + + def forward(self, input_1, input_2): + pass + + incompatible_nn_module_multi_args = MockIncompatibleModuleMultipleArgs( + in_keys=["in_1", "in_2"], + out_keys=["out_1", "out_2"], + ) + with pytest.raises(TypeError): + is_tensordict_compatible(incompatible_nn_module_multi_args) + + +def test_ensure_tensordict_compatible(): + class MultiHeadLinear(nn.Module): + def __init__(self, in_1, out_1, out_2, out_3): + super().__init__() + self.linear_1 = nn.Linear(in_1, out_1) + self.linear_2 = nn.Linear(in_1, out_2) + self.linear_3 = nn.Linear(in_1, out_3) + + def forward(self, x): + return self.linear_1(x), self.linear_2(x), self.linear_3(x) + + td_module = SafeModule( + MultiHeadLinear(5, 4, 3, 2), + in_keys=["in_1", "in_2"], + out_keys=["out_1", "out_2"], + ) + ensured_module = ensure_tensordict_compatible(td_module) + assert ensured_module is td_module + with pytest.raises(TypeError): + ensure_tensordict_compatible(td_module, in_keys=["input"]) + with pytest.raises(TypeError): + ensure_tensordict_compatible(td_module, out_keys=["output"]) + + class NonNNModule: + def __init__(self): + pass + + def forward(self, x): + pass + + non_nn_module = NonNNModule() + with pytest.raises(TypeError): + ensure_tensordict_compatible(non_nn_module) + + class ErrorNNModule(nn.Module): + def forward(self, in_1, in_2): + pass + + error_nn_module = ErrorNNModule() + with pytest.raises(TypeError): + ensure_tensordict_compatible(error_nn_module, in_keys=["input"]) + + nn_module = MultiHeadLinear(5, 4, 3, 2) + ensured_module = ensure_tensordict_compatible( + nn_module, + in_keys=["x"], + out_keys=["out_1", "out_2", "out_3"], + ) + assert set(unravel_key_list(ensured_module.in_keys)) == {"x"} + assert isinstance(ensured_module, TensorDictModule) + + +def test_safe_specs(): + + out_key = ("a", "b") + spec = Composite(Composite({out_key: Unbounded()})) + original_spec = spec.clone() + mod = SafeModule( + module=nn.Linear(3, 1), + spec=spec, + out_keys=[out_key, ("other", "key")], + in_keys=[], + ) + assert original_spec == spec + assert original_spec[out_key] == mod.spec[out_key] + + +def test_actor_critic_specs(): + action_key = ("agents", "action") + spec = Composite(Composite({action_key: Unbounded(shape=(3,))})) + policy_module = TensorDictModule( + nn.Linear(3, 1), + in_keys=[("agents", "observation")], + out_keys=[action_key], + ) + original_spec = spec.clone() + module = TensorDictSequential( + policy_module, AdditiveGaussianModule(spec=spec, action_key=action_key) + ) + value_module = ValueOperator( + module=module, + in_keys=[("agents", "observation"), action_key], + out_keys=[("agents", "state_action_value")], + ) + assert original_spec == spec + assert module[1].spec == spec + DDPGLoss(actor_network=module, value_network=value_module) + assert original_spec == spec + assert module[1].spec == spec + + +def test_vmapmodule(): + lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) + sample_in = torch.ones((10, 3, 2)) + sample_in_td = TensorDict({"x": sample_in}, batch_size=[10]) + lam(sample_in) + vm = VmapModule(lam, 0) + vm(sample_in_td) + assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/_rb_common.py b/test/rb/_rb_common.py new file mode 100644 index 00000000000..7d4e886671d --- /dev/null +++ b/test/rb/_rb_common.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import importlib +import sys + +import torch +from packaging import version +from packaging.version import parse + +from torchrl.data import ReplayBuffer, TensorDictReplayBuffer + +OLD_TORCH = parse(torch.__version__) < parse("2.0.0") +_has_tv = importlib.util.find_spec("torchvision") is not None +_has_gym = importlib.util.find_spec("gym") is not None +_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None +_os_is_windows = sys.platform == "win32" +_has_transformers = importlib.util.find_spec("transformers") is not None +_has_ray = importlib.util.find_spec("ray") is not None +_has_zstandard = importlib.util.find_spec("zstandard") is not None + +TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) + +torch_2_3 = version.parse( + ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) +) >= version.parse("2.3.0") + +ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator()) +TensorDictReplayBufferRNG = functools.partial( + TensorDictReplayBuffer, generator=torch.Generator() +) diff --git a/test/rb/conftest.py b/test/rb/conftest.py new file mode 100644 index 00000000000..f98e32d8b55 --- /dev/null +++ b/test/rb/conftest.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations diff --git a/test/rb/test_composable.py b/test/rb/test_composable.py new file mode 100644 index 00000000000..ef99d2dc6fd --- /dev/null +++ b/test/rb/test_composable.py @@ -0,0 +1,560 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import contextlib +import pickle +import sys + +import pytest +import torch +from _rb_common import ( + _os_is_windows, + OLD_TORCH, + ReplayBufferRNG, + TensorDictReplayBufferRNG, + TORCH_VERSION, +) +from packaging import version +from tensordict import is_tensor_collection, TensorDict, TensorDictBase +from torch.utils._pytree import tree_flatten, tree_map + +from torchrl.data import ( + RemoteTensorDictReplayBuffer, + ReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers import samplers, writers +from torchrl.data.replay_buffers.samplers import RandomSampler +from torchrl.data.replay_buffers.storages import ( + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, + TensorStorage, +) +from torchrl.data.replay_buffers.writers import ( + RoundRobinWriter, + TensorDictMaxValueWriter, +) +from torchrl.testing import capture_log_records + + +@pytest.mark.parametrize( + "sampler", + [ + samplers.RandomSampler, + samplers.SamplerWithoutReplacement, + samplers.PrioritizedSampler, + ], +) +@pytest.mark.parametrize( + "writer", [writers.RoundRobinWriter, writers.TensorDictMaxValueWriter] +) +@pytest.mark.parametrize( + "rb_type,storage,datatype", + [ + [ReplayBuffer, ListStorage, None], + [ReplayBufferRNG, ListStorage, None], + [TensorDictReplayBuffer, ListStorage, "tensordict"], + [TensorDictReplayBufferRNG, ListStorage, "tensordict"], + [RemoteTensorDictReplayBuffer, ListStorage, "tensordict"], + [ReplayBuffer, LazyTensorStorage, "tensor"], + [ReplayBuffer, LazyTensorStorage, "tensordict"], + [ReplayBuffer, LazyTensorStorage, "pytree"], + [ReplayBufferRNG, LazyTensorStorage, "tensor"], + [ReplayBufferRNG, LazyTensorStorage, "tensordict"], + [ReplayBufferRNG, LazyTensorStorage, "pytree"], + [TensorDictReplayBuffer, LazyTensorStorage, "tensordict"], + [TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"], + [RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"], + [ReplayBuffer, LazyMemmapStorage, "tensor"], + [ReplayBuffer, LazyMemmapStorage, "tensordict"], + [ReplayBuffer, LazyMemmapStorage, "pytree"], + [ReplayBufferRNG, LazyMemmapStorage, "tensor"], + [ReplayBufferRNG, LazyMemmapStorage, "tensordict"], + [ReplayBufferRNG, LazyMemmapStorage, "pytree"], + [TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], + [TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"], + [RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], + ], +) +@pytest.mark.parametrize("size", [3, 5, 100]) +class TestComposableBuffers: + def _get_rb( + self, rb_type, size, sampler, writer, storage, compilable=False, **kwargs + ): + if storage is not None: + storage = storage(size, compilable=compilable) + + sampler_args = {} + if sampler is samplers.PrioritizedSampler: + sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9} + + sampler = sampler(**sampler_args) + writer = writer(compilable=compilable) + rb = rb_type( + storage=storage, + sampler=sampler, + writer=writer, + batch_size=3, + compilable=compilable, + **kwargs, + ) + return rb + + def _get_datum(self, datatype): + if datatype is None: + data = torch.randint(100, (1,)) + elif datatype == "tensor": + data = torch.randint(100, (1,)) + elif datatype == "tensordict": + data = TensorDict( + {"a": torch.randint(100, (1,)), "next": {"reward": torch.randn(1)}}, [] + ) + elif datatype == "pytree": + data = { + "a": torch.randint(100, (1,)), + "b": {"c": [torch.zeros(3), (torch.ones(2),)]}, + 30: torch.zeros(2), + } + else: + raise NotImplementedError(datatype) + return data + + def _get_data(self, datatype, size): + if datatype is None: + data = torch.randint(100, (size, 1)) + elif datatype == "tensor": + data = torch.randint(100, (size, 1)) + elif datatype == "tensordict": + data = TensorDict( + { + "a": torch.randint(100, (size, 1)), + "next": {"reward": torch.randn(size, 1)}, + }, + [size], + ) + elif datatype == "pytree": + data = { + "a": torch.randint(100, (size, 1)), + "b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]}, + 30: torch.zeros(size, 2), + } + else: + raise NotImplementedError(datatype) + return data + + def test_rb_repr(self, rb_type, sampler, writer, storage, size, datatype): + if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: + pytest.skip( + "Distributed package support on Windows is a prototype feature and is subject to changes." + ) + torch.manual_seed(0) + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + data = self._get_datum(datatype) + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.add(data) + return + rb.add(data) + # we just check that str runs, not its value + assert str(rb) + rb.sample() + assert str(rb) + + def test_add(self, rb_type, sampler, writer, storage, size, datatype): + if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: + pytest.skip( + "Distributed package support on Windows is a prototype feature and is subject to changes." + ) + torch.manual_seed(0) + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + data = self._get_datum(datatype) + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.add(data) + return + rb.add(data) + s, info = rb.sample(1, return_info=True) + assert len(rb) == 1 + if isinstance(s, (torch.Tensor, TensorDictBase)): + assert s.ndim, s + s = s[0] + else: + + def assert_ndim(tensor): + assert tensor.shape[0] == 1 + + tree_map(assert_ndim, s) + s = tree_map(lambda s: s[0], s) + if isinstance(s, TensorDictBase): + s = s.select(*data.keys(True), strict=False) + data = data.select(*s.keys(True), strict=False) + assert (s == data).all() + assert list(s.keys(True, True)) + else: + flat_s = tree_flatten(s)[0] + flat_data = tree_flatten(data)[0] + assert all((_s == _data).all() for (_s, _data) in zip(flat_s, flat_data)) + + def test_cursor_position(self, rb_type, sampler, writer, storage, size, datatype): + storage = storage(size) + writer = writer() + writer.register_storage(storage) + batch1 = self._get_data(datatype, size=5) + cond = ( + OLD_TORCH + and not isinstance(writer, TensorDictMaxValueWriter) + and size < len(batch1) + and isinstance(storage, TensorStorage) + ) + + if not is_tensor_collection(batch1) and isinstance( + writer, TensorDictMaxValueWriter + ): + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + writer.extend(batch1) + return + + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + writer.extend(batch1) + + # Added less data than storage max size + if size > 5: + assert writer._cursor == 5 + # Added more data than storage max size + elif size < 5: + # if Max writer, we don't necessarily overwrite existing values so + # we just check that the cursor is before the threshold + if isinstance(writer, TensorDictMaxValueWriter): + assert writer._cursor <= 5 - size + else: + assert writer._cursor == 5 - size + # Added as data as storage max size + else: + assert writer._cursor == 0 + if not isinstance(writer, TensorDictMaxValueWriter): + batch2 = self._get_data(datatype, size=size - 1) + writer.extend(batch2) + assert writer._cursor == size - 1 + + def test_extend(self, rb_type, sampler, writer, storage, size, datatype): + if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: + pytest.skip( + "Distributed package support on Windows is a prototype feature and is subject to changes." + ) + torch.manual_seed(0) + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + data_shape = 5 + data = self._get_data(datatype, size=data_shape) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb.storage, TensorStorage) + ) + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return + length = min(rb.storage.max_size, len(rb) + data_shape) + if writer is TensorDictMaxValueWriter: + data["next", "reward"][-length:] = 1_000_000 + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data) + length = len(rb) + if is_tensor_collection(data): + data_iter = data[-length:] + else: + + def data_iter(): + for t in range(-length, -1): + yield tree_map(lambda x, t=t: x[t], data) + + data_iter = data_iter() + for d in data_iter: + for b in rb.storage: + if isinstance(b, TensorDictBase): + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) + if isinstance(b, (torch.Tensor, TensorDictBase)): + value = b == d + value = value.all() + else: + d_flat = tree_flatten(d)[0] + b_flat = tree_flatten(b)[0] + value = all((_b == _d).all() for (_b, _d) in zip(b_flat, d_flat)) + if value: + break + else: + raise RuntimeError("did not find match") + + data2 = self._get_data(datatype, size=2 * size + 2) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data2) + and isinstance(rb.storage, TensorStorage) + ) + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data2) + + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + # Compiling on Windows requires "cl" compiler to be installed. + # + # Our Windows CI jobs do not have "cl", so skip this test. + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + @pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="torch.compile is not supported on Python 3.14+", + ) + @pytest.mark.parametrize("avoid_max_size", [False, True]) + def test_extend_sample_recompile( + self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size + ): + if rb_type is not ReplayBuffer: + pytest.skip( + "Only replay buffer of type 'ReplayBuffer' is currently supported." + ) + if sampler is not RandomSampler: + pytest.skip("Only sampler of type 'RandomSampler' is currently supported.") + if storage is not LazyTensorStorage: + pytest.skip( + "Only storage of type 'LazyTensorStorage' is currently supported." + ) + if writer is not RoundRobinWriter: + pytest.skip( + "Only writer of type 'RoundRobinWriter' is currently supported." + ) + if datatype == "tensordict": + pytest.skip("'tensordict' datatype is not currently supported.") + + torch._dynamo.reset_code_caches() + + # Number of times to extend the replay buffer + num_extend = 10 + data_size = size + + # These two cases are separated because when the max storage size is + # reached, the code execution path changes, causing necessary + # recompiles. + if avoid_max_size: + storage_size = (num_extend + 1) * data_size + else: + storage_size = 2 * data_size + + rb = self._get_rb( + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=storage_size, + compilable=True, + ) + data = self._get_data(datatype, size=data_size) + + @torch.compile + def extend_and_sample(data): + rb.extend(data) + return rb.sample() + + # NOTE: The first three calls to 'extend' and 'sample' can currently + # cause recompilations, so avoid capturing those. + num_extend_before_capture = 3 + + for _ in range(num_extend_before_capture): + extend_and_sample(data) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend - num_extend_before_capture): + extend_and_sample(data) + + finally: + torch._logging.set_logs() + + assert len(rb) == min((num_extend * data_size), storage_size) + assert len(records) == 0 + + def test_sample(self, rb_type, sampler, writer, storage, size, datatype): + if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: + pytest.skip( + "Distributed package support on Windows is a prototype feature and is subject to changes." + ) + torch.manual_seed(0) + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + data = self._get_data(datatype, size=5) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb.storage, TensorStorage) + ) + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data) + rb_sample = rb.sample() + # if not isinstance(new_data, (torch.Tensor, TensorDictBase)): + # new_data = new_data[0] + + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + rb_sample_iter = rb_sample + else: + + def data_iter_func(maxval, data=data): + for t in range(maxval): + yield tree_map(lambda x, t=t: x[t], data) + + rb_sample_iter = data_iter_func(rb._batch_size, rb_sample) + + for single_sample in rb_sample_iter: + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + data_iter = data + else: + data_iter = data_iter_func(5, data) + + for data_sample in data_iter: + if isinstance(data_sample, TensorDictBase): + keys = set(single_sample.keys()).intersection(data_sample.keys()) + data_sample = data_sample.exclude("index").select( + *keys, strict=False + ) + keys = set(single_sample.keys()).intersection(data_sample.keys()) + single_sample = single_sample.select(*keys, strict=False) + + if isinstance(data_sample, (torch.Tensor, TensorDictBase)): + value = data_sample == single_sample + value = value.all() + else: + d_flat = tree_flatten(single_sample)[0] + b_flat = tree_flatten(data_sample)[0] + value = all((_b == _d).all() for (_b, _d) in zip(b_flat, d_flat)) + + if value: + break + else: + raise RuntimeError("did not find match") + + def test_index(self, rb_type, sampler, writer, storage, size, datatype): + if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: + pytest.skip( + "Distributed package support on Windows is a prototype feature and is subject to changes." + ) + torch.manual_seed(0) + rb = self._get_rb( + rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size + ) + data = self._get_data(datatype, size=5) + cond = ( + OLD_TORCH + and writer is not TensorDictMaxValueWriter + and size < len(data) + and isinstance(rb.storage, TensorStorage) + ) + if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: + with pytest.raises( + RuntimeError, match="expects data to be a tensor collection" + ): + rb.extend(data) + return + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data) + d1 = rb[2] + d2 = rb.storage[2] + if type(d1) is not type(d2): + d1 = d1[0] + if is_tensor_collection(data) or isinstance(data, torch.Tensor): + b = d1 == d2 + if not isinstance(b, bool): + b = b.all() + else: + d1_flat = tree_flatten(d1)[0] + d2_flat = tree_flatten(d2)[0] + b = all((_d1 == _d2).all() for (_d1, _d2) in zip(d1_flat, d2_flat)) + assert b + + def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): + rb = self._get_rb( + rb_type=rb_type, + sampler=sampler, + writer=writer, + storage=storage, + size=size, + delayed_init=False, + ) + serialized = pickle.dumps(rb) + rb2 = pickle.loads(serialized) + assert rb.__dict__.keys() == rb2.__dict__.keys() + for key in sorted(rb.__dict__.keys()): + assert isinstance(rb.__dict__[key], type(rb2.__dict__[key])) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/test_ensemble.py b/test/rb/test_ensemble.py new file mode 100644 index 00000000000..2da22a87128 --- /dev/null +++ b/test/rb/test_ensemble.py @@ -0,0 +1,673 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import contextlib +import functools + +import numpy as np +import pytest +import torch +from _rb_common import _has_gym, _has_tv +from tensordict import ( + assert_allclose_td, + is_tensor_collection, + tensorclass, + TensorDict, + TensorDictBase, +) +from torch.utils._pytree import tree_flatten + +from torchrl.collectors import Collector +from torchrl.collectors.utils import split_trajectories +from torchrl.data import ( + FlatStorageCheckpointer, + MultiStep, + NestedStorageCheckpointer, + PrioritizedReplayBuffer, + ReplayBuffer, + ReplayBufferEnsemble, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.checkpointers import H5StorageCheckpointer +from torchrl.data.replay_buffers.samplers import ( + PrioritizedSampler, + PrioritizedSliceSampler, + RandomSampler, + SamplerEnsemble, + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import ( + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, + StorageEnsemble, + TensorStorage, +) +from torchrl.data.replay_buffers.utils import tree_iter +from torchrl.data.replay_buffers.writers import ( + RoundRobinWriter, + TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, + WriterEnsemble, +) +from torchrl.envs import GymEnv, SerialEnv +from torchrl.envs.transforms.transforms import ( + Compose, + RenameTransform, + Resize, + StepCounter, + ToTensorImage, +) +from torchrl.modules import RandomPolicy +from torchrl.testing import CARTPOLE_VERSIONED, get_default_devices +from torchrl.testing.mocking_classes import CountingEnv + + +class TestEnsemble: + def _make_data(self, data_type): + if data_type is torch.Tensor: + return torch.ones(90) + if data_type is TensorDict: + return TensorDict( + { + "root": torch.arange(90), + "nested": TensorDict( + {"data": torch.arange(180).view(90, 2)}, batch_size=[90, 2] + ), + }, + batch_size=[90], + ) + raise NotImplementedError + + def _make_sampler(self, sampler_type): + if sampler_type is SamplerWithoutReplacement: + return SamplerWithoutReplacement(drop_last=True) + if sampler_type is RandomSampler: + return RandomSampler() + raise NotImplementedError + + def _make_storage(self, storage_type, data_type): + if storage_type is LazyMemmapStorage: + return LazyMemmapStorage(max_size=100) + if storage_type is TensorStorage: + if data_type is TensorDict: + return TensorStorage(TensorDict(batch_size=[100])) + elif data_type is torch.Tensor: + return TensorStorage(torch.zeros(100)) + else: + raise NotImplementedError + if storage_type is ListStorage: + return ListStorage(max_size=100) + raise NotImplementedError + + def _make_collate(self, storage_type): + if storage_type is ListStorage: + return torch.stack + else: + return self._robust_stack + + @staticmethod + def _robust_stack(tensor_list): + if not isinstance(tensor_list, (tuple, list)): + return tensor_list + if all(tensor.shape == tensor_list[0].shape for tensor in tensor_list[1:]): + return torch.stack(list(tensor_list)) + if is_tensor_collection(tensor_list[0]): + return torch.cat(list(tensor_list)) + return torch.nested.nested_tensor(list(tensor_list)) + + @pytest.mark.parametrize( + "storage_type", [LazyMemmapStorage, TensorStorage, ListStorage] + ) + @pytest.mark.parametrize("data_type", [torch.Tensor, TensorDict]) + @pytest.mark.parametrize("p", [[0.0, 0.9, 0.1], None]) + @pytest.mark.parametrize("num_buffer_sampled", [3, 16, None]) + @pytest.mark.parametrize("batch_size", [48, None]) + @pytest.mark.parametrize("sampler_type", [RandomSampler, SamplerWithoutReplacement]) + def test_rb( + self, storage_type, sampler_type, data_type, p, num_buffer_sampled, batch_size + ): + storages = [self._make_storage(storage_type, data_type) for _ in range(3)] + collate_fn = self._make_collate(storage_type) + data = [self._make_data(data_type) for _ in range(3)] + samplers = [self._make_sampler(sampler_type) for _ in range(3)] + sub_batch_size = ( + batch_size // 3 + if issubclass(sampler_type, SamplerWithoutReplacement) + and batch_size is not None + else None + ) + error_catcher = ( + pytest.raises( + ValueError, + match="Samplers with drop_last=True must work with a predictable batch-size", + ) + if batch_size is None + and issubclass(sampler_type, SamplerWithoutReplacement) + else contextlib.nullcontext() + ) + rbs = None + with error_catcher: + rbs = (rb0, rb1, rb2) = [ + ReplayBuffer( + storage=storage, + sampler=sampler, + collate_fn=collate_fn, + batch_size=sub_batch_size, + ) + for (storage, sampler) in zip(storages, samplers) + ] + if rbs is None: + return + for datum, rb in zip(data, rbs): + rb.extend(datum) + rb = ReplayBufferEnsemble( + *rbs, p=p, num_buffer_sampled=num_buffer_sampled, batch_size=batch_size + ) + if batch_size is not None: + for batch_iter in rb: + assert isinstance(batch_iter, (torch.Tensor, TensorDictBase)) + break + batch_sample, info = rb.sample(return_info=True) + else: + batch_iter = None + batch_sample, info = rb.sample(48, return_info=True) + assert isinstance(batch_sample, (torch.Tensor, TensorDictBase)) + if isinstance(batch_sample, TensorDictBase): + assert "root" in batch_sample.keys() + assert "nested" in batch_sample.keys() + assert ("nested", "data") in batch_sample.keys(True) + if p is not None: + if batch_iter is not None: + buffer_ids = batch_iter.get(("index", "buffer_ids")) + assert isinstance(buffer_ids, torch.Tensor), batch_iter + assert 0 not in buffer_ids.unique().tolist() + + buffer_ids = batch_sample.get(("index", "buffer_ids")) + assert isinstance(buffer_ids, torch.Tensor), buffer_ids + assert 0 not in buffer_ids.unique().tolist() + if num_buffer_sampled is not None: + if batch_iter is not None: + assert batch_iter.shape == torch.Size( + [num_buffer_sampled, 48 // num_buffer_sampled] + ) + assert batch_sample.shape == torch.Size( + [num_buffer_sampled, 48 // num_buffer_sampled] + ) + else: + if batch_iter is not None: + assert batch_iter.shape == torch.Size([3, 16]) + assert batch_sample.shape == torch.Size([3, 16]) + + def _prepare_dual_replay_buffer(self, explicit=False): + torch.manual_seed(0) + rb0 = TensorDictReplayBuffer( + storage=LazyMemmapStorage(10), + transform=Compose( + ToTensorImage(in_keys=["pixels", ("next", "pixels")]), + Resize(32, in_keys=["pixels", ("next", "pixels")]), + RenameTransform([("some", "key")], ["renamed"]), + ), + ) + rb1 = TensorDictReplayBuffer( + storage=LazyMemmapStorage(10), + transform=Compose( + ToTensorImage(in_keys=["pixels", ("next", "pixels")]), + Resize(32, in_keys=["pixels", ("next", "pixels")]), + RenameTransform(["another_key"], ["renamed"]), + ), + ) + if explicit: + storages = StorageEnsemble( + rb0._storage, rb1._storage, transforms=[rb0._transform, rb1._transform] + ) + writers = WriterEnsemble(rb0._writer, rb1._writer) + samplers = SamplerEnsemble(rb0._sampler, rb1._sampler, p=[0.5, 0.5]) + collate_fns = [rb0._collate_fn, rb1._collate_fn] + rb = ReplayBufferEnsemble( + storages=storages, + samplers=samplers, + writers=writers, + collate_fns=collate_fns, + transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]), + ) + else: + rb = ReplayBufferEnsemble( + rb0, + rb1, + p=[0.5, 0.5], + transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]), + ) + data0 = TensorDict( + { + "pixels": torch.randint(255, (10, 244, 244, 3)), + ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)), + ("some", "key"): torch.randn(10), + }, + batch_size=[10], + ) + data1 = TensorDict( + { + "pixels": torch.randint(255, (10, 64, 64, 3)), + ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)), + "another_key": torch.randn(10), + }, + batch_size=[10], + ) + rb0.extend(data0) + rb1.extend(data1) + return rb, rb0, rb1 + + @pytest.mark.skipif(not _has_tv, reason="torchvision not found") + def test_rb_transform(self): + rb, rb0, rb1 = self._prepare_dual_replay_buffer() + for _ in range(2): + sample = rb.sample(10) + assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32]) + assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32]) + assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33]) + assert sample["renamed"].shape == torch.Size([2, 5]) + + @pytest.mark.skipif(not _has_tv, reason="torchvision not found") + @pytest.mark.parametrize("explicit", [False, True]) + def test_rb_indexing(self, explicit): + rb, rb0, rb1 = self._prepare_dual_replay_buffer(explicit=explicit) + if explicit: + # indirect checks + assert rb[0]._storage is rb0._storage + assert rb[1]._storage is rb1._storage + else: + assert rb[0] is rb0 + assert rb[1] is rb1 + assert rb[:] is rb + + torch.manual_seed(0) + sample1 = rb.sample(6) + # tensor + torch.manual_seed(0) + sample0 = rb[torch.tensor([0, 1])].sample(6) + assert_allclose_td(sample0, sample1) + # slice + torch.manual_seed(0) + sample0 = rb[:2].sample(6) + assert_allclose_td(sample0, sample1) + # np.ndarray + torch.manual_seed(0) + sample0 = rb[np.array([0, 1])].sample(6) + assert_allclose_td(sample0, sample1) + # list + torch.manual_seed(0) + sample0 = rb[[0, 1]].sample(6) + assert_allclose_td(sample0, sample1) + + # direct indexing + sample1 = rb[:, :3] + # tensor + sample0 = rb[torch.tensor([0, 1]), :3] + assert_allclose_td(sample0, sample1) + # slice + torch.manual_seed(0) + sample0 = rb[:2, :3] + assert_allclose_td(sample0, sample1) + # np.ndarray + torch.manual_seed(0) + sample0 = rb[np.array([0, 1]), :3] + assert_allclose_td(sample0, sample1) + # list + torch.manual_seed(0) + sample0 = rb[[0, 1], :3] + assert_allclose_td(sample0, sample1) + + # check indexing of components + assert isinstance(rb.storage[:], StorageEnsemble) + assert isinstance(rb.storage[:2], StorageEnsemble) + assert isinstance(rb.storage[torch.tensor([0, 1])], StorageEnsemble) + assert isinstance(rb.storage[np.array([0, 1])], StorageEnsemble) + assert isinstance(rb.storage[[0, 1]], StorageEnsemble) + assert isinstance(rb.storage[1], LazyMemmapStorage) + + rb.storage[:, :3] + rb.storage[:2, :3] + rb.storage[torch.tensor([0, 1]), :3] + rb.storage[np.array([0, 1]), :3] + rb.storage[[0, 1], :3] + + assert isinstance(rb.sampler[:], SamplerEnsemble) + assert isinstance(rb.sampler[:2], SamplerEnsemble) + assert isinstance(rb.sampler[torch.tensor([0, 1])], SamplerEnsemble) + assert isinstance(rb.sampler[np.array([0, 1])], SamplerEnsemble) + assert isinstance(rb.sampler[[0, 1]], SamplerEnsemble) + assert isinstance(rb.sampler[1], RandomSampler) + + assert isinstance(rb.writer[:], WriterEnsemble) + assert isinstance(rb.writer[:2], WriterEnsemble) + assert isinstance(rb.writer[torch.tensor([0, 1])], WriterEnsemble) + assert isinstance(rb.writer[np.array([0, 1])], WriterEnsemble) + assert isinstance(rb.writer[[0, 1]], WriterEnsemble) + assert isinstance(rb.writer[0], RoundRobinWriter) + + +def _rbtype(datatype): + if datatype in ("pytree", "tensorclass"): + return [ + (ReplayBuffer, RandomSampler), + (PrioritizedReplayBuffer, RandomSampler), + (ReplayBuffer, SamplerWithoutReplacement), + (PrioritizedReplayBuffer, SamplerWithoutReplacement), + ] + return [ + (ReplayBuffer, RandomSampler), + (ReplayBuffer, SamplerWithoutReplacement), + (PrioritizedReplayBuffer, None), + (TensorDictReplayBuffer, RandomSampler), + (TensorDictReplayBuffer, SamplerWithoutReplacement), + (TensorDictPrioritizedReplayBuffer, None), + ] + + +class TestRBMultidim: + @tensorclass + class MyData: + x: torch.Tensor + y: torch.Tensor + z: torch.Tensor + + def _make_data(self, datatype, datadim): + if datadim == 1: + shape = [12] + elif datadim == 2: + shape = [4, 3] + else: + raise NotImplementedError + if datatype == "pytree": + return { + "x": (torch.ones(*shape, 2), (torch.ones(*shape, 3))), + "y": [ + {"z": torch.ones(shape)}, + torch.ones((*shape, 1), dtype=torch.bool), + ], + } + elif datatype == "tensordict": + return TensorDict( + {"x": torch.ones(*shape, 2), "y": {"z": torch.ones(*shape, 3)}}, shape + ) + elif datatype == "tensorclass": + return self.MyData( + x=torch.ones(*shape, 2), + y=torch.ones(*shape, 3), + z=torch.ones((*shape, 1), dtype=torch.bool), + batch_size=shape, + ) + + datatype_rb_tuples = [ + [datatype, *rbtype] + for datatype in ["pytree", "tensordict", "tensorclass"] + for rbtype in _rbtype(datatype) + ] + + @pytest.mark.parametrize("datatype,rbtype,sampler_cls", datatype_rb_tuples) + @pytest.mark.parametrize("datadim", [1, 2]) + @pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage]) + def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls): + data = self._make_data(datatype, datadim) + if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer): + rbtype = functools.partial(rbtype, sampler=sampler_cls()) + else: + rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1) + + rb = rbtype(storage=storage_cls(100, ndim=datadim), batch_size=4) + assert str(rb) # check str works + rb.extend(data) + assert str(rb) + assert len(rb) == 12 + data = rb[:] + if datatype in ("tensordict", "tensorclass"): + assert data.numel() == 12 + else: + assert all( + leaf.shape[:datadim].numel() == 12 for leaf in tree_flatten(data)[0] + ) + s = rb.sample() + assert str(rb) + if datatype in ("tensordict", "tensorclass"): + assert (s.exclude("index") == 1).all() + assert s.numel() == 4 + else: + for leaf in tree_iter(s): + assert leaf.shape[0] == 4 + assert (leaf == 1).all() + + @pytest.mark.skipif(not _has_gym, reason="gym required for this test.") + @pytest.mark.parametrize( + "writer_cls", + [TensorDictMaxValueWriter, RoundRobinWriter, TensorDictRoundRobinWriter], + ) + @pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize( + "rbtype", + [ + functools.partial(ReplayBuffer, batch_size=8), + functools.partial(TensorDictReplayBuffer, batch_size=8), + ], + ) + @pytest.mark.parametrize( + "sampler_cls", + [ + functools.partial(SliceSampler, num_slices=2, strict_length=False), + RandomSampler, + functools.partial( + SliceSamplerWithoutReplacement, num_slices=2, strict_length=False + ), + functools.partial(PrioritizedSampler, alpha=1.0, beta=1.0, max_capacity=10), + functools.partial( + PrioritizedSliceSampler, + alpha=1.0, + beta=1.0, + max_capacity=10, + num_slices=2, + strict_length=False, + ), + ], + ) + @pytest.mark.parametrize( + "transform", + [ + None, + [ + lambda: split_trajectories, + functools.partial(MultiStep, gamma=0.9, n_steps=3), + ], + ], + ) + @pytest.mark.parametrize("env_device", get_default_devices()) + def test_rb_multidim_collector( + self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device + ): + torch.manual_seed(0) + env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device) + env.set_seed(0) + collector = Collector( + env, + RandomPolicy(env.action_spec), + frames_per_batch=4, + total_frames=16, + device=env_device, + ) + if writer_cls is TensorDictMaxValueWriter: + with pytest.raises( + ValueError, + match="TensorDictMaxValueWriter is not compatible with storages with more than one dimension", + ): + rb = rbtype( + storage=storage_cls(max_size=10, ndim=2), + sampler=sampler_cls(), + writer=writer_cls(), + delayed_init=False, + ) + return + rb = rbtype( + storage=storage_cls(max_size=10, ndim=2), + sampler=sampler_cls(), + writer=writer_cls(), + ) + if not isinstance(rb.sampler, SliceSampler) and transform is not None: + pytest.skip("no need to test this combination") + if transform: + for t in transform: + rb.append_transform(t()) + try: + for i, data in enumerate(collector): # noqa: B007 + assert data.device == torch.device(env_device) + rb.extend(data) + if isinstance(rb, TensorDictReplayBuffer) and transform is not None: + # this should fail bc we can't set the indices after executing the transform. + with pytest.raises( + RuntimeError, match="Failed to set the metadata" + ): + rb.sample() + return + s = rb.sample() + assert s.device == torch.device("cpu") + rbtot = rb[:] + assert rbtot.shape[0] == 2 + assert len(rb) == rbtot.numel() + if transform is not None: + assert s.ndim == 2 + except Exception: + raise + + @pytest.mark.parametrize("strict_length", [True, False]) + def test_done_slicesampler(self, strict_length): + env = SerialEnv( + 3, + [ + lambda: CountingEnv(max_steps=31).add_truncated_keys(), + lambda: CountingEnv(max_steps=32).add_truncated_keys(), + lambda: CountingEnv(max_steps=33).add_truncated_keys(), + ], + ) + full_action_spec = CountingEnv(max_steps=32).full_action_spec + policy = lambda td: td.update( + full_action_spec.zero((3,)).apply_(lambda x: x + 1) + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(200, ndim=2), + sampler=SliceSampler( + slice_len=32, + strict_length=strict_length, + truncated_key=("next", "truncated"), + ), + batch_size=128, + ) + + # env.add_truncated_keys() + + for i in range(50): + r = env.rollout( + 50, policy=policy, break_when_any_done=False, set_truncated=True + ) + rb.extend(r) + + sample = rb.sample() + + assert sample["next", "done"].sum() == 128 // 32, ( + i, + sample["next", "done"].sum(), + ) + assert (split_trajectories(sample)["next", "done"].sum(-2) == 1).all() + + +@pytest.mark.skipif(not _has_gym, reason="gym required") +class TestCheckpointers: + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize( + "checkpointer", + [FlatStorageCheckpointer, H5StorageCheckpointer, NestedStorageCheckpointer], + ) + @pytest.mark.parametrize("frames_per_batch", [22, 122]) + def test_simple_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): + env = GymEnv(CARTPOLE_VERSIONED(), device=None) + env.set_seed(0) + torch.manual_seed(0) + collector = Collector( + env, + policy=env.rand_step, + total_frames=200, + frames_per_batch=frames_per_batch, + ) + rb = ReplayBuffer(storage=storage_type(100)) + rb_test = ReplayBuffer(storage=storage_type(100)) + if torch.__version__ < "2.4.0.dev" and checkpointer in ( + H5StorageCheckpointer, + NestedStorageCheckpointer, + ): + with pytest.raises(ValueError, match="Unsupported torch version"): + checkpointer() + return + rb.storage.checkpointer = checkpointer() + rb_test.storage.checkpointer = checkpointer() + for data in collector: + rb.extend(data) + rb.dumps(tmpdir) + rb_test.loads(tmpdir) + assert_allclose_td(rb_test[:], rb[:]) + assert rb.writer._cursor == rb_test._writer._cursor + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize("frames_per_batch", [22, 122]) + @pytest.mark.parametrize( + "checkpointer", + [FlatStorageCheckpointer, NestedStorageCheckpointer, H5StorageCheckpointer], + ) + def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): + env = SerialEnv( + 3, + lambda: GymEnv(CARTPOLE_VERSIONED(), device=None).append_transform( + StepCounter() + ), + ) + env.set_seed(0) + torch.manual_seed(0) + collector = Collector( + env, + policy=env.rand_step, + total_frames=200, + frames_per_batch=frames_per_batch, + ) + rb = ReplayBuffer(storage=storage_type(100, ndim=2)) + rb_test = ReplayBuffer(storage=storage_type(100, ndim=2)) + if torch.__version__ < "2.4.0.dev" and checkpointer in ( + H5StorageCheckpointer, + NestedStorageCheckpointer, + ): + with pytest.raises(ValueError, match="Unsupported torch version"): + checkpointer() + return + rb.storage.checkpointer = checkpointer() + rb_test.storage.checkpointer = checkpointer() + for data in collector: + rb.extend(data) + assert rb.storage.max_size == 102 + if frames_per_batch > 100: + assert rb.storage._is_full + assert len(rb) == 102 + # Checks that when writing to the buffer with a batch greater than the total + # size, we get the last step written properly. + assert (rb[:]["next", "step_count"][:, -1] != 0).any() + rb.dumps(tmpdir) + rb.dumps(tmpdir) + rb_test.loads(tmpdir) + assert_allclose_td(rb_test[:], rb[:]) + assert rb.writer._cursor == rb_test._writer._cursor + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/test_prioritized.py b/test/rb/test_prioritized.py new file mode 100644 index 00000000000..9f93a679755 --- /dev/null +++ b/test/rb/test_prioritized.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import functools + +import numpy as np +import pytest +import torch +from tensordict import assert_allclose_td, is_tensorclass, TensorDict + +from torchrl.data import ( + ReplayBuffer, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers import samplers +from torchrl.data.replay_buffers.samplers import PrioritizedSampler +from torchrl.data.replay_buffers.storages import ( + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, +) +from torchrl.testing import get_default_devices, make_tc + + +@pytest.mark.parametrize("priority_key", ["pk", "td_error"]) +@pytest.mark.parametrize("contiguous", [True, False]) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("alpha", [0.0, 0.7]) +def test_ptdrb(priority_key, contiguous, alpha, device): + torch.manual_seed(0) + np.random.seed(0) + rb = TensorDictReplayBuffer( + sampler=samplers.PrioritizedSampler(5, alpha=alpha, beta=0.9), + priority_key=priority_key, + batch_size=5, + ) + td1 = TensorDict( + source={ + "a": torch.randn(3, 1), + priority_key: torch.rand(3, 1) / 10, + "_idx": torch.arange(3).view(3, 1), + }, + batch_size=[3], + device=device, + ) + rb.extend(td1) + s = rb.sample() + assert s.batch_size == torch.Size([5]) + assert (td1[s.get("_idx").squeeze()].get("a") == s.get("a")).all() + assert_allclose_td(td1[s.get("_idx").squeeze()].select("a"), s.select("a")) + + # test replacement + td2 = TensorDict( + source={ + "a": torch.randn(5, 1), + priority_key: torch.rand(5, 1) / 10, + "_idx": torch.arange(5).view(5, 1), + }, + batch_size=[5], + device=device, + ) + rb.extend(td2) + s = rb.sample() + assert s.batch_size == torch.Size([5]) + assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() + assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) + + if ( + alpha == 0.0 + ): # when alpha is 0.0, sampling is uniform, so no need to check priority sampling + return + + # test strong update + # get all indices that match first item + idx = s.get("_idx") + idx_match = (idx == idx[0]).nonzero()[:, 0] + s.set_at_( + priority_key, + torch.ones(idx_match.numel(), 1, device=device) * 100000000, + idx_match, + ) + val = s.get("a")[0] + + idx0 = s.get("_idx")[0] + rb.update_tensordict_priority(s) + s = rb.sample() + assert (val == s.get("a")).sum() >= 1 + torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) + + # test updating values of original td + td2.set_("a", torch.ones_like(td2.get("a"))) + s = rb.sample() + torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) + + +@pytest.mark.gpu +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_cuda_segment_tree_parity(): + ext = pytest.importorskip("torchrl._torchrl") + if not hasattr(ext, "CudaSumSegmentTreeFp32"): + pytest.skip("TorchRL was not built with CUDA segment tree support") + CudaMinSegmentTreeFp32 = ext.CudaMinSegmentTreeFp32 + CudaSumSegmentTreeFp32 = ext.CudaSumSegmentTreeFp32 + MinSegmentTreeFp32 = ext.MinSegmentTreeFp32 + SumSegmentTreeFp32 = ext.SumSegmentTreeFp32 + + device = torch.device("cuda:0") + size = 16 + index = torch.tensor([0, 3, 4, 7, 12, 15], device=device) + value = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0, 32.0], device=device) + + cpu_sum = SumSegmentTreeFp32(size) + cpu_min = MinSegmentTreeFp32(size) + cuda_sum = CudaSumSegmentTreeFp32(size, device) + cuda_min = CudaMinSegmentTreeFp32(size, device) + + cpu_sum[index.cpu()] = value.cpu() + cpu_min[index.cpu()] = value.cpu() + cuda_sum[index] = value + cuda_min[index] = value + + left = torch.tensor([0, 3, 4, 7], device=device) + right = torch.tensor([16, 8, 13, 16], device=device) + torch.testing.assert_close( + cuda_sum.query(left, right).cpu(), cpu_sum.query(left.cpu(), right.cpu()) + ) + torch.testing.assert_close( + cuda_min.query(left, right).cpu(), cpu_min.query(left.cpu(), right.cpu()) + ) + + mass = torch.tensor([0.5, 1.0, 2.9, 7.1, 30.0], device=device) + torch.testing.assert_close( + cuda_sum.scan_lower_bound(mass).cpu(), + cpu_sum.scan_lower_bound(mass.cpu()), + ) + + +@pytest.mark.gpu +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_cuda_prioritized_replay_buffer_samples_on_cuda(): + ext = pytest.importorskip("torchrl._torchrl") + if not hasattr(ext, "CudaSumSegmentTreeFp32"): + pytest.skip("TorchRL was not built with CUDA segment tree support") + device = torch.device("cuda:0") + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(32, device=device), + sampler=PrioritizedSampler(max_capacity=32, alpha=0.7, beta=0.5), + batch_size=8, + priority_key="td_error", + ) + data = TensorDict( + { + "obs": torch.arange(16, device=device).float().unsqueeze(-1), + "td_error": torch.linspace(0.1, 1.0, 16, device=device), + }, + batch_size=[16], + device=device, + ) + + rb.extend(data) + sample = rb.sample() + + assert sample.device == device + assert sample["index"].device == device + assert sample["priority_weight"].device == device + + sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 + rb.update_tensordict_priority(sample) + sample = rb.sample() + assert sample["index"].device == device + assert sample["priority_weight"].device == device + + +def test_tensordict_prioritized_replay_buffer_sampler_device_cpu(): + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + storage=LazyTensorStorage(32), + sampler_device="cpu", + batch_size=8, + priority_key="td_error", + ) + data = TensorDict( + { + "obs": torch.arange(16).float().unsqueeze(-1), + "td_error": torch.linspace(0.1, 1.0, 16), + }, + batch_size=[16], + ) + + rb.extend(data) + sample = rb.sample() + + assert rb._sampler.device == torch.device("cpu") + assert sample["index"].device == torch.device("cpu") + assert sample["priority_weight"].device == torch.device("cpu") + + +@pytest.mark.gpu +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_tensordict_prioritized_replay_buffer_memmap_storage_cuda_sampler(tmpdir): + ext = pytest.importorskip("torchrl._torchrl") + if not hasattr(ext, "CudaSumSegmentTreeFp32"): + pytest.skip("TorchRL was not built with CUDA segment tree support") + + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + storage=LazyMemmapStorage(32, scratch_dir=tmpdir), + sampler_device="cuda:0", + batch_size=8, + priority_key="td_error", + ) + data = TensorDict( + { + "obs": torch.arange(16).float().unsqueeze(-1), + "td_error": torch.linspace(0.1, 1.0, 16), + }, + batch_size=[16], + ) + + rb.extend(data) + sample = rb.sample() + + assert rb._sampler.device == torch.device("cuda:0") + assert sample["obs"].device.type == "cpu" + assert sample["index"].device.type == "cpu" + assert sample["priority_weight"].device.type == "cpu" + + sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 + rb.update_tensordict_priority(sample) + sample = rb.sample() + assert sample["index"].device.type == "cpu" + assert rb._sampler.device == torch.device("cuda:0") + + +@pytest.mark.gpu +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_tensordict_prioritized_replay_buffer_cuda_storage_cpu_sampler(): + device = torch.device("cuda:0") + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + storage=LazyTensorStorage(32, device=device), + sampler_device="cpu", + batch_size=8, + priority_key="td_error", + ) + data = TensorDict( + { + "obs": torch.arange(16, device=device).float().unsqueeze(-1), + "td_error": torch.linspace(0.1, 1.0, 16, device=device), + }, + batch_size=[16], + device=device, + ) + + rb.extend(data) + sample = rb.sample() + + assert rb._sampler.device == torch.device("cpu") + assert sample.device == device + assert sample["index"].device == device + assert sample["priority_weight"].device == device + + sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 + rb.update_tensordict_priority(sample) + sample = rb.sample() + assert sample.device == device + assert rb._sampler.device == torch.device("cpu") + + +@pytest.mark.gpu +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_cuda_prioritized_replay_buffer_weight_matches_cpu_formula(): + ext = pytest.importorskip("torchrl._torchrl") + if not hasattr(ext, "CudaSumSegmentTreeFp32"): + pytest.skip("TorchRL was not built with CUDA segment tree support") + + size = 64 + batch_size = 16 + alpha = 0.7 + beta = 0.5 + eps = 1e-8 + priorities = torch.linspace(0.1, 2.0, size) + expected_tree_priority = (priorities + eps).pow(alpha) + min_tree_priority = expected_tree_priority.min() + + def make_rb(device): + device = torch.device(device) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=PrioritizedSampler( + max_capacity=size, + alpha=alpha, + beta=beta, + eps=eps, + device=device, + ), + batch_size=batch_size, + priority_key="td_error", + ) + data = TensorDict( + { + "obs": torch.arange(size, device=device), + "td_error": priorities.to(device), + }, + batch_size=[size], + device=device, + ) + rb.extend(data) + return rb + + cpu_rb = make_rb("cpu") + cuda_rb = make_rb("cuda:0") + + for rb, device in ( + (cpu_rb, torch.device("cpu")), + (cuda_rb, torch.device("cuda:0")), + ): + for _ in range(8): + sample = rb.sample() + index = sample["index"].to("cpu") + expected_weight = (expected_tree_priority[index] / min_tree_priority).pow( + -beta + ) + torch.testing.assert_close(sample["obs"].to("cpu"), index) + torch.testing.assert_close(sample["td_error"].to("cpu"), priorities[index]) + torch.testing.assert_close( + sample["priority_weight"].to("cpu"), expected_weight + ) + assert sample["index"].device == device + assert sample["priority_weight"].device == device + + +@pytest.mark.parametrize("stack", [False, True]) +@pytest.mark.parametrize("datatype", ["tc", "tb"]) +@pytest.mark.parametrize("reduction", ["min", "max", "median", "mean"]) +def test_replay_buffer_trajectories(stack, reduction, datatype): + traj_td = TensorDict( + {"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)}, + batch_size=[3, 4], + ) + rbcls = functools.partial(TensorDictReplayBuffer, priority_key="td_error") + if datatype == "tc": + c = make_tc(traj_td) + rbcls = functools.partial(ReplayBuffer, storage=LazyTensorStorage(100)) + traj_td = c(**traj_td, batch_size=traj_td.batch_size) + assert is_tensorclass(traj_td) + elif datatype != "tb": + raise NotImplementedError + + if stack: + traj_td = torch.stack(list(traj_td), 0) + + rb = rbcls( + sampler=samplers.PrioritizedSampler( + 5, + alpha=0.7, + beta=0.9, + reduction=reduction, + ), + batch_size=3, + ) + rb.extend(traj_td) + if datatype == "tc": + sampled_td, info = rb.sample(return_info=True) + index = info["index"] + else: + sampled_td = rb.sample() + if datatype == "tc": + assert is_tensorclass(traj_td) + return + + sampled_td.set("td_error", torch.rand(sampled_td.shape)) + if datatype == "tc": + rb.update_priority(index, sampled_td) + sampled_td, info = rb.sample(return_info=True) + assert (info["priority_weight"] > 0).all() + assert sampled_td.batch_size == torch.Size([3, 4]) + else: + rb.update_tensordict_priority(sampled_td) + sampled_td = rb.sample(include_info=True) + assert (sampled_td.get("priority_weight") > 0).all() + assert sampled_td.batch_size == torch.Size([3, 4]) + + # # set back the trajectory length + # sampled_td_filtered = sampled_td.to_tensordict().exclude( + # "priority_weight", "index", "td_error" + # ) + # sampled_td_filtered.batch_size = [3, 4] + + +@pytest.mark.parametrize("priority_key", ["pk", "td_error"]) +@pytest.mark.parametrize("contiguous", [True, False]) +@pytest.mark.parametrize("device", get_default_devices()) +def test_prb(priority_key, contiguous, device): + torch.manual_seed(0) + np.random.seed(0) + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + priority_key=priority_key, + storage=ListStorage(5), + batch_size=5, + ) + td1 = TensorDict( + source={ + "a": torch.randn(3, 1), + priority_key: torch.rand(3, 1) / 10, + "_idx": torch.arange(3).view(3, 1), + }, + batch_size=[3], + ).to(device) + + rb.extend(td1) + s = rb.sample() + assert s.batch_size == torch.Size([5]) + assert (td1[s.get("_idx").squeeze()].get("a") == s.get("a")).all() + assert_allclose_td(td1[s.get("_idx").squeeze()].select("a"), s.select("a")) + + # test replacement + td2 = TensorDict( + source={ + "a": torch.randn(5, 1), + priority_key: torch.rand(5, 1) / 10, + "_idx": torch.arange(5).view(5, 1), + }, + batch_size=[5], + ).to(device) + rb.extend(td2) + s = rb.sample() + assert s.batch_size == torch.Size([5]) + assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() + assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) + + # test strong update + # get all indices that match first item + idx = s.get("_idx") + idx_match = (idx == idx[0]).nonzero()[:, 0] + s.set_at_( + priority_key, + torch.ones(idx_match.numel(), 1, device=device) * 100000000, + idx_match, + ) + val = s.get("a")[0] + + idx0 = s.get("_idx")[0] + rb.update_tensordict_priority(s) + s = rb.sample() + assert (val == s.get("a")).sum() >= 1 + torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) + + # test updating values of original td + td2.set_("a", torch.ones_like(td2.get("a"))) + s = rb.sample() + torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/test_rb_core.py b/test/rb/test_rb_core.py new file mode 100644 index 00000000000..c5cdf0661b8 --- /dev/null +++ b/test/rb/test_rb_core.py @@ -0,0 +1,628 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import contextlib +import functools + +import pytest +import torch +import torchrl +from _rb_common import OLD_TORCH, ReplayBufferRNG, TensorDictReplayBufferRNG +from tensordict import assert_allclose_td, TensorDict, TensorDictBase + +from torchrl._utils import rl_warnings +from torchrl.data import ( + PrioritizedReplayBuffer, + ReplayBuffer, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import ( + PrioritizedSampler, + RandomSampler, + SamplerWithoutReplacement, + SliceSampler, +) + +from torchrl.data.replay_buffers.storages import ( + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, + TensorStorage, +) +from torchrl.data.replay_buffers.writers import RoundRobinWriter + + +def test_replay_buffer_read_write_all_in_order(): + rb = TensorDictReplayBuffer(storage=LazyTensorStorage(6)) + rb_slice = TensorDictReplayBuffer(storage=LazyTensorStorage(6)) + data = TensorDict({"obs": torch.arange(6), "reward": torch.zeros(6)}, [6]) + rb.extend(data) + rb_slice.extend(data.clone()) + + all_data = rb.read_all_in_order() + assert_allclose_td(all_data, rb[:]) + assert all_data["obs"].tolist() == list(range(6)) + all_data["value_target"] = all_data["obs"] + 1 + rb.write_all(all_data) + rb_slice[:] = all_data.clone() + + updated = rb.read_all_in_order() + assert_allclose_td(updated, rb[:]) + assert_allclose_td(updated, rb_slice[:]) + assert updated["value_target"].tolist() == list(range(1, 7)) + + +def test_replay_buffer_read_write_all_in_order_with_end(): + rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10)) + rb_slice = TensorDictReplayBuffer(storage=LazyTensorStorage(10)) + rb.extend(TensorDict({"obs": torch.arange(6)}, [6])) + rb_slice.extend(TensorDict({"obs": torch.arange(6)}, [6])) + + partial = rb.read_all_in_order(end=3) + assert_allclose_td(partial, rb[:3]) + partial["obs"] = partial["obs"] + 10 + rb.write_all(partial, end=3) + rb_slice[:3] = partial.clone() + + updated = rb.read_all_in_order() + assert_allclose_td(updated, rb_slice[:]) + assert updated["obs"].tolist() == [10, 11, 12, 3, 4, 5] + + +def test_replay_buffer_read_write_all_in_order_matches_full_slice_ndim2(): + rb = TensorDictReplayBuffer(storage=LazyTensorStorage(6, ndim=2)) + rb_slice = TensorDictReplayBuffer(storage=LazyTensorStorage(6, ndim=2)) + data = TensorDict( + {"obs": torch.arange(6).reshape(2, 3), "reward": torch.zeros(2, 3)}, + [2, 3], + ) + rb.extend(data) + rb_slice.extend(data.clone()) + + all_data = rb.read_all_in_order() + assert_allclose_td(all_data, rb[:]) + all_data["value_target"] = all_data["obs"] + 1 + rb.write_all(all_data) + rb_slice[:] = all_data.clone() + + assert_allclose_td(rb.read_all_in_order(), rb[:]) + assert_allclose_td(rb.read_all_in_order(), rb_slice[:]) + + +class TestRNG: + def test_rb_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG( + sampler=RandomSampler(), storage=LazyTensorStorage(100), delayed_init=False + ) + assert rb.initialized + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + a = rb.sample(32) + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_prb_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBuffer( + sampler=PrioritizedSampler(100, 1.0, 1.0), + storage=LazyTensorStorage(100), + generator=torch.Generator(), + ) + rb.extend(torch.arange(100)) + rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101)) + + rb._rng.set_state(state) + a = rb.sample(32) + + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + + c = rb.sample(32) + assert (a != c).any() + + def test_slice_rng(self): + state = torch.random.get_rng_state() + rb = ReplayBuffer( + sampler=SliceSampler(num_slices=4), + storage=LazyTensorStorage(100), + generator=torch.Generator(), + ) + done = torch.zeros(100, 1, dtype=torch.bool) + done[49] = 1 + done[-1] = 1 + data = TensorDict( + { + "data": torch.arange(100), + ("next", "done"): done, + }, + batch_size=[100], + ) + rb.extend(data) + + rb._rng.set_state(state) + a = rb.sample(32) + + rb._rng.set_state(state) + b = rb.sample(32) + assert (a == b).all() + + c = rb.sample(32) + assert (a != c).any() + + def test_rng_state_dict(self): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + sd = rb.state_dict() + assert sd.get("_rng") is not None + a = rb.sample(32) + + rb.load_state_dict(sd) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + def test_rng_dumps(self, tmpdir): + state = torch.random.get_rng_state() + rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) + rb.extend(torch.arange(100)) + rb._rng.set_state(state) + rb.dumps(tmpdir) + a = rb.sample(32) + + rb.loads(tmpdir) + b = rb.sample(32) + assert (a == b).all() + c = rb.sample(32) + assert (a != c).any() + + +@pytest.mark.parametrize( + "rbtype,storage", + [ + (ReplayBuffer, None), + (ReplayBuffer, ListStorage), + (ReplayBufferRNG, None), + (ReplayBufferRNG, ListStorage), + (PrioritizedReplayBuffer, None), + (PrioritizedReplayBuffer, ListStorage), + (TensorDictReplayBuffer, None), + (TensorDictReplayBuffer, ListStorage), + (TensorDictReplayBuffer, LazyTensorStorage), + (TensorDictReplayBuffer, LazyMemmapStorage), + (TensorDictReplayBufferRNG, None), + (TensorDictReplayBufferRNG, ListStorage), + (TensorDictReplayBufferRNG, LazyTensorStorage), + (TensorDictReplayBufferRNG, LazyMemmapStorage), + (TensorDictPrioritizedReplayBuffer, None), + (TensorDictPrioritizedReplayBuffer, ListStorage), + (TensorDictPrioritizedReplayBuffer, LazyTensorStorage), + (TensorDictPrioritizedReplayBuffer, LazyMemmapStorage), + ], +) +@pytest.mark.parametrize("size", [3, 5, 100]) +@pytest.mark.parametrize("prefetch", [0]) +class TestBuffers: + default_constr = { + ReplayBuffer: ReplayBuffer, + PrioritizedReplayBuffer: functools.partial( + PrioritizedReplayBuffer, alpha=0.8, beta=0.9 + ), + TensorDictReplayBuffer: TensorDictReplayBuffer, + TensorDictPrioritizedReplayBuffer: functools.partial( + TensorDictPrioritizedReplayBuffer, alpha=0.8, beta=0.9 + ), + TensorDictReplayBufferRNG: TensorDictReplayBufferRNG, + ReplayBufferRNG: ReplayBufferRNG, + } + + def _get_rb(self, rbtype, size, storage, prefetch): + if storage is not None: + storage = storage(size) + rb = self.default_constr[rbtype]( + storage=storage, prefetch=prefetch, batch_size=3 + ) + return rb + + def _get_datum(self, rbtype): + if rbtype in (ReplayBuffer, ReplayBufferRNG): + data = torch.randint(100, (1,)) + elif rbtype is PrioritizedReplayBuffer: + data = torch.randint(100, (1,)) + elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): + data = TensorDict({"a": torch.randint(100, (1,))}, []) + elif rbtype is TensorDictPrioritizedReplayBuffer: + data = TensorDict({"a": torch.randint(100, (1,))}, []) + else: + raise NotImplementedError(rbtype) + return data + + def _get_data(self, rbtype, size): + if rbtype in (ReplayBuffer, ReplayBufferRNG): + data = [torch.randint(100, (1,)) for _ in range(size)] + elif rbtype is PrioritizedReplayBuffer: + data = [torch.randint(100, (1,)) for _ in range(size)] + elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): + data = TensorDict( + { + "a": torch.randint(100, (size,)), + "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), + }, + [size], + ) + elif rbtype is TensorDictPrioritizedReplayBuffer: + data = TensorDict( + { + "a": torch.randint(100, (size,)), + "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), + }, + [size], + ) + else: + raise NotImplementedError(rbtype) + return data + + def test_cursor_position2(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + batch1 = self._get_data(rbtype, size=5) + cond = ( + OLD_TORCH and size < len(batch1) and isinstance(rb.storage, TensorStorage) + ) + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(batch1) + + # Added fewer data than storage max size + if size > 5 or storage is None: + assert rb.writer._cursor == 5 + # Added more data than storage max size + elif size < 5: + assert rb.writer._cursor == 5 - size + # Added as data as storage max size + else: + assert rb.writer._cursor == 0 + batch2 = self._get_data(rbtype, size=size - 1) + rb.extend(batch2) + assert rb.writer._cursor == size - 1 + + def test_add(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_datum(rbtype) + rb.add(data) + s = rb.sample(1)[0] + if isinstance(s, TensorDictBase): + s = s.select(*data.keys(True), strict=False) + data = data.select(*s.keys(True), strict=False) + assert (s == data).all() + assert list(s.keys(True, True)) + else: + assert (s == data).all() + + def test_empty(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_datum(rbtype) + for _ in range(2): + rb.add(data) + s = rb.sample(1)[0] + if isinstance(s, TensorDictBase): + s = s.select(*data.keys(True), strict=False) + data = data.select(*s.keys(True), strict=False) + assert (s == data).all() + assert list(s.keys(True, True)) + else: + assert (s == data).all() + rb.empty() + with pytest.raises( + RuntimeError, match="Cannot sample from an empty storage" + ): + rb.sample() + + def test_extend(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_data(rbtype, size=5) + cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data) + length = len(rb) + for d in data[-length:]: + for b in rb.storage: + if isinstance(b, TensorDictBase): + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) + + value = b == d + if isinstance(value, (torch.Tensor, TensorDictBase)): + value = value.all() + if value: + break + else: + raise RuntimeError("did not find match") + + def test_sample(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_data(rbtype, size=5) + cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data) + new_data = rb.sample() + if not isinstance(new_data, (torch.Tensor, TensorDictBase)): + new_data = new_data[0] + + for d in new_data: + for b in data: + if isinstance(b, TensorDictBase): + keys = set(d.keys()).intersection(b.keys()) + b = b.exclude("index").select(*keys, strict=False) + keys = set(d.keys()).intersection(b.keys()) + d = d.select(*keys, strict=False) + + value = b == d + if isinstance(value, (torch.Tensor, TensorDictBase)): + value = value.all() + if value: + break + else: + raise RuntimeError("did not find matching value") + + def test_index(self, rbtype, storage, size, prefetch): + torch.manual_seed(0) + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_data(rbtype, size=5) + cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) + with ( + pytest.warns( + UserWarning, + match="A cursor of length superior to the storage capacity was provided", + ) + if cond + else contextlib.nullcontext() + ): + rb.extend(data) + d1 = rb[2] + d2 = rb.storage[2] + if type(d1) is not type(d2): + d1 = d1[0] + b = d1 == d2 + if not isinstance(b, bool): + b = b.all() + assert b + + def test_index_nonfull(self, rbtype, storage, size, prefetch): + # checks that indexing the buffer before it's full gives the accurate view of the data + rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) + data = self._get_data(rbtype, size=size - 1) + rb.extend(data) + assert len(rb[: size - 1]) == size - 1 + assert len(rb[size - 2 :]) == 1 + + +def test_replay_buffer_set_at_(): + """Tests that set_at_ writes through to storage in-place.""" + rb = ReplayBuffer( + storage=LazyTensorStorage(10), + batch_size=5, + ) + data = TensorDict({"a": torch.zeros(10), "b": torch.ones(10)}, batch_size=[10]) + rb.extend(data) + # Modify key "a" at indices [2, 5] + rb.set_at_("a", torch.tensor([99.0, 99.0]), torch.tensor([2, 5])) + assert rb["a"][2] == 99.0 + assert rb["a"][5] == 99.0 + assert rb["a"][0] == 0.0 # unchanged + assert rb["b"][2] == 1.0 # other key unchanged + + +def test_replay_buffer_set_(): + """Tests that set_ writes through to storage in-place.""" + rb = ReplayBuffer( + storage=LazyTensorStorage(10), + batch_size=5, + ) + data = TensorDict({"a": torch.zeros(10), "b": torch.ones(10)}, batch_size=[10]) + rb.extend(data) + rb.set_("a", torch.full((10,), 42.0)) + assert (rb["a"] == 42.0).all() + assert (rb["b"] == 1.0).all() # other key unchanged + + +def test_replay_buffer_update_(): + """Tests that update_ writes through to storage in-place.""" + rb = ReplayBuffer( + storage=LazyTensorStorage(10), + batch_size=5, + ) + data = TensorDict({"a": torch.zeros(10), "b": torch.ones(10)}, batch_size=[10]) + rb.extend(data) + update = TensorDict( + {"a": torch.full((10,), 7.0), "b": torch.full((10,), 8.0)}, + batch_size=[10], + ) + rb.update_(update) + assert (rb["a"] == 7.0).all() + assert (rb["b"] == 8.0).all() + + +def test_multi_loops(): + """Tests that one can iterate multiple times over a buffer without rep.""" + rb = ReplayBuffer( + batch_size=5, storage=ListStorage(10), sampler=SamplerWithoutReplacement() + ) + rb.extend(torch.zeros(10)) + for i, d in enumerate(rb): # noqa: B007 + assert (d == 0).all() + assert i == 1 + for i, d in enumerate(rb): # noqa: B007 + assert (d == 0).all() + assert i == 1 + + +def test_batch_errors(): + """Tests error messages related to batch-size""" + rb = ReplayBuffer( + storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=False) + ) + rb.extend(torch.zeros(10)) + rb.sample(3) # that works + with pytest.raises( + RuntimeError, + match="Cannot iterate over the replay buffer. Batch_size was not specified", + ): + for _ in rb: + pass + with pytest.raises(RuntimeError, match="batch_size not specified"): + rb.sample() + with pytest.raises(ValueError, match="Samplers with drop_last=True"): + ReplayBuffer( + storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=True) + ) + # that works + ReplayBuffer( + storage=ListStorage(10), + ) + rb = ReplayBuffer( + storage=ListStorage(10), + sampler=SamplerWithoutReplacement(drop_last=False), + batch_size=3, + ) + rb.extend(torch.zeros(10)) + for _ in rb: + pass + rb.sample() + + +@pytest.mark.skipif(not torchrl._utils.RL_WARNINGS, reason="RL_WARNINGS is not set") +def test_add_warning(): + if not rl_warnings(): + return + rb = ReplayBuffer(storage=ListStorage(10), batch_size=3) + with pytest.warns( + UserWarning, + match=r"Using `add\(\)` with a TensorDict that has batch_size", + ): + rb.add(TensorDict(batch_size=[1])) + + +@pytest.mark.parametrize("stack", [False, True]) +@pytest.mark.parametrize("reduction", ["min", "max", "mean", "median"]) +def test_rb_trajectories(stack, reduction): + traj_td = TensorDict( + {"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)}, + batch_size=[3, 4], + ) + if stack: + traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0) + + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + priority_key="td_error", + storage=ListStorage(5), + batch_size=3, + ) + rb.extend(traj_td) + sampled_td = rb.sample() + sampled_td.set("td_error", torch.rand(3, 4)) + rb.update_tensordict_priority(sampled_td) + sampled_td = rb.sample(include_info=True) + assert (sampled_td.get("priority_weight") > 0).all() + assert sampled_td.batch_size == torch.Size([3, 4]) + + # set back the trajectory length + sampled_td_filtered = sampled_td.to_tensordict().exclude( + "priority_weight", "index", "td_error" + ) + sampled_td_filtered.batch_size = [3, 4] + + +def test_shared_storage_prioritized_sampler(): + n = 100 + + storage = LazyMemmapStorage(n) + writer = RoundRobinWriter() + sampler0 = RandomSampler() + sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1) + + rb0 = ReplayBuffer(storage=storage, writer=writer, sampler=sampler0, batch_size=10) + rb1 = ReplayBuffer(storage=storage, writer=writer, sampler=sampler1, batch_size=10) + + data = TensorDict({"a": torch.arange(50)}, [50]) + + # Extend rb0. rb1 should be aware of changes to storage. + rb0.extend(data) + + assert len(rb0) == 50 + assert len(storage) == 50 + assert len(rb1) == 50 + + rb0.sample() + rb1.sample() + + assert rb1._sampler._sum_tree.query(0, 10) == 10 + assert rb1._sampler._sum_tree.query(0, 50) == 50 + assert rb1._sampler._sum_tree.query(0, 70) == 50 + + +@pytest.mark.parametrize("size", [10, 15, 20]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_replay_buffer_iter(size, drop_last): + torch.manual_seed(0) + storage = ListStorage(size) + sampler = SamplerWithoutReplacement(drop_last=drop_last) + writer = RoundRobinWriter() + + rb = ReplayBuffer(storage=storage, sampler=sampler, writer=writer, batch_size=3) + rb.extend([torch.randint(100, (1,)) for _ in range(size)]) + + for i, _ in enumerate(rb): + if i == 20: + # guard against infinite loop if error is introduced + raise RuntimeError("Iteration didn't terminate") + + if drop_last: + assert i == size // 3 - 1 + else: + assert i == (size - 1) // 3 + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_rb_distributed.py b/test/rb/test_rb_distributed.py similarity index 62% rename from test/test_rb_distributed.py rename to test/rb/test_rb_distributed.py index c3ecf7897e8..62277039de5 100644 --- a/test/test_rb_distributed.py +++ b/test/rb/test_rb_distributed.py @@ -9,16 +9,22 @@ import sys import time +from functools import partial import pytest import torch import torch.distributed.rpc as rpc import torch.multiprocessing as mp +from _rb_common import _has_ray from tensordict import TensorDict from torchrl._utils import logger as torchrl_logger +from torchrl.data import RayReplayBuffer from torchrl.data.replay_buffers import RemoteTensorDictReplayBuffer -from torchrl.data.replay_buffers.samplers import RandomSampler -from torchrl.data.replay_buffers.storages import LazyMemmapStorage +from torchrl.data.replay_buffers.samplers import ( + RandomSampler, + SamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage from torchrl.data.replay_buffers.writers import RoundRobinWriter RETRY_COUNT = 3 @@ -139,6 +145,84 @@ def _sample_from_buffer(buffer, batch_size): ) +@pytest.mark.skipif(not _has_ray, reason="ray required for this test.") +class TestRayRB: + @pytest.fixture(autouse=True, scope="module") + def cleanup(self): + import ray + + ray.shutdown() + torchrl_logger.info("Initializing Ray.") + ray.init(num_cpus=1) + yield + torchrl_logger.info("Shutting down Ray.") + ray.shutdown() + + def test_ray_rb(self): + rb = RayReplayBuffer( + storage=partial(LazyTensorStorage, 100), ray_init_config={"num_cpus": 1} + ) + try: + rb.extend( + TensorDict( + {"x": torch.ones(100, 2), "y": torch.ones(100, 2)}, batch_size=100 + ) + ) + assert rb.write_count == 100 + assert len(rb) == 100 + assert rb.sample(2).shape == (2,) + finally: + rb.close() + + def test_ray_rb_iter(self): + rb = RayReplayBuffer( + storage=partial(LazyTensorStorage, 100), + ray_init_config={"num_cpus": 1}, + sampler=SamplerWithoutReplacement, + batch_size=25, + ) + try: + rb.extend( + TensorDict( + { + "x": torch.ones( + 100, + ), + "y": torch.ones( + 100, + ), + }, + batch_size=100, + ) + ) + for _ in range(2): + for d in rb: + torchrl_logger.info(f"d: {d}") + assert d is not None + assert d.shape == (25,) + finally: + rb.close() + + def test_ray_rb_serialization(self): + import ray + + class Worker: + def __init__(self, rb): + self.rb = rb + + def run(self): + self.rb.extend(TensorDict({"x": torch.ones(100)}, batch_size=100)) + + rb = RayReplayBuffer( + storage=partial(LazyTensorStorage, 100), ray_init_config={"num_cpus": 1} + ) + try: + remote_worker = ray.remote(Worker).remote(rb) + ray.get(remote_worker.run.remote()) + finally: + rb.close() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/test_rb_transforms.py b/test/rb/test_rb_transforms.py new file mode 100644 index 00000000000..45b0c11ec9f --- /dev/null +++ b/test/rb/test_rb_transforms.py @@ -0,0 +1,516 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +from functools import partial +from unittest import mock + +import pytest +import torch +from _rb_common import _has_tv +from tensordict import TensorDict +from torch.utils._pytree import tree_map + +from torchrl.data import ReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import RandomSampler, SliceSampler +from torchrl.data.replay_buffers.storages import LazyMemmapStorage, LazyTensorStorage +from torchrl.envs.transforms import NextStateReconstructor +from torchrl.envs.transforms.transforms import ( + BinarizeReward, + CatFrames, + CatTensors, + CenterCrop, + DiscreteActionProjection, + DoubleToFloat, + FiniteTensorDictCheck, + FlattenObservation, + GrayScale, + gSDENoise, + ObservationNorm, + PinMemoryTransform, + Resize, + RewardClipping, + RewardScaling, + SqueezeTransform, + ToTensorImage, + UnsqueezeTransform, + VecNorm, +) + + +class TestTransforms: + def test_append_transform(self): + rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), batch_size=1) + td = TensorDict( + { + "observation": torch.randn(2, 4, 3, 16), + "observation2": torch.randn(2, 4, 3, 16), + }, + [], + ) + rb.add(td) + flatten = CatTensors( + in_keys=["observation", "observation2"], out_key="observation_cat" + ) + + rb.append_transform(flatten) + + sampled = rb.sample() + assert sampled.get("observation_cat").shape[-1] == 32 + + def test_init_transform(self): + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) + + rb = ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 + ) + + td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) + rb.add(td) + sampled = rb.sample() + assert sampled.get("flattened").shape[-1] == 48 + + def test_insert_transform(self): + flatten = FlattenObservation( + -2, -1, in_keys=["observation"], out_keys=["flattened"] + ) + rb = ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 + ) + td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) + rb.add(td) + + rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) + + sampled = rb.sample() + assert sampled.get("flattened").shape[-1] == 48 + + with pytest.raises(ValueError): + rb.insert_transform(10, SqueezeTransform(-1, in_keys=["observation"])) + + transforms = [ + ToTensorImage, + pytest.param( + partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping" + ), + BinarizeReward, + pytest.param( + partial(Resize, w=2, h=2), + id="Resize", + marks=pytest.mark.skipif( + not _has_tv, reason="needs torchvision dependency" + ), + ), + pytest.param( + partial(CenterCrop, w=1), + id="CenterCrop", + marks=pytest.mark.skipif( + not _has_tv, reason="needs torchvision dependency" + ), + ), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), + GrayScale, + pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), + pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), + pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), + DoubleToFloat, + VecNorm, + ] + + @pytest.mark.parametrize("transform", transforms) + def test_smoke_replay_buffer_transform(self, transform): + rb = TensorDictReplayBuffer( + transform=transform(in_keys=["observation"]), batch_size=1 + ) + + # td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, []) + td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 3)}, []) + rb.add(td) + + m = mock.Mock() + m.side_effect = [td.unsqueeze(0)] + rb._transform.forward = m + # rb._transform.__len__ = lambda *args: 3 + rb.sample() + assert rb._transform.forward.called + + # was_called = [False] + # forward = rb._transform.forward + # def new_forward(*args, **kwargs): + # was_called[0] = True + # return forward(*args, **kwargs) + # rb._transform.forward = new_forward + # rb.sample() + # assert was_called[0] + + transforms2 = [ + partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3), + FiniteTensorDictCheck, + gSDENoise, + PinMemoryTransform, + ] + + @pytest.mark.parametrize("transform", transforms2) + def test_smoke_replay_buffer_transform_no_inkeys(self, transform): + if transform == PinMemoryTransform and not torch.cuda.is_available(): + raise pytest.skip("No CUDA device detected, skipping PinMemory") + rb = ReplayBuffer( + collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1 + ) + + action = torch.zeros(3) + action[..., 0] = 1 + td = TensorDict( + {"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, [] + ) + rb.add(td) + rb.sample() + + rb._transform = mock.MagicMock() + rb._transform.__len__ = lambda *args: 3 + rb.sample() + assert rb._transform.called + + @pytest.mark.parametrize("at_init", [True, False]) + def test_transform_nontensor(self, at_init): + def t(x): + return tree_map(lambda y: y * 0, x) + + if at_init: + rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=t) + else: + rb = ReplayBuffer(storage=LazyMemmapStorage(100)) + rb.append_transform(t) + data = { + "a": torch.randn(3), + "b": {"c": (torch.zeros(2), [torch.ones(1)])}, + 30: -torch.ones(()), + } + rb.add(data) + + def assert0(x): + assert (x == 0).all() + + s = rb.sample(10) + tree_map(assert0, s) + + def test_transform_inv(self): + rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4) + data = TensorDict({"a": torch.zeros(10)}, [10]) + + def t(data): + data += 1 + return data + + rb.append_transform(t, invert=True) + rb.extend(data) + assert (data == 1).all() + + +class TestNextStateReconstructor: + """Tests for :class:`~torchrl.envs.transforms.NextStateReconstructor`.""" + + _DEFAULT_TRAJ_KEY = ("collector", "traj_ids") + + @classmethod + def _make_data( + cls, + n_traj=3, + traj_len=4, + obs_dim=2, + traj_key: tuple | str | None = None, + ): + if traj_key is None: + traj_key = cls._DEFAULT_TRAJ_KEY + n = n_traj * traj_len + obs = torch.arange(n * obs_dim, dtype=torch.float32).reshape(n, obs_dim) + done = torch.zeros(n, 1, dtype=torch.bool) + done[traj_len - 1 :: traj_len] = True + traj_ids = torch.repeat_interleave(torch.arange(n_traj), traj_len) + return TensorDict( + { + "observation": obs, + ("next", "done"): done, + ("next", "reward"): torch.zeros(n, 1), + traj_key: traj_ids, + }, + batch_size=[n], + ) + + def test_slice_sampler_default(self): + """With ``SliceSampler`` + default ``traj_key``, slices mirror cleanly.""" + data = self._make_data(n_traj=3, traj_len=4) + rb = ReplayBuffer( + storage=LazyTensorStorage(12), + sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), + transform=NextStateReconstructor(), + batch_size=8, + ) + rb.extend(data) + sample = rb.sample() + assert sample.batch_size == torch.Size([8]) + next_obs = sample.get(("next", "observation")) + root_obs = sample.get("observation") + traj = sample.get(self._DEFAULT_TRAJ_KEY) + # Within each slice (4 entries), positions 0..2 mirror to 1..3 of the same traj. + for slice_start in (0, 4): + assert (traj[slice_start : slice_start + 4] == traj[slice_start]).all() + for i in range(3): + torch.testing.assert_close( + next_obs[slice_start + i], root_obs[slice_start + i + 1] + ) + # Last position of each slice belongs to a different trajectory + # in the (i, i+1) pair (or has no i+1 at all) → NaN. + assert torch.isnan(next_obs[slice_start + 3]).all() + + def test_single_trajectory_full_batch(self): + """Whole trajectory as one batch: every transition reconstructed, last NaN.""" + n = 6 + td = TensorDict( + { + "observation": torch.arange(n, dtype=torch.float32).view(n, 1), + self._DEFAULT_TRAJ_KEY: torch.zeros(n, dtype=torch.long), + # No terminal in the middle; explicit final done for completeness. + ("next", "done"): torch.tensor([[False]] * (n - 1) + [[True]]), + }, + batch_size=[n], + ) + out = NextStateReconstructor()(td) + next_obs = out.get(("next", "observation")) + torch.testing.assert_close(next_obs[:-1], td.get("observation")[1:]) + assert torch.isnan(next_obs[-1]).all() + + def test_done_catches_slice_repetition(self): + """SliceSampler can place two slices of the same trajectory in one batch. + + Trajectory ids match across the splice; ``done`` at the slice end of the + first copy disambiguates. Without the done check, the first slice's + last position would silently borrow the *second slice's first frame* + (same trajectory, but not its temporal successor) and the user would + never know. + """ + n = 8 # two identical trajectories of length 4, glued together + obs = torch.tensor([[0.0], [1.0], [2.0], [3.0]] * 2, dtype=torch.float32) + td = TensorDict( + { + "observation": obs, + self._DEFAULT_TRAJ_KEY: torch.tensor([0] * 8), # all same id + ("next", "done"): torch.tensor([[False], [False], [False], [True]] * 2), + }, + batch_size=[n], + ) + out = NextStateReconstructor()(td) + next_obs = out.get(("next", "observation")) + # Position 3: traj id matches position 4, but done[3]=True → NaN + assert torch.isnan(next_obs[3]).all() + # Positions 0..2 mirror to 1..3 + torch.testing.assert_close(next_obs[:3], obs[1:4]) + # Positions 4..6 mirror to 5..7 + torch.testing.assert_close(next_obs[4:7], obs[5:8]) + # Position 7: no i+1 → NaN + assert torch.isnan(next_obs[7]).all() + + def test_random_sampler_is_mostly_nan(self): + """Random sampling yields mismatched traj ids between neighbors → NaN. + + Documents the honest failure mode: when the user picks a sampler that + doesn't preserve trajectory adjacency, the transform refuses to invent + a next observation. + """ + data = self._make_data(n_traj=8, traj_len=4) # 32 entries + rb = ReplayBuffer( + storage=LazyTensorStorage(32), + sampler=RandomSampler(), + transform=NextStateReconstructor(), + batch_size=16, + ) + rb.extend(data) + torch.manual_seed(0) + sample = rb.sample() + next_obs = sample.get(("next", "observation")) + # With 8 trajectories random-sampled into a 16-batch, the chance that + # two adjacent picks share a trajectory id (≈ 1/8) is low. Assert that + # the *vast majority* of positions are NaN — both that the check is + # firing and that we aren't accidentally fabricating next obs. + nan_frac = torch.isnan(next_obs).all(dim=-1).float().mean().item() + assert nan_frac > 0.7, f"expected mostly-NaN, got nan_frac={nan_frac:.2f}" + + def test_nested_keys(self): + n = 8 + td = TensorDict( + { + "agents": TensorDict( + { + "pos": torch.arange(n * 3, dtype=torch.float32).reshape(n, 3), + "vel": torch.arange(n * 2, dtype=torch.float32).reshape(n, 2), + }, + [n], + ), + ("next", "done"): torch.tensor([[False], [False], [False], [True]] * 2), + ("next", "reward"): torch.zeros(n, 1), + self._DEFAULT_TRAJ_KEY: torch.tensor([0] * 4 + [1] * 4), + }, + batch_size=[n], + ) + rb = ReplayBuffer( + storage=LazyTensorStorage(n), + sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), + transform=NextStateReconstructor( + keys=[("agents", "pos"), ("agents", "vel")], + ), + batch_size=4, + ) + rb.extend(td) + sample = rb.sample() + for k in (("agents", "pos"), ("agents", "vel")): + next_k = ("next", *k) + torch.testing.assert_close(sample.get(next_k)[:3], sample.get(k)[1:4]) + assert torch.isnan(sample.get(next_k)[3]).all() + + def test_explicit_fill_value(self): + data = self._make_data(n_traj=2, traj_len=4) + rb = ReplayBuffer( + storage=LazyTensorStorage(8), + sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), + transform=NextStateReconstructor(fill_value=-1.0), + batch_size=8, + ) + rb.extend(data) + sample = rb.sample() + next_obs = sample.get(("next", "observation")) + # The last position of each slice belongs to a different trajectory + # in (i, i+1), so it gets the fill value. + for slice_start in (0, 4): + assert (next_obs[slice_start + 3] == -1.0).all() + + def test_overwrites_existing_next_obs(self): + """If ``("next", k)`` is already in storage, the transform overwrites it.""" + n = 8 + td = TensorDict( + { + "observation": torch.arange(n, dtype=torch.float32).view(n, 1), + ("next", "observation"): torch.full( + (n, 1), -999.0, dtype=torch.float32 + ), + ("next", "done"): torch.tensor([[False], [False], [False], [True]] * 2), + ("next", "reward"): torch.zeros(n, 1), + self._DEFAULT_TRAJ_KEY: torch.tensor([0] * 4 + [1] * 4), + }, + batch_size=[n], + ) + rb = ReplayBuffer( + storage=LazyTensorStorage(n), + sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), + transform=NextStateReconstructor(), + batch_size=8, + ) + rb.extend(td) + sample = rb.sample() + assert not (sample.get(("next", "observation")) == -999.0).any() + + def test_step_count_cross_check(self): + """``step_count_key`` adds a stricter "consecutive in time" check.""" + n = 4 + td = TensorDict( + { + "observation": torch.arange(n, dtype=torch.float32).view(n, 1), + self._DEFAULT_TRAJ_KEY: torch.zeros(n, dtype=torch.long), + ("next", "done"): torch.zeros(n, 1, dtype=torch.bool), + # Same traj id and no done, but step counts disagree at i=1 + # (jumps from 0 to 5, then 5 -> 6 -> 7). + ("collector", "step_count"): torch.tensor([0, 5, 6, 7]), + }, + batch_size=[n], + ) + t = NextStateReconstructor(step_count_key=("collector", "step_count")) + out = t(td) + next_obs = out.get(("next", "observation")) + # Position 0 → step_count[1] - step_count[0] = 5 ≠ 1, so NaN. + assert torch.isnan(next_obs[0]).all() + # Positions 1 and 2 are consecutive (5→6, 6→7) → reconstructed. + torch.testing.assert_close(next_obs[1], td.get("observation")[2]) + torch.testing.assert_close(next_obs[2], td.get("observation")[3]) + # Position 3 has no i+1 → NaN. + assert torch.isnan(next_obs[3]).all() + + def test_strict_missing_traj_key_raises(self): + td = TensorDict( + {"observation": torch.arange(4, dtype=torch.float32).view(4, 1)}, + batch_size=[4], + ) + with pytest.raises(KeyError, match="trajectory key"): + NextStateReconstructor()(td) + + def test_strict_missing_done_key_raises(self): + td = TensorDict( + { + "observation": torch.arange(4, dtype=torch.float32).view(4, 1), + self._DEFAULT_TRAJ_KEY: torch.zeros(4, dtype=torch.long), + }, + batch_size=[4], + ) + with pytest.raises(KeyError, match="done key"): + NextStateReconstructor()(td) + + def test_strict_false_single_traj_fallback(self): + td = TensorDict( + {"observation": torch.arange(4, dtype=torch.float32).view(4, 1)}, + batch_size=[4], + ) + out = NextStateReconstructor(strict=False)(td) + next_obs = out.get(("next", "observation")) + torch.testing.assert_close(next_obs[:-1], td.get("observation")[1:]) + assert torch.isnan(next_obs[-1]).all() + + def test_traj_key_none_disables_check(self): + td = TensorDict( + { + "observation": torch.arange(4, dtype=torch.float32).view(4, 1), + # Different traj ids, but check is disabled → all-shift, no NaN + # except the last position. + self._DEFAULT_TRAJ_KEY: torch.tensor([0, 1, 2, 3]), + }, + batch_size=[4], + ) + out = NextStateReconstructor(traj_key=None, done_key=None)(td) + next_obs = out.get(("next", "observation")) + torch.testing.assert_close(next_obs[:-1], td.get("observation")[1:]) + assert torch.isnan(next_obs[-1]).all() + + def test_int_obs_requires_explicit_fill_value(self): + td = TensorDict( + { + "observation": torch.arange(4, dtype=torch.int64).view(4, 1), + self._DEFAULT_TRAJ_KEY: torch.zeros(4, dtype=torch.long), + ("next", "done"): torch.zeros(4, 1, dtype=torch.bool), + }, + batch_size=[4], + ) + with pytest.raises(TypeError, match="non-floating dtype"): + NextStateReconstructor()(td) + # Explicit integer fill works + out = NextStateReconstructor(fill_value=-1)(td) + next_obs = out.get(("next", "observation")) + assert next_obs[-1].item() == -1 + + def test_bad_batch_dims_errors(self): + td = TensorDict( + { + "observation": torch.arange(8, dtype=torch.float32).view(2, 4, 1), + self._DEFAULT_TRAJ_KEY: torch.zeros(2, 4, dtype=torch.long), + }, + batch_size=[2, 4], + ) + with pytest.raises(ValueError, match="flat"): + NextStateReconstructor()(td) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/test_samplers.py b/test/rb/test_samplers.py new file mode 100644 index 00000000000..166acac618f --- /dev/null +++ b/test/rb/test_samplers.py @@ -0,0 +1,1998 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import os +import warnings + +import numpy as np +import pytest +import torch +from _rb_common import _has_snapshot, TORCH_VERSION +from packaging import version +from tensordict import TensorDict + +from torchrl._utils import _replace_last +from torchrl.collectors.utils import split_trajectories +from torchrl.data import TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer +from torchrl.data.replay_buffers import ReplayBuffer +from torchrl.data.replay_buffers.samplers import ( + PrioritizedSampler, + PrioritizedSliceSampler, + RandomSampler, + Sampler, + SamplerWithoutReplacement, + SliceSampler, + SliceSamplerWithoutReplacement, + StalenessAwareSampler, +) +from torchrl.data.replay_buffers.scheduler import ( + LinearScheduler, + SchedulerList, + StepScheduler, +) +from torchrl.data.replay_buffers.storages import ( + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, +) +from torchrl.modules import GRUModule, set_recurrent_mode +from torchrl.testing import get_default_devices + + +@pytest.mark.parametrize("size", [10, 15, 20]) +@pytest.mark.parametrize("samples", [5, 9, 11, 14, 16]) +@pytest.mark.parametrize("drop_last", [True, False]) +def test_samplerwithoutrep(size, samples, drop_last): + torch.manual_seed(0) + storage = ListStorage(size) + storage.set(range(size), range(size)) + assert len(storage) == size + sampler = SamplerWithoutReplacement(drop_last=drop_last) + visited = False + for _ in range(10): + _n_left = ( + sampler._sample_list.numel() if sampler._sample_list is not None else size + ) + if samples > size and drop_last: + with pytest.raises( + ValueError, + match=r"The batch size .* is greater than the storage capacity", + ): + idx, _ = sampler.sample(storage, samples) + break + idx, _ = sampler.sample(storage, samples) + if drop_last or _n_left >= samples: + assert idx.numel() == samples + assert idx.unique().numel() == idx.numel() + else: + assert idx.numel() == _n_left + visited = True + if not drop_last and (size % samples > 0): + assert visited + else: + assert not visited + + +class TestSamplers: + @pytest.mark.parametrize( + "backend", ["torch"] + (["torchsnapshot"] if _has_snapshot else []) + ) + def test_sampler_without_rep_state_dict(self, backend): + os.environ["CKPT_BACKEND"] = backend + torch.manual_seed(0) + + n_samples = 3 + buffer_size = 100 + storage_in = LazyTensorStorage(buffer_size, device="cpu") + storage_out = LazyTensorStorage(buffer_size, device="cpu") + + replay_buffer = TensorDictReplayBuffer( + storage=storage_in, + sampler=SamplerWithoutReplacement(), + ) + # fill replay buffer with random data + transition = TensorDict( + { + "observation": torch.ones(1, 4), + "action": torch.ones(1, 2), + "reward": torch.ones(1, 1), + "dones": torch.ones(1, 1), + "next": {"observation": torch.ones(1, 4)}, + }, + batch_size=1, + ) + for _ in range(n_samples): + replay_buffer.extend(transition.clone()) + for _ in range(n_samples): + s = replay_buffer.sample(batch_size=1) + assert (s.exclude("index") == 1).all() + + replay_buffer.extend(torch.zeros_like(transition)) + + state_dict = replay_buffer.state_dict() + + new_replay_buffer = TensorDictReplayBuffer( + storage=storage_out, + batch_size=state_dict["_batch_size"], + sampler=SamplerWithoutReplacement(), + ) + + new_replay_buffer.load_state_dict(state_dict) + s = new_replay_buffer.sample(batch_size=1) + assert (s.exclude("index") == 0).all() + + def test_sampler_without_rep_dumps_loads(self, tmpdir): + d0 = tmpdir + "/save0" + d1 = tmpdir + "/save1" + d2 = tmpdir + "/dump" + replay_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(max_size=100, scratch_dir=d0, device="cpu"), + sampler=SamplerWithoutReplacement(drop_last=True), + batch_size=8, + ) + replay_buffer2 = TensorDictReplayBuffer( + storage=LazyMemmapStorage(max_size=100, scratch_dir=d1, device="cpu"), + sampler=SamplerWithoutReplacement(drop_last=True), + batch_size=8, + ) + td = TensorDict( + {"a": torch.arange(0, 27), ("b", "c"): torch.arange(1, 28)}, batch_size=[27] + ) + replay_buffer.extend(td) + for _ in replay_buffer: + break + replay_buffer.dumps(d2) + replay_buffer2.loads(d2) + assert ( + replay_buffer.sampler._sample_list == replay_buffer2.sampler._sample_list + ).all() + s = replay_buffer2.sample(3) + assert (s["a"] == s["b", "c"] - 1).all() + + @pytest.mark.parametrize("drop_last", [False, True]) + def test_sampler_without_replacement_cap_prefetch(self, drop_last): + torch.manual_seed(0) + data = TensorDict({"a": torch.arange(11)}, batch_size=[11]) + rb = ReplayBuffer( + storage=LazyTensorStorage(11), + sampler=SamplerWithoutReplacement(drop_last=drop_last), + batch_size=2, + prefetch=3, + ) + rb.extend(data) + + for _ in range(100): + s = set() + for i, d in enumerate(rb): + assert i <= (4 + int(not drop_last)), i + s = s.union(set(d["a"].tolist())) + assert i == (4 + int(not drop_last)), i + if drop_last: + assert s != set(range(11)) + else: + assert s == set(range(11)) + + @pytest.mark.parametrize( + "batch_size,num_slices,slice_len,prioritized", + [ + [100, 20, None, True], + [100, 20, None, False], + [120, 30, None, False], + [100, None, 5, False], + [120, None, 4, False], + [101, None, 101, False], + ], + ) + @pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")]) + @pytest.mark.parametrize("done_key", ["done", ("some", "done")]) + @pytest.mark.parametrize("match_episode", [True, False]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_slice_sampler( + self, + batch_size, + num_slices, + slice_len, + prioritized, + episode_key, + done_key, + match_episode, + device, + ): + torch.manual_seed(0) + storage = LazyMemmapStorage(100) + episode = torch.zeros(100, dtype=torch.int, device=device) + episode[:30] = 1 + episode[30:55] = 2 + episode[55:70] = 3 + episode[70:] = 4 + steps = torch.cat( + [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 + ) + + done = torch.zeros(100, 1, dtype=torch.bool) + done[torch.tensor([29, 54, 69, 99])] = 1 + + data = TensorDict( + { + # we only use episode_key if we want the sampler to access it + episode_key if match_episode else "whatever_episode": episode, + "another_episode": episode, + "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), + "act": torch.randn((20,)).expand(100, 20), + "steps": steps, + "count": torch.arange(100), + "other": torch.randn((20, 50)).expand(100, 20, 50), + done_key: done, + _replace_last(done_key, "terminated"): done, + }, + [100], + device=device, + ) + storage.set(range(100), data) + if slice_len is not None and slice_len > 15: + # we may have to sample trajs shorter than slice_len + strict_length = False + else: + strict_length = True + + if prioritized: + num_steps = data.shape[0] + sampler = PrioritizedSliceSampler( + max_capacity=num_steps, + alpha=0.7, + beta=0.9, + num_slices=num_slices, + traj_key=episode_key, + end_key=done_key, + slice_len=slice_len, + strict_length=strict_length, + truncated_key=_replace_last(done_key, "truncated"), + ) + index = torch.arange(0, num_steps, 1) + sampler.extend(index) + sampler.update_priority(index, 1) + else: + sampler = SliceSampler( + num_slices=num_slices, + traj_key=episode_key, + end_key=done_key, + slice_len=slice_len, + strict_length=strict_length, + truncated_key=_replace_last(done_key, "truncated"), + ) + if slice_len is not None: + num_slices = batch_size // slice_len + trajs_unique_id = set() + too_short = False + count_unique = set() + for _ in range(50): + index, info = sampler.sample(storage, batch_size=batch_size) + samples = storage._storage[index] + if strict_length: + # check that trajs are ok + samples = samples.view(num_slices, -1) + + unique_another_episode = ( + samples["another_episode"].unique(dim=1).squeeze() + ) + assert unique_another_episode.shape == torch.Size([num_slices]), ( + num_slices, + samples, + ) + assert ( + samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1] + ).all() + if isinstance(index, tuple): + index_numel = index[0].numel() + else: + index_numel = index.numel() + + too_short = too_short or index_numel < batch_size + trajs_unique_id = trajs_unique_id.union( + samples["another_episode"].view(-1).tolist() + ) + count_unique = count_unique.union(samples.get("count").view(-1).tolist()) + + truncated = info[_replace_last(done_key, "truncated")] + terminated = info[_replace_last(done_key, "terminated")] + assert (truncated | terminated).view(num_slices, -1)[:, -1].all() + assert ( + terminated + == samples[_replace_last(done_key, "terminated")].view_as(terminated) + ).all() + done = info[done_key] + assert done.view(num_slices, -1)[:, -1].all() + + if len(count_unique) == 100: + # all items have been sampled + break + else: + raise AssertionError( + f"Not all items can be sampled: {set(range(100)) - count_unique} are missing" + ) + + if strict_length: + assert not too_short + else: + assert too_short + + assert len(trajs_unique_id) == 4 + + @pytest.mark.parametrize("sampler", [SliceSampler, SliceSamplerWithoutReplacement]) + def test_slice_sampler_at_capacity(self, sampler): + torch.manual_seed(0) + + trajectory0 = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) + trajectory1 = torch.arange(2).repeat_interleave(6) + trajectory = torch.stack([trajectory0, trajectory1], 0) + + td = TensorDict( + {"trajectory": trajectory, "steps": torch.arange(12).expand(2, 12)}, [2, 12] + ) + + rb = ReplayBuffer( + sampler=sampler(traj_key="trajectory", num_slices=2), + storage=LazyTensorStorage(20, ndim=2), + batch_size=6, + ) + + rb.extend(td) + + for s in rb: + if (s["steps"] == 9).any(): + break + else: + raise AssertionError + + def test_slice_sampler_errors(self): + device = "cpu" + batch_size, num_slices = 100, 20 + + episode = torch.zeros(100, dtype=torch.int, device=device) + episode[:30] = 1 + episode[30:55] = 2 + episode[55:70] = 3 + episode[70:] = 4 + steps = torch.cat( + [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 + ) + + done = torch.zeros(100, 1, dtype=torch.bool) + done[torch.tensor([29, 54, 69])] = 1 + + data = TensorDict( + { + # we only use episode_key if we want the sampler to access it + "episode": episode, + "another_episode": episode, + "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), + "act": torch.randn((20,)).expand(100, 20), + "steps": steps, + "other": torch.randn((20, 50)).expand(100, 20, 50), + ("next", "done"): done, + }, + [100], + device=device, + ) + + data_wrong_done = data.clone(False) + data_wrong_done.rename_key_("episode", "_") + data_wrong_done["next", "done"] = done.unsqueeze(1).expand(100, 5, 1) + storage = LazyMemmapStorage(100) + storage.set(range(100), data_wrong_done) + sampler = SliceSampler(num_slices=num_slices) + with pytest.raises( + RuntimeError, + match="Expected the end-of-trajectory signal to be 1-dimensional", + ): + index, _ = sampler.sample(storage, batch_size=batch_size) + + storage = ListStorage(100) + storage.set(range(100), data) + sampler = SliceSampler(num_slices=num_slices) + with pytest.raises( + RuntimeError, + match="Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories.", + ): + index, _ = sampler.sample(storage, batch_size=batch_size) + + @pytest.mark.parametrize("batch_size,num_slices", [[20, 4], [4, 2]]) + @pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")]) + @pytest.mark.parametrize("done_key", ["done", ("some", "done")]) + @pytest.mark.parametrize("match_episode", [True, False]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_slice_sampler_without_replacement( + self, + batch_size, + num_slices, + episode_key, + done_key, + match_episode, + device, + ): + torch.manual_seed(0) + storage = LazyMemmapStorage(100) + episode = torch.zeros(100, dtype=torch.int, device=device) + steps = [] + done = torch.zeros(100, 1, dtype=torch.bool) + for i in range(0, 100, 5): + episode[i : i + 5] = i // 5 + steps.append(torch.arange(5)) + done[i + 4] = 1 + steps = torch.cat(steps) + + data = TensorDict( + { + # we only use episode_key if we want the sampler to access it + episode_key if match_episode else "whatever_episode": episode, + "another_episode": episode, + "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), + "act": torch.randn((20,)).expand(100, 20), + "steps": steps, + "other": torch.randn((20, 50)).expand(100, 20, 50), + done_key: done, + }, + [100], + device=device, + ) + storage.set(range(100), data) + sampler = SliceSamplerWithoutReplacement( + num_slices=num_slices, traj_key=episode_key, end_key=done_key + ) + trajs_unique_id = set() + for i in range(5): + index, info = sampler.sample(storage, batch_size=batch_size) + samples = storage._storage[index] + + # check that trajs are ok + samples = samples.view(num_slices, -1) + assert samples["another_episode"].unique( + dim=1 + ).squeeze().shape == torch.Size([num_slices]) + assert (samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1]).all() + cur_episodes = samples["another_episode"].view(-1).tolist() + for ep in cur_episodes: + assert ep not in trajs_unique_id, i + trajs_unique_id = trajs_unique_id.union( + cur_episodes, + ) + done_recon = info[("next", "truncated")] | info[("next", "terminated")] + assert done_recon.view(num_slices, -1)[:, -1].all() + done = info[("next", "done")] + assert done.view(num_slices, -1)[:, -1].all() + + def test_slice_sampler_left_right(self): + torch.manual_seed(0) + data = TensorDict( + {"obs": torch.arange(1, 11).repeat(10), "eps": torch.arange(100) // 10 + 1}, + [100], + ) + + for N in (2, 4): + rb = TensorDictReplayBuffer( + sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)), + batch_size=50, + storage=LazyMemmapStorage(100), + ) + rb.extend(data) + + for _ in range(10): + sample = rb.sample() + sample = split_trajectories(sample) + assert (sample["next", "truncated"].squeeze(-1).sum(-1) == 1).all() + assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"] + assert ((sample["eps"] == 0).sum(-1) <= N).all() + for i in range(sample.shape[0]): + curr_eps = sample[i]["eps"] + curr_eps = curr_eps[curr_eps != 0] + assert curr_eps.unique().numel() == 1 + + def test_slice_sampler_left_right_ndim(self): + torch.manual_seed(0) + data = TensorDict( + {"obs": torch.arange(1, 11).repeat(12), "eps": torch.arange(120) // 10 + 1}, + [120], + ) + data = data.reshape(4, 30) + + for N in (2, 4): + rb = TensorDictReplayBuffer( + sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)), + batch_size=50, + storage=LazyMemmapStorage(100, ndim=2), + ) + rb.extend(data) + + for _ in range(10): + sample = rb.sample() + sample = split_trajectories(sample) + assert (sample["next", "truncated"].squeeze(-1).sum(-1) <= 1).all() + assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"] + assert ((sample["eps"] == 0).sum(-1) <= N).all() + for i in range(sample.shape[0]): + curr_eps = sample[i]["eps"] + curr_eps = curr_eps[curr_eps != 0] + assert curr_eps.unique().numel() == 1 + + def test_slice_sampler_strictlength(self): + torch.manual_seed(0) + + data = TensorDict( + { + "traj": torch.cat( + [ + torch.ones(2, dtype=torch.int), + torch.zeros(10, dtype=torch.int), + ], + dim=0, + ), + "x": torch.arange(12), + }, + [12], + ) + + buffer = ReplayBuffer( + storage=LazyTensorStorage(12), + sampler=SliceSampler(num_slices=2, strict_length=True, traj_key="traj"), + batch_size=8, + ) + buffer.extend(data) + + for _ in range(50): + sample = buffer.sample() + assert sample.shape == torch.Size([8]) + assert (sample["traj"] == 0).all() + + buffer = ReplayBuffer( + storage=LazyTensorStorage(12), + sampler=SliceSampler(num_slices=2, strict_length=False, traj_key="traj"), + batch_size=8, + ) + buffer.extend(data) + + for _ in range(50): + sample = buffer.sample() + if sample.shape == torch.Size([6]): + assert (sample["traj"] != 0).any() + else: + assert len(sample["traj"].unique()) == 1 + + # ------------------------------------------------------------------ + # traj_key auto-detection tests + # ------------------------------------------------------------------ + + def test_slice_sampler_auto_traj_key_collector_ids(self): + """Auto-detection should prefer ("collector", "traj_ids") over "episode".""" + torch.manual_seed(0) + # Build data with both keys present; sampler should pick collector key + # and warn that this changes the pre-0.13 default. + traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) + data = TensorDict( + { + ("collector", "traj_ids"): traj_ids, + "episode": torch.zeros(8, dtype=torch.int), # wrong, should be ignored + "obs": torch.arange(8).float(), + }, + batch_size=[8], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(8), + sampler=SliceSampler(num_slices=2), + batch_size=6, + ) + rb.extend(data) + # Force resolution — with both keys present we must see a FutureWarning. + with pytest.warns(FutureWarning, match="auto-detected"): + sample = rb.sample() + assert rb.sampler.traj_key == ("collector", "traj_ids") + assert rb.sampler._fetch_traj is True + assert rb.sampler._traj_key_auto is False + # Each slice should come from a single trajectory + sample_reshaped = sample.reshape(2, 3) + for i in range(2): + traj_vals = sample_reshaped[i][("collector", "traj_ids")] + assert traj_vals.unique().numel() == 1 + + def test_slice_sampler_auto_traj_key_no_warning_single_key(self): + """No FutureWarning when only one of the two candidate keys is present.""" + torch.manual_seed(0) + traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) + data = TensorDict( + { + ("collector", "traj_ids"): traj_ids, + "obs": torch.arange(8).float(), + }, + batch_size=[8], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(8), + sampler=SliceSampler(num_slices=2), + batch_size=6, + ) + rb.extend(data) + with warnings.catch_warnings(): + warnings.simplefilter("error", FutureWarning) + rb.sample() + assert rb.sampler.traj_key == ("collector", "traj_ids") + + def test_slice_sampler_auto_traj_key_episode(self): + """Auto-detection falls back to 'episode' when collector key is absent.""" + torch.manual_seed(0) + traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) + data = TensorDict( + { + "episode": traj_ids, + "obs": torch.arange(8).float(), + }, + batch_size=[8], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(8), + sampler=SliceSampler(num_slices=2), + batch_size=6, + ) + rb.extend(data) + rb.sample() + assert rb.sampler.traj_key == "episode" + assert rb.sampler._fetch_traj is True + + def test_slice_sampler_auto_traj_key_fallback_to_done(self): + """Auto-detection falls back to end_key reconstruction when no traj key.""" + torch.manual_seed(0) + done = torch.zeros(9, 1, dtype=torch.bool) + done[[2, 5, 8]] = True + data = TensorDict( + { + ("next", "done"): done, + ("next", "truncated"): done, + ("next", "terminated"): done, + "obs": torch.arange(9).float(), + }, + batch_size=[9], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(9), + sampler=SliceSampler(num_slices=3), + batch_size=9, + ) + rb.extend(data) + rb.sample() + assert rb.sampler._fetch_traj is False + + def test_slice_sampler_explicit_traj_key_no_auto(self): + """Explicit traj_key should bypass auto-detection entirely.""" + torch.manual_seed(0) + traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) + data = TensorDict( + { + "my_traj": traj_ids, + ("collector", "traj_ids"): torch.zeros(8, dtype=torch.int), + "obs": torch.arange(8).float(), + }, + batch_size=[8], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(8), + sampler=SliceSampler(num_slices=2, traj_key="my_traj"), + batch_size=6, + ) + rb.extend(data) + rb.sample() + assert rb.sampler.traj_key == "my_traj" + assert getattr(rb.sampler, "_traj_key_auto", False) is False + + # ------------------------------------------------------------------ + # mask / lengths tests (strict_length=False) + # ------------------------------------------------------------------ + + def _make_rb_with_short_trajs(self, traj_lengths, slice_len, num_slices): + """Helper: build a TensorDictReplayBuffer with trajectories of given lengths.""" + parts = [] + for t_id, length in enumerate(traj_lengths): + is_init = torch.zeros(length, 1, dtype=torch.bool) + is_init[0] = True # episode reset at the first step of each trajectory + parts.append( + TensorDict( + { + "traj": torch.full((length,), t_id, dtype=torch.int), + "obs": torch.arange(length).float(), + "is_init": is_init, + }, + batch_size=[length], + ) + ) + data = torch.cat(parts) + total = sum(traj_lengths) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(total), + sampler=SliceSampler( + slice_len=slice_len, + traj_key="traj", + strict_length=False, + pad_output=True, + ), + batch_size=num_slices * slice_len, + ) + rb.extend(data) + return rb + + def test_slice_sampler_mask_present_when_short_trajs(self): + """mask appears in output when short trajectories force padding.""" + torch.manual_seed(0) + rb = self._make_rb_with_short_trajs( + traj_lengths=[3, 6, 2], slice_len=5, num_slices=3 + ) + sample = rb.sample() + assert ("collector", "mask") in sample.keys(True) + + def test_slice_sampler_mask_shape_dtype(self): + """mask is bool with shape [B*T] (matches batch shape, no trailing 1).""" + torch.manual_seed(0) + B, T = 4, 6 + rb = self._make_rb_with_short_trajs( + traj_lengths=[2, 5, 3, 4], slice_len=T, num_slices=B + ) + sample = rb.sample() + mask = sample[("collector", "mask")] + assert mask.shape == torch.Size([B * T]) + assert mask.dtype == torch.bool + # mask must match the leading batch dim so trainer code can index + # batch[batch.get(("collector", "mask"))] without broadcasting tricks. + assert mask.shape[0] == sample.batch_size[0] + + def test_slice_sampler_mask_correctness(self): + """mask rows are contiguous: True prefix followed by False suffix.""" + torch.manual_seed(0) + B, T = 6, 8 + rb = self._make_rb_with_short_trajs( + traj_lengths=[3, 8, 2, 7, 1, 5], slice_len=T, num_slices=B + ) + for _ in range(20): + sample = rb.sample() + mask = sample[("collector", "mask")].reshape(B, T) + # derive lengths from the mask itself + lengths = mask.sum(-1) # [B] + for i in range(B): + length = lengths[i].item() + assert length >= 1 + assert length <= T + assert mask[ + i, :length + ].all(), f"slice {i}: first {length} steps should be True" + assert not mask[ + i, length: + ].any(), f"slice {i}: steps after {length} should be False" + + def test_slice_sampler_mask_padded_obs_is_valid(self): + """Padded positions repeat the last real index — obs values must be finite.""" + torch.manual_seed(0) + rb = self._make_rb_with_short_trajs( + traj_lengths=[2, 6, 3], slice_len=5, num_slices=3 + ) + sample = rb.sample() + assert torch.isfinite(sample["obs"]).all() + + def test_slice_sampler_strict_length_no_mask(self): + """With pad_output=False, no mask is emitted regardless of strict_length.""" + torch.manual_seed(0) + data = TensorDict( + { + "traj": torch.cat( + [torch.zeros(6, dtype=torch.int), torch.ones(6, dtype=torch.int)] + ), + "obs": torch.arange(12).float(), + }, + batch_size=[12], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(12), + sampler=SliceSampler( + slice_len=4, traj_key="traj", strict_length=True, pad_output=False + ), + batch_size=8, + ) + rb.extend(data) + sample = rb.sample() + assert ("collector", "mask") not in sample.keys(True) + + def test_slice_sampler_pad_output_strict_length_raises(self): + """pad_output=True + strict_length=True is rejected at construction.""" + with pytest.raises(ValueError, match="pad_output=True is incompatible"): + SliceSampler( + slice_len=4, traj_key="traj", strict_length=True, pad_output=True + ) + + def test_slice_sampler_pad_output_marks_slice_starts(self): + """pad_output=True writes is_init=True at every slice start. + + This is what lets a recurrent policy in `set_recurrent_mode("recurrent")` + consume the flat [B*T] sample directly: the RNN splits on `is_init` + and uses each slice's stored hidden state at position 0. + """ + torch.manual_seed(0) + B, T = 4, 8 + rb = self._make_rb_with_short_trajs( + traj_lengths=[3, 8, 2, 7, 1, 5], slice_len=T, num_slices=B + ) + for _ in range(10): + sample = rb.sample() + is_init = sample["is_init"].reshape(B, T) + # Position 0 of every slice must be True regardless of where the + # slice landed within its source trajectory. + assert is_init[:, 0].all(), "every slice must start with is_init=True" + + def test_slice_sampler_marks_slice_starts_no_pad(self): + """Default (no pad_output) flow: is_init=True at every slice start. + + This is the workflow most users will hit: trajectories are written + end-to-end into the buffer, the sampler returns concatenated + variable-length slices, and the RNN splits on `is_init`. No mask, no + padding involved. + """ + torch.manual_seed(0) + traj_lengths = [3, 8, 2, 7, 5] + parts = [] + for t_id, length in enumerate(traj_lengths): + init = torch.zeros(length, 1, dtype=torch.bool) + init[0] = True + parts.append( + TensorDict( + { + "traj": torch.full((length,), t_id, dtype=torch.int), + "is_init": init, + }, + batch_size=[length], + ) + ) + data = torch.cat(parts) + B = 4 + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(data.numel()), + sampler=SliceSampler(num_slices=B, traj_key="traj", strict_length=False), + batch_size=B * 6, + ) + rb.extend(data) + for _ in range(10): + sample = rb.sample() + assert "is_init" in sample.keys(True) + is_init = sample["is_init"].squeeze(-1) + trunc = sample[("next", "truncated")].squeeze(-1) + # Slice 0 always starts at position 0. + assert is_init[0].item(), "first slice must start with is_init=True" + # Every position right after a truncated flag must be is_init=True + # (next slice's start). The last truncated marks the end of the + # batch; nothing follows it. + slice_ends = trunc.nonzero().squeeze(-1).tolist() + for end in slice_ends[:-1]: + assert is_init[ + end + 1 + ].item(), f"slice starting at index {end + 1} missing is_init=True" + + def test_slice_sampler_pad_output_no_is_init_no_marker(self): + """Without is_init in the storage we don't introduce one out of thin air.""" + torch.manual_seed(0) + # Build a buffer *without* is_init. + data = TensorDict( + { + "traj": torch.cat( + [ + torch.full((3,), 0, dtype=torch.int), + torch.full((6,), 1, dtype=torch.int), + torch.full((2,), 2, dtype=torch.int), + ] + ), + "obs": torch.arange(11).float(), + }, + batch_size=[11], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(11), + sampler=SliceSampler( + slice_len=5, traj_key="traj", strict_length=False, pad_output=True + ), + batch_size=15, + ) + rb.extend(data) + sample = rb.sample() + # is_init must not appear if it wasn't in the storage + assert "is_init" not in sample.keys(True) + + def test_slice_sampler_flat_sample_matches_batched_recurrent_module(self): + """A flat padded sample must match an explicit [B, T] recurrent call.""" + torch.manual_seed(0) + B, T = 4, 5 + input_size, hidden_size = 3, 7 + parts = [] + for traj_id, length in enumerate([11, 9, 10, 12]): + is_init = torch.zeros(length, 1, dtype=torch.bool) + is_init[0] = True + parts.append( + TensorDict( + { + "traj": torch.full((length,), traj_id, dtype=torch.int), + "embed": torch.randn(length, input_size), + "recurrent_state": torch.randn(length, 1, hidden_size), + "is_init": is_init, + }, + batch_size=[length], + ) + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(sum(part.shape[0] for part in parts)), + sampler=SliceSampler( + slice_len=T, + traj_key="traj", + strict_length=False, + pad_output=True, + ), + batch_size=B * T, + ) + rb.extend(torch.cat(parts)) + sample = rb.sample() + assert sample["is_init"].reshape(B, T)[:, 0].all() + + gru = GRUModule( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + in_keys=["embed", "recurrent_state", "is_init"], + out_keys=["features", ("next", "recurrent_state")], + ) + with set_recurrent_mode("recurrent"): + flat_out = gru(sample.clone()) + batched_out = gru(sample.clone().reshape(B, T)) + + torch.testing.assert_close( + flat_out["features"].reshape(B, T, hidden_size), batched_out["features"] + ) + torch.testing.assert_close( + flat_out[("next", "recurrent_state")].reshape(B, T, 1, hidden_size), + batched_out[("next", "recurrent_state")], + ) + + def test_slice_sampler_mask_all_long_trajs_no_mask(self): + """When all trajs >= slice_len, pad_output=True still emits no mask (nothing to pad).""" + torch.manual_seed(0) + data = TensorDict( + { + "traj": torch.cat( + [torch.zeros(8, dtype=torch.int), torch.ones(8, dtype=torch.int)] + ), + "obs": torch.arange(16).float(), + }, + batch_size=[16], + ) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(16), + sampler=SliceSampler( + slice_len=4, traj_key="traj", strict_length=False, pad_output=True + ), + batch_size=8, + ) + rb.extend(data) + sample = rb.sample() + # No short trajectories → no padding needed → no mask emitted + assert ("collector", "mask") not in sample.keys(True) + + def test_slice_sampler_truncated_marks_last_real_step(self): + """truncated flag should sit at the last *real* timestep, not the padded end.""" + torch.manual_seed(0) + B, T = 4, 6 + rb = self._make_rb_with_short_trajs( + traj_lengths=[2, 5, 3, 4], slice_len=T, num_slices=B + ) + sample = rb.sample() + mask = sample[("collector", "mask")].reshape(B, T) + lengths = mask.sum(-1) # [B] — derived from mask + trunc = sample[("next", "truncated")].reshape(B, T) + for i in range(B): + length = lengths[i].item() + # truncated should be True exactly at position length-1 + assert trunc[ + i, length - 1 + ].item(), f"slice {i}: truncated missing at last real step" + # no truncated flag in padded region + if length < T: + assert not trunc[ + i, length: + ].any(), f"slice {i}: spurious truncated in padding" + + @pytest.mark.parametrize("ndim", [1, 2]) + @pytest.mark.parametrize("strict_length", [True, False]) + @pytest.mark.parametrize("circ", [False, True]) + @pytest.mark.parametrize("at_capacity", [False, True]) + def test_slice_sampler_prioritized(self, ndim, strict_length, circ, at_capacity): + torch.manual_seed(0) + out = [] + for t in range(5): + length = (t + 1) * 5 + done = torch.zeros(length, 1, dtype=torch.bool) + done[-1] = 1 + priority = 10 if t == 0 else 1 + traj = TensorDict( + { + "traj": torch.full((length,), t), + "step_count": torch.arange(length), + "done": done, + "priority": torch.full((length,), priority), + }, + batch_size=length, + ) + out.append(traj) + data = torch.cat(out) + if ndim == 2: + data = torch.stack([data, data]) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(data.numel() - at_capacity, ndim=ndim), + sampler=PrioritizedSliceSampler( + max_capacity=data.numel() - at_capacity, + alpha=1.0, + beta=1.0, + end_key="done", + slice_len=10, + strict_length=strict_length, + cache_values=True, + ), + batch_size=50, + ) + if not circ: + # Simplest case: the buffer is full but no overlap + index = rb.extend(data, update_priority=False) + else: + # The buffer is 2/3 -> 1/3 overlapping + rb.extend(data[..., : data.shape[-1] // 3], update_priority=False) + index = rb.extend(data, update_priority=False) + rb.update_priority(index, data["priority"]) + samples = [] + found_shorter_batch = False + for _ in range(100): + samples.append(rb.sample()) + if samples[-1].numel() < 50: + found_shorter_batch = True + samples = torch.cat(samples) + if strict_length: + assert not found_shorter_batch + else: + assert found_shorter_batch + # the first trajectory has a very high priority, but should only appear + # if strict_length=False. + if strict_length: + assert (samples["traj"] != 0).all(), samples["traj"].unique() + else: + assert (samples["traj"] == 0).any() + # Check that all samples of the first traj contain all elements (since it's too short to fulfill 10 elts) + sc = samples[samples["traj"] == 0]["step_count"] + assert (sc == 1).sum() == (sc == 2).sum() + assert (sc == 1).sum() == (sc == 4).sum() + assert rb.sampler._cache + rb.extend(data, update_priority=False) + assert not rb.sampler._cache + + @pytest.mark.parametrize("ndim", [1, 2]) + @pytest.mark.parametrize("strict_length", [True, False]) + @pytest.mark.parametrize("circ", [False, True]) + @pytest.mark.parametrize( + "span", [False, [False, False], [False, True], 3, [False, 3]] + ) + def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): + torch.manual_seed(0) + out = [] + # 5 trajs of length 3, 6, 9, 12 and 15 + for t in range(5): + length = (t + 1) * 3 + done = torch.zeros(length, 1, dtype=torch.bool) + done[-1] = 1 + priority = 1 + traj = TensorDict( + { + "traj": torch.full((length,), t), + "step_count": torch.arange(length), + "done": done, + "priority": torch.full((length,), priority), + }, + batch_size=length, + ) + out.append(traj) + data = torch.cat(out) + if ndim == 2: + data = torch.stack([data, data]) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(data.numel(), ndim=ndim), + sampler=PrioritizedSliceSampler( + max_capacity=data.numel(), + alpha=1.0, + beta=1.0, + end_key="done", + slice_len=5, + strict_length=strict_length, + cache_values=True, + span=span, + ), + batch_size=5, + ) + if not circ: + # Simplest case: the buffer is full but no overlap + index = rb.extend(data) + else: + # The buffer is 2/3 -> 1/3 overlapping + rb.extend(data[..., : data.shape[-1] // 3]) + index = rb.extend(data) + rb.update_priority(index, data["priority"]) + found_traj_0 = False + found_traj_4_truncated_right = False + for i, s in enumerate(rb): + t = s["traj"].unique().tolist() + assert len(t) == 1 + t = t[0] + if t == 0: + found_traj_0 = True + if t == 4 and s.numel() < 5: + if s["step_count"][0] > 10: + found_traj_4_truncated_right = True + if s["step_count"][0] == 0: + pass + if i == 1000: + break + assert not rb.sampler.span[0] + # if rb.sampler.span[0]: + # assert found_traj_4_truncated_left + if rb.sampler.span[1]: + assert found_traj_4_truncated_right + else: + assert not found_traj_4_truncated_right + if strict_length and not rb.sampler.span[1]: + assert not found_traj_0 + else: + assert found_traj_0 + + @pytest.mark.parametrize("max_priority_within_buffer", [True, False]) + def test_prb_update_max_priority(self, max_priority_within_buffer): + rb = ReplayBuffer( + storage=LazyTensorStorage(11), + sampler=PrioritizedSampler( + max_capacity=11, + alpha=1.0, + beta=1.0, + max_priority_within_buffer=max_priority_within_buffer, + ), + ) + for data in torch.arange(20): + idx = rb.add(data) + rb.update_priority(idx, 21 - data) + if data <= 10: + # The max is always going to be the first value + assert rb.sampler._max_priority[0] == 21 + assert rb.sampler._max_priority[1] == 0 + elif not max_priority_within_buffer: + # The max is the historical max, which was at idx 0 + assert rb.sampler._max_priority[0] == 21 + assert rb.sampler._max_priority[1] == 0 + else: + # the max is the current max. Find it and compare + sumtree = torch.as_tensor( + [rb.sampler._sum_tree[i] for i in range(rb.sampler._max_capacity)] + ) + assert rb.sampler._max_priority[0] == sumtree.max() + assert rb.sampler._max_priority[1] == sumtree.argmax() + idx = rb.extend(torch.arange(10)) + rb.update_priority(idx, 12) + if max_priority_within_buffer: + assert rb.sampler._max_priority[0] == 12 + assert rb.sampler._max_priority[1] == 0 + else: + assert rb.sampler._max_priority[0] == 21 + assert rb.sampler._max_priority[1] == 0 + + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + def test_prb_serialization(self, tmpdir): + rb = ReplayBuffer( + storage=LazyMemmapStorage(max_size=10), + sampler=PrioritizedSampler(max_capacity=10, alpha=0.8, beta=0.6), + ) + + td = TensorDict( + { + "observations": torch.zeros(1, 3), + "actions": torch.zeros(1, 1), + "rewards": torch.zeros(1, 1), + "next_observations": torch.zeros(1, 3), + "terminations": torch.zeros(1, 1, dtype=torch.bool), + }, + batch_size=[1], + ) + rb.extend(td) + + rb.save(tmpdir) + + rb2 = ReplayBuffer( + storage=LazyMemmapStorage(max_size=10), + sampler=PrioritizedSampler(max_capacity=10, alpha=0.5, beta=0.5), + ) + + td = TensorDict( + { + "observations": torch.ones(1, 3), + "actions": torch.ones(1, 1), + "rewards": torch.ones(1, 1), + "next_observations": torch.ones(1, 3), + "terminations": torch.ones(1, 1, dtype=torch.bool), + }, + batch_size=[1], + ) + rb2.extend(td) + rb2.load(tmpdir) + assert len(rb) == 1 + assert rb.sampler._alpha == rb2.sampler._alpha + assert rb.sampler._beta == rb2.sampler._beta + assert rb.sampler._max_priority[0] == rb2.sampler._max_priority[0] + assert rb.sampler._max_priority[1] == rb2.sampler._max_priority[1] + + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + def test_prb_new_sampler_with_loaded_storage(self, tmpdir): + """Test that creating a new PrioritizedSampler with loaded storage works correctly. + + This test reproduces the issue from scratch8.py where creating a new + PrioritizedSampler instance with storage that already contains data + would fail with "RuntimeError: non-positive p_sum". + """ + device = torch.device("cpu") + + # Create and populate original buffer + original_rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(10, device=device), + sampler=PrioritizedSampler(max_capacity=10, alpha=0.7, beta=0.5), + batch_size=2, + priority_key="td_error", + ) + + data = TensorDict( + { + "state": torch.ones(4, 2, dtype=torch.float32, device=device), + "td_error": torch.ones(4) * 0.5, + }, + batch_size=torch.Size((4,)), + ) + original_rb.extend(data) + + # Update priorities + td = original_rb.sample() + td["td_error"] = torch.arange(2, device=device) + 1.0 + original_rb.update_tensordict_priority(td) + + # Get original priorities for comparison + original_priorities = torch.tensor( + [original_rb._sampler._sum_tree[i] for i in range(len(original_rb))] + ) + + # Save and load normally + original_rb.dumps(tmpdir) + del original_rb + + loaded_rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(10, device=device), + sampler=PrioritizedSampler(max_capacity=10, alpha=0.7, beta=0.5), + batch_size=2, + priority_key="td_error", + ) + loaded_rb.loads(tmpdir) + + # Create a new buffer with the loaded storage but NEW sampler + # This was failing before the fix with "RuntimeError: non-positive p_sum" + new_rb_with_loaded_storage = TensorDictReplayBuffer( + storage=loaded_rb.storage, # Use the loaded storage + sampler=PrioritizedSampler( # But create a NEW sampler instance + max_capacity=len(loaded_rb), alpha=0.7, beta=0.5 + ), + batch_size=2, + priority_key="td_error", + ) + + # This should work now thanks to our fix + td = new_rb_with_loaded_storage.sample() + assert td.batch_size == torch.Size([2]) + + # Verify the storage has the expected data + assert len(new_rb_with_loaded_storage) == 4 + + # Verify priorities were properly initialized with default values + # When creating a new sampler with existing storage, it should initialize with default priorities + new_priorities = torch.tensor( + [ + new_rb_with_loaded_storage._sampler._sum_tree[i] + for i in range(len(new_rb_with_loaded_storage)) + ] + ) + expected_default_priority = new_rb_with_loaded_storage._sampler.default_priority + expected_priorities = torch.full( + (len(new_rb_with_loaded_storage),), + expected_default_priority, + dtype=torch.float, + ) + + # All priorities should be positive and equal to the default priority + assert (new_priorities > 0).all(), "All priorities should be positive" + torch.testing.assert_close( + new_priorities, + expected_priorities, + msg="New sampler should initialize with default priorities", + ) + + # Also verify that the loaded buffer maintains the original priorities + loaded_priorities = torch.tensor( + [loaded_rb._sampler._sum_tree[i] for i in range(len(loaded_rb))] + ) + torch.testing.assert_close( + loaded_priorities, + original_priorities, + msg="Loaded buffer should maintain original priorities", + ) + + def test_prb_ndim(self): + """This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB. + + All tests are done for 1d and 2d TDs. + + """ + torch.manual_seed(0) + np.random.seed(0) + + # first case: 1d, RB + rb = ReplayBuffer( + sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), + storage=LazyTensorStorage(100), + batch_size=4, + ) + data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) + idx = rb.extend(data) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() + rb.update_priority(idx, 2) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() + s, info = rb.sample(return_info=True) + rb.update_priority(info["index"], 3) + assert ( + torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[info["index"]] + == 3 + ).all() + + # second case: 1d, TRB + rb = TensorDictReplayBuffer( + sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), + storage=LazyTensorStorage(100), + batch_size=4, + ) + data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) + idx = rb.extend(data) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() + rb.update_priority(idx, 2) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() + s = rb.sample() + rb.update_priority(s["index"], 3) + assert ( + torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[s["index"]] == 3 + ).all() + + # third case: 1d TPRB + rb = TensorDictPrioritizedReplayBuffer( + alpha=1.0, + beta=1.0, + storage=LazyTensorStorage(100), + batch_size=4, + priority_key="p", + ) + data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) + idx = rb.extend(data) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 0.5).all() + rb.update_priority(idx, 2) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() + s = rb.sample() + + s["p"] = torch.ones(4) * 10_000 + rb.update_tensordict_priority(s) + assert ( + torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[s["index"]] + == 10_000 + ).all() + + s2 = rb.sample() + # All indices in s2 must be from s since we set a very high priority to these items + assert (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).any(0).all() + + # fourth case: 2d RB + rb = ReplayBuffer( + sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), + storage=LazyTensorStorage(100, ndim=2), + batch_size=4, + ) + data = TensorDict( + {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] + ) + idx = rb.extend(data) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() + rb.update_priority(idx, 2) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() + + s, info = rb.sample(return_info=True) + rb.update_priority(info["index"], 3) + priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( + (5, 2) + ) + assert (priorities[info["index"]] == 3).all() + + # fifth case: 2d TRB + # 2d + rb = TensorDictReplayBuffer( + sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), + storage=LazyTensorStorage(100, ndim=2), + batch_size=4, + ) + data = TensorDict( + {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] + ) + idx = rb.extend(data) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() + rb.update_priority(idx, 2) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() + + s = rb.sample() + rb.update_priority(s["index"], 10_000) + priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( + (5, 2) + ) + assert (priorities[s["index"].unbind(-1)] == 10_000).all() + + s2 = rb.sample() + assert ( + (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all() + ) + + # Sixth case: 2d TDPRB + rb = TensorDictPrioritizedReplayBuffer( + alpha=1.0, + beta=1.0, + storage=LazyTensorStorage(100, ndim=2), + batch_size=4, + priority_key="p", + ) + data = TensorDict( + {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] + ) + idx = rb.extend(data) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 0.5).all() + rb.update_priority(idx, torch.ones(()) * 2) + assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() + s = rb.sample() + # setting the priorities to a value that is so big that the buffer will resample them + s["p"] = torch.ones(4) * 10_000 + rb.update_tensordict_priority(s) + priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( + (5, 2) + ) + assert (priorities[s["index"].unbind(-1)] == 10_000).all() + + s2 = rb.sample() + assert ( + (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all() + ) + + def test_replacement_kwarg_random(self): + # RandomSampler(replacement=True) is a regular RandomSampler + s = RandomSampler() + assert type(s) is RandomSampler + s = RandomSampler(replacement=True) + assert type(s) is RandomSampler + + # RandomSampler(replacement=False) dispatches to SamplerWithoutReplacement + s = RandomSampler(replacement=False) + assert type(s) is SamplerWithoutReplacement + # default kwargs propagate + assert s.drop_last is False + assert s.shuffle is True + + # Extra kwargs are forwarded to SamplerWithoutReplacement + s = RandomSampler(replacement=False, drop_last=True, shuffle=False) + assert type(s) is SamplerWithoutReplacement + assert s.drop_last is True + assert s.shuffle is False + + # isinstance is preserved + assert isinstance(s, Sampler) + assert isinstance(s, SamplerWithoutReplacement) + + def test_replacement_kwarg_slice(self): + # SliceSampler(replacement=True) is a regular SliceSampler + s = SliceSampler(slice_len=5) + assert type(s) is SliceSampler + s = SliceSampler(replacement=True, slice_len=5) + assert type(s) is SliceSampler + + # SliceSampler(replacement=False) dispatches to SliceSamplerWithoutReplacement + s = SliceSampler(replacement=False, slice_len=5) + assert type(s) is SliceSamplerWithoutReplacement + assert s.slice_len == 5 + assert s.drop_last is False + assert s.shuffle is True + + # Extra without-replacement kwargs forward correctly + s = SliceSampler( + replacement=False, + slice_len=5, + drop_last=True, + shuffle=False, + traj_key="episode", + strict_length=False, + ) + assert type(s) is SliceSamplerWithoutReplacement + assert s.slice_len == 5 + assert s.drop_last is True + assert s.shuffle is False + assert s.traj_key == "episode" + assert s.strict_length is False + + # isinstance preserves the SliceSampler hierarchy + assert isinstance(s, SliceSampler) + assert isinstance(s, SamplerWithoutReplacement) + + def test_replacement_kwarg_subclass_unaffected(self): + # PrioritizedSliceSampler inherits from SliceSampler but should NOT dispatch + s = PrioritizedSliceSampler(slice_len=5, max_capacity=10, alpha=0.5, beta=0.5) + assert type(s) is PrioritizedSliceSampler + + # SamplerWithoutReplacement(replacement=...) is a no-op pop + s = SamplerWithoutReplacement(replacement=False, drop_last=True) + assert type(s) is SamplerWithoutReplacement + assert s.drop_last is True + s = SliceSamplerWithoutReplacement(replacement=False, slice_len=5) + assert type(s) is SliceSamplerWithoutReplacement + assert s.slice_len == 5 + + def test_replacement_kwarg_no_variant_errors(self): + # PrioritizedSampler has no without-replacement variant -> TypeError + with pytest.raises(TypeError, match="no without-replacement variant"): + PrioritizedSampler(max_capacity=10, alpha=0.5, beta=0.5, replacement=False) + + def test_replacement_kwarg_in_replay_buffer(self): + # End-to-end: a buffer using RandomSampler(replacement=False) should + # exhaust the storage without duplicate indices (like SamplerWithoutReplacement). + torch.manual_seed(0) + data = TensorDict({"a": torch.arange(11)}, batch_size=[11]) + rb = ReplayBuffer( + storage=LazyTensorStorage(11), + sampler=RandomSampler(replacement=False, drop_last=False), + batch_size=3, + ) + rb.extend(data) + seen = set() + for _ in range(4): + seen.update(rb.sample()["a"].tolist()) + assert seen == set(range(11)) + + def test_replacement_kwarg_slice_in_replay_buffer(self): + # End-to-end: SliceSampler(replacement=False) returns sub-trajectories + torch.manual_seed(0) + episodes = torch.zeros(60, dtype=torch.long) + episodes[:20] = 0 + episodes[20:40] = 1 + episodes[40:] = 2 + data = TensorDict( + {"episode": episodes, "obs": torch.arange(60)}, + batch_size=[60], + ) + rb = ReplayBuffer( + storage=LazyTensorStorage(60), + sampler=SliceSampler( + replacement=False, + slice_len=5, + traj_key="episode", + strict_length=True, + ), + batch_size=10, + ) + rb.extend(data) + sample = rb.sample() + # batch_size=10, slice_len=5 -> 2 slices of 5 contiguous obs each + obs = sample["obs"].view(2, 5) + diffs = obs[:, 1:] - obs[:, :-1] + assert (diffs == 1).all(), obs + + +class TestStalenessAwareSampler: + """Tests for StalenessAwareSampler.""" + + def _make_buffer_with_versions(self, n_entries=100, version_range=(0, 5)): + """Create a replay buffer populated with data containing policy_version.""" + sampler = StalenessAwareSampler(max_staleness=-1) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(n_entries), + sampler=sampler, + batch_size=16, + ) + # Fill with data having varying policy versions + for v in range(version_range[0], version_range[1] + 1): + batch = TensorDict( + { + "observation": torch.randn(20, 4), + "action": torch.randn(20, 2), + "policy_version": torch.full((20,), float(v)), + }, + batch_size=[20], + ) + rb.extend(batch) + return rb, sampler + + def test_basic_sampling(self): + """Test that StalenessAwareSampler can sample from a buffer.""" + rb, sampler = self._make_buffer_with_versions() + sampler.consumer_version = 5 + batch = rb.sample() + assert batch is not None + assert batch.shape[0] == 16 + + def test_freshness_weighting(self): + """Test that fresher entries are sampled more frequently.""" + sampler = StalenessAwareSampler(max_staleness=-1) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(200), + sampler=sampler, + batch_size=32, + ) + # Add 100 entries at version 0 (stale) and 100 at version 9 (fresh) + stale = TensorDict( + { + "observation": torch.zeros(100, 4), + "policy_version": torch.full((100,), 0.0), + }, + batch_size=[100], + ) + fresh = TensorDict( + { + "observation": torch.ones(100, 4), + "policy_version": torch.full((100,), 9.0), + }, + batch_size=[100], + ) + rb.extend(stale) + rb.extend(fresh) + sampler.consumer_version = 10 + + # Sample many times and count how often fresh vs stale entries appear + fresh_count = 0 + total = 0 + for _ in range(100): + batch = rb.sample() + # Fresh entries have observation == 1, stale have observation == 0 + fresh_count += (batch["observation"][:, 0] > 0.5).sum().item() + total += batch.shape[0] + + fresh_ratio = fresh_count / total + # Fresh entries (staleness=1) should be sampled ~10x more than stale (staleness=10) + # So fresh_ratio should be significantly above 0.5 + assert ( + fresh_ratio > 0.7 + ), f"Expected fresh entries to dominate, got {fresh_ratio:.2f}" + + def test_hard_staleness_gate(self): + """Test that entries beyond max_staleness are never sampled.""" + sampler = StalenessAwareSampler(max_staleness=3) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(200), + sampler=sampler, + batch_size=32, + ) + # Add entries at version 0 (stale) and version 8 (fresh) + stale = TensorDict( + { + "observation": torch.zeros(100, 4), + "policy_version": torch.full((100,), 0.0), + }, + batch_size=[100], + ) + fresh = TensorDict( + { + "observation": torch.ones(100, 4), + "policy_version": torch.full((100,), 8.0), + }, + batch_size=[100], + ) + rb.extend(stale) + rb.extend(fresh) + sampler.consumer_version = 10 + + # All sampled entries should be fresh (staleness=2 <= 3) + # Stale entries have staleness=10 > 3, so they're excluded + for _ in range(50): + batch = rb.sample() + assert ( + batch["observation"][:, 0] > 0.5 + ).all(), ( + "Stale entries should never be sampled when max_staleness is exceeded" + ) + + def test_all_stale_raises(self): + """Test that an error is raised when all entries exceed max_staleness.""" + sampler = StalenessAwareSampler(max_staleness=2) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(50), + sampler=sampler, + batch_size=8, + ) + data = TensorDict( + { + "observation": torch.randn(50, 4), + "policy_version": torch.full((50,), 0.0), + }, + batch_size=[50], + ) + rb.extend(data) + sampler.consumer_version = 100 # Everything is very stale + + with pytest.raises(RuntimeError, match="max_staleness"): + rb.sample() + + def test_consumer_version_increment(self): + """Test consumer version tracking.""" + sampler = StalenessAwareSampler() + assert sampler.consumer_version == 0 + sampler.increment_consumer_version() + assert sampler.consumer_version == 1 + sampler.consumer_version = 42 + assert sampler.consumer_version == 42 + + def test_staleness_in_info(self): + """Test that staleness values are returned in sample info.""" + sampler = StalenessAwareSampler(max_staleness=-1) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(50), + sampler=sampler, + batch_size=8, + ) + data = TensorDict( + { + "observation": torch.randn(50, 4), + "policy_version": torch.full((50,), 3.0), + }, + batch_size=[50], + ) + rb.extend(data) + sampler.consumer_version = 5 + + index, info = sampler.sample(rb._storage, 8) + assert "staleness" in info + assert (info["staleness"] == 2.0).all() # consumer=5 - version=3 = 2 + + def test_missing_version_key_raises(self): + """Test that a clear error is raised when version key is missing.""" + sampler = StalenessAwareSampler() + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(50), + sampler=sampler, + batch_size=8, + ) + data = TensorDict( + {"observation": torch.randn(50, 4)}, + batch_size=[50], + ) + rb.extend(data) + + with pytest.raises(KeyError, match="policy_version"): + rb.sample() + + def test_state_dict_roundtrip(self): + """Test that state_dict/load_state_dict preserves sampler state.""" + sampler = StalenessAwareSampler(max_staleness=7) + sampler.consumer_version = 42 + + sd = sampler.state_dict() + assert sd["consumer_version"] == 42 + assert sd["max_staleness"] == 7 + + sampler2 = StalenessAwareSampler() + sampler2.load_state_dict(sd) + assert sampler2.consumer_version == 42 + assert sampler2.max_staleness == 7 + + def test_no_staleness_limit(self): + """Test sampling with max_staleness=-1 (no limit).""" + sampler = StalenessAwareSampler(max_staleness=-1) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(50), + sampler=sampler, + batch_size=8, + ) + data = TensorDict( + { + "observation": torch.randn(50, 4), + "policy_version": torch.full((50,), 0.0), + }, + batch_size=[50], + ) + rb.extend(data) + sampler.consumer_version = 1000 # Very stale, but no limit + + # Should not raise + batch = rb.sample() + assert batch.shape[0] == 8 + + +def test_prioritized_slice_sampler_doc_example(): + sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) + rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6 + ) + data = TensorDict( + { + "observation": torch.randn(9, 16), + "action": torch.randn(9, 1), + "episode": torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=torch.long), + "steps": torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=torch.long), + ("next", "observation"): torch.randn(9, 16), + ("next", "reward"): torch.randn(9, 1), + ("next", "done"): torch.tensor( + [0, 0, 1, 0, 0, 1, 0, 0, 1], dtype=torch.bool + ).unsqueeze(1), + }, + batch_size=[9], + ) + rb.extend(data) + sample, info = rb.sample(return_info=True) + # print("episode", sample["episode"].tolist()) + # print("steps", sample["steps"].tolist()) + # print("weight", info["priority_weight"].tolist()) + + priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1]) + rb.update_priority(torch.arange(0, 9, 1), priority=priority) + sample, info = rb.sample(return_info=True) + # print("episode", sample["episode"].tolist()) + # print("steps", sample["steps"].tolist()) + # print("weight", info["priority_weight"].tolist()) + + +@pytest.mark.parametrize("device", get_default_devices()) +def test_prioritized_slice_sampler_episodes(device): + num_slices = 10 + batch_size = 20 + + episode = torch.zeros(100, dtype=torch.int, device=device) + episode[:30] = 1 + episode[30:55] = 2 + episode[55:70] = 3 + episode[70:] = 4 + steps = torch.cat( + [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 + ) + done = torch.zeros(100, 1, dtype=torch.bool) + done[torch.tensor([29, 54, 69])] = 1 + + data = TensorDict( + { + "observation": torch.randn(100, 16), + "action": torch.randn(100, 4), + "episode": episode, + "steps": steps, + ("next", "observation"): torch.randn(100, 16), + ("next", "reward"): torch.randn(100, 1), + ("next", "done"): done, + }, + batch_size=[100], + device=device, + ) + + num_steps = data.shape[0] + sampler = PrioritizedSliceSampler( + max_capacity=num_steps, + alpha=0.7, + beta=0.9, + num_slices=num_slices, + ) + + rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(100), + sampler=sampler, + batch_size=batch_size, + ) + rb.extend(data) + + episodes = [] + for _ in range(10): + sample = rb.sample() + episodes.append(sample["episode"]) + assert {1, 2, 3, 4} == set( + torch.cat(episodes).cpu().tolist() + ), "all episodes are expected to be sampled at least once" + + index = torch.arange(0, num_steps, 1) + new_priorities = torch.cat( + [torch.ones(30), torch.zeros(25), torch.ones(15), torch.zeros(30)], 0 + ) + sampler.update_priority(index, new_priorities) + + episodes = [] + for _ in range(10): + sample = rb.sample() + episodes.append(sample["episode"]) + assert {1, 3} == set( + torch.cat(episodes).cpu().tolist() + ), "after priority update, only episode 1 and 3 are expected to be sampled" + + +@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)]) +@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)]) +@pytest.mark.parametrize("gamma", [0.1]) +@pytest.mark.parametrize("total_steps", [200]) +@pytest.mark.parametrize("n_annealing_steps", [100]) +@pytest.mark.parametrize("anneal_every_n", [10, 159]) +@pytest.mark.parametrize("alpha_min", [0, 0.2]) +@pytest.mark.parametrize("beta_max", [1, 1.4]) +def test_prioritized_parameter_scheduler( + alpha, + beta, + gamma, + total_steps, + n_annealing_steps, + anneal_every_n, + alpha_min, + beta_max, +): + rb = TensorDictPrioritizedReplayBuffer( + alpha=alpha, beta=beta, storage=ListStorage(max_size=1000) + ) + data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000) + rb.extend(data) + alpha_scheduler = LinearScheduler( + rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps + ) + beta_scheduler = StepScheduler( + rb, + param_name="beta", + gamma=gamma, + n_steps=anneal_every_n, + max_value=beta_max, + mode="additive", + ) + + scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler)) + + alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha) + alpha_min = torch.tensor(alpha_min) + expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1) + expected_alpha_vals = torch.nn.functional.pad( + expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min + ) + + expected_beta_vals = [beta] + annealing_steps = total_steps // anneal_every_n + gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma + expected_beta_vals = ( + (beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max) + ) + for i in range(total_steps): + curr_alpha = rb.sampler.alpha + torch.testing.assert_close( + curr_alpha + if torch.is_tensor(curr_alpha) + else torch.tensor(curr_alpha).float(), + expected_alpha_vals[i], + msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}", + ) + curr_beta = rb.sampler.beta + torch.testing.assert_close( + curr_beta + if torch.is_tensor(curr_beta) + else torch.tensor(curr_beta).float(), + expected_beta_vals[i], + msg=f"expected {expected_beta_vals[i]}, got {curr_beta}", + ) + rb.sample(20) + scheduler.step() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_storage_map.py b/test/rb/test_storage_map.py similarity index 100% rename from test/test_storage_map.py rename to test/rb/test_storage_map.py diff --git a/test/rb/test_storages.py b/test/rb/test_storages.py new file mode 100644 index 00000000000..f77dec13030 --- /dev/null +++ b/test/rb/test_storages.py @@ -0,0 +1,1525 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse +import functools +import gc +import os +import shutil +import signal +import subprocess +import sys +import tempfile +import time +from functools import partial +from pathlib import Path +from unittest import mock + +import pytest +import torch +from _rb_common import ( + _has_snapshot, + _has_zstandard, + _os_is_windows, + torch_2_3, + TORCH_VERSION, +) +from packaging import version +from tensordict import ( + assert_allclose_td, + is_tensor_collection, + is_tensorclass, + LazyStackedTensorDict, + tensorclass, + TensorDict, + TensorDictBase, +) +from torch import multiprocessing as mp +from torch.utils._pytree import tree_flatten, tree_map + +from torchrl.data import ( + CompressedListStorage, + ReplayBuffer, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers.samplers import ( + RandomSampler, + SamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import ( + _MEMMAP_STORAGE_REGISTRY, + LazyMemmapStorage, + LazyStackStorage, + LazyTensorStorage, + ListStorage, + TensorStorage, +) +from torchrl.data.replay_buffers.utils import tree_iter +from torchrl.data.replay_buffers.writers import RoundRobinWriter +from torchrl.testing import capture_log_records, get_default_devices, make_tc + + +class TestStorages: + def _get_tensor(self): + return torch.randn(10, 11) + + def _get_tensordict(self): + return TensorDict( + {"data": torch.randn(10, 11), ("nested", "data"): torch.randn(10, 11, 3)}, + [10, 11], + ) + + def _get_pytree(self): + return { + "a": torch.randint(100, (10, 11, 1)), + "b": {"c": [torch.zeros(10, 11), (torch.ones(10, 11),)]}, + 30: torch.zeros(10, 11), + } + + def _get_tensorclass(self): + data = self._get_tensordict() + return make_tc(data)(**data, batch_size=data.shape) + + @pytest.mark.parametrize("storage_type", [TensorStorage]) + def test_errors(self, storage_type): + with pytest.raises(ValueError, match="Expected storage to be non-null"): + storage_type(None) + data = torch.randn(3) + with pytest.raises( + ValueError, match="The max-size and the storage shape mismatch" + ): + storage_type(data, max_size=4) + + def test_existsok_lazymemmap(self, tmpdir): + storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storage0) + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storage1) + with pytest.raises(RuntimeError, match="existsok"): + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True) + rb = ReplayBuffer(storage=storage2) + rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) + + @pytest.mark.parametrize( + "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] + ) + @pytest.mark.parametrize("storage_type", [TensorStorage]) + def test_get_set(self, storage_type, data_type): + if data_type == "tensor": + data = self._get_tensor() + elif data_type == "tensorclass": + data = self._get_tensorclass() + elif data_type == "tensordict": + data = self._get_tensordict() + elif data_type == "pytree": + data = self._get_pytree() + else: + raise NotImplementedError + storage = storage_type(data) + if data_type == "pytree": + storage.set(range(10), tree_map(torch.zeros_like, data)) + + def check(x): + assert (x == 0).all() + + tree_map(check, storage.get(range(10))) + else: + storage.set(range(10), torch.zeros_like(data)) + assert (storage.get(range(10)) == 0).all() + + @pytest.mark.parametrize( + "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] + ) + @pytest.mark.parametrize("storage_type", [TensorStorage]) + def test_state_dict(self, storage_type, data_type): + if data_type == "tensor": + data = self._get_tensor() + elif data_type == "tensorclass": + data = self._get_tensorclass() + elif data_type == "tensordict": + data = self._get_tensordict() + elif data_type == "pytree": + data = self._get_pytree() + else: + raise NotImplementedError + storage = storage_type(data) + if data_type == "pytree": + with pytest.raises(TypeError, match="are not supported by"): + storage.state_dict() + return + sd = storage.state_dict() + storage2 = storage_type(torch.zeros_like(data)) + storage2.load_state_dict(sd) + assert (storage.get(range(10)) == storage2.get(range(10))).all() + assert type(storage.get(range(10))) is type( # noqa: E721 + storage2.get(range(10)) + ) + + @pytest.mark.gpu + @pytest.mark.skipif( + not torch.cuda.device_count(), + reason="not cuda device found to test rb storage.", + ) + @pytest.mark.parametrize( + "device_data,device_storage", + [ + [torch.device("cuda"), torch.device("cpu")], + [torch.device("cpu"), torch.device("cuda")], + [torch.device("cpu"), "auto"], + [torch.device("cuda"), "auto"], + ], + ) + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize("data_type", ["tensor", "tc", "td"]) + def test_storage_device(self, device_data, device_storage, storage_type, data_type): + @tensorclass + class TC: + a: torch.Tensor + + if data_type == "tensor": + data = torch.randn(3, device=device_data) + elif data_type == "td": + data = TensorDict( + {"a": torch.randn(3, device=device_data)}, [], device=device_data + ) + elif data_type == "tc": + data = TC( + a=torch.randn(3, device=device_data), + batch_size=[], + device=device_data, + ) + else: + raise NotImplementedError + + if ( + storage_type is LazyMemmapStorage + and device_storage != "auto" + and device_storage.type != "cpu" + ): + with pytest.raises(ValueError, match="Memory map device other than CPU"): + storage_type(max_size=10, device=device_storage) + return + storage = storage_type(max_size=10, device=device_storage) + storage.set(0, data) + if device_storage != "auto": + assert storage.get(0).device.type == device_storage.type + else: + assert storage.get(0).device.type == storage.device.type + + @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) + @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) + @pytest.mark.parametrize("init_out", [True, False]) + @pytest.mark.parametrize( + "backend", ["torch"] + (["torchsnapshot"] if _has_snapshot else []) + ) + def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): + os.environ["CKPT_BACKEND"] = backend + buffer_size = 100 + if storage_in == "memmap": + storage_in = LazyMemmapStorage(buffer_size, device="cpu") + elif storage_in == "tensor": + storage_in = LazyTensorStorage(buffer_size, device="cpu") + if storage_out == "memmap": + storage_out = LazyMemmapStorage(buffer_size, device="cpu") + elif storage_out == "tensor": + storage_out = LazyTensorStorage(buffer_size, device="cpu") + + replay_buffer = TensorDictReplayBuffer( + pin_memory=False, prefetch=3, storage=storage_in, batch_size=3 + ) + # fill replay buffer with random data + transition = TensorDict( + { + "observation": torch.ones(1, 4), + "action": torch.ones(1, 2), + "reward": torch.ones(1, 1), + "dones": torch.ones(1, 1), + "next": {"observation": torch.ones(1, 4)}, + }, + batch_size=1, + ) + for _ in range(3): + replay_buffer.extend(transition) + + state_dict = replay_buffer.state_dict() + + new_replay_buffer = TensorDictReplayBuffer( + pin_memory=False, + prefetch=3, + storage=storage_out, + batch_size=state_dict["_batch_size"], + ) + if init_out: + new_replay_buffer.extend(transition) + + new_replay_buffer.load_state_dict(state_dict) + s = new_replay_buffer.sample() + assert (s.exclude("index") == 1).all() + + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") + @pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="torch.compile is not supported on Python 3.14+", + ) + # This test checks if the `torch._dynamo.disable` wrapper around + # `TensorStorage._rand_given_ndim` is still necessary. + def test__rand_given_ndim_recompile(self): + torch._dynamo.reset_code_caches() + + # Number of times to extend the replay buffer + num_extend = 5 + data_size = 50 + storage_size = (num_extend + 1) * data_size + sample_size = 3 + + storage = LazyTensorStorage(storage_size, compilable=True) + sampler = RandomSampler() + + # Override to avoid the `torch._dynamo.disable` wrapper + storage._rand_given_ndim = storage._rand_given_ndim_impl + + @torch.compile + def extend_and_sample(data): + storage.set(torch.arange(data_size) + len(storage), data) + return sampler.sample(storage, sample_size) + + data = torch.randint(100, (data_size, 1)) + + try: + torch._logging.set_logs(recompiles=True) + records = [] + capture_log_records(records, "torch._dynamo", "recompiles") + + for _ in range(num_extend): + extend_and_sample(data) + + finally: + torch._logging.set_logs() + + assert len(storage) == num_extend * data_size + assert len(records) <= 8, ( + "Excessive recompilations detected. Expected 8 or fewer, but got " + f"{len(records)}. This suggests the `torch.compiler.disable` " + "decorators may not be working properly or new recompilation " + "sources have been introduced." + ) + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_extend_lazystack(self, storage_type): + rb = ReplayBuffer( + storage=storage_type(6), + batch_size=2, + ) + td1 = TensorDict(a=torch.rand(5, 4, 8), batch_size=5) + td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5) + ltd = LazyStackedTensorDict(td1, td2, stack_dim=1) + rb.extend(ltd) + rb.sample(3) + assert len(rb) == 5 + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_extend_lazystack_direct_write(self, storage_type): + """Test that lazy stacks can be extended to storage correctly. + + This tests that lazy stacks from collectors are properly stored in + replay buffers and that the data integrity is preserved. Also verifies + that the update_() optimization is used for tensor indices. + """ + rb = ReplayBuffer( + storage=storage_type(100), + batch_size=10, + ) + # Create a list of tensordicts (like a collector would produce) + tensordicts = [ + TensorDict( + {"obs": torch.rand(4, 8), "action": torch.rand(2)}, batch_size=() + ) + for _ in range(10) + ] + # Create lazy stack with stack_dim=0 (the batch dimension) + lazy_td = LazyStackedTensorDict.lazy_stack(tensordicts, dim=0) + assert isinstance(lazy_td, LazyStackedTensorDict) + + # Track calls to update_at_() - used for tensor indices + update_at_called = [] + original_update_at = TensorDictBase.update_at_ + + def mock_update_at(self, *args, **kwargs): + update_at_called.append(True) + return original_update_at(self, *args, **kwargs) + + # Extend with lazy stack and verify update_at_() is called + # (rb.extend uses tensor indices, so update_at_() path is taken) + with mock.patch.object(TensorDictBase, "update_at_", mock_update_at): + rb.extend(lazy_td) + + # Verify update_at_() was called (optimization was used) + assert len(update_at_called) > 0, "update_at_() should have been called" + + # Verify data integrity + assert len(rb) == 10 + sample = rb.sample(5) + assert sample["obs"].shape == (5, 4, 8) + assert sample["action"].shape == (5, 2) + + # Verify all data is accessible by reading the entire storage + all_data = rb[:] + assert all_data["obs"].shape == (10, 4, 8) + assert all_data["action"].shape == (10, 2) + + # Verify data values are preserved (check against original stacked data) + expected = lazy_td.to_tensordict() + assert torch.allclose(all_data["obs"], expected["obs"]) + assert torch.allclose(all_data["action"], expected["action"]) + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_extend_lazystack_2d_storage(self, storage_type): + """Test lazy stack optimization for 2D storage (parallel envs). + + When using parallel environments, the storage is 2D [max_size, n_steps] + and the lazy stack has stack_dim=1 (time dimension). This test verifies + the optimization handles this case correctly. + """ + n_envs = 4 + n_steps = 10 + img_shape = (3, 32, 32) + + # Create 2D storage - capacity is 100 * n_steps when ndim=2 + storage = storage_type(100 * n_steps, ndim=2) + + # Pre-initialize storage with correct shape by setting first element + init_td = TensorDict( + {"pixels": torch.zeros(n_steps, *img_shape)}, + batch_size=[n_steps], + ) + storage.set(0, init_td, set_cursor=False) + + # Expand storage to full size + full_init = TensorDict( + {"pixels": torch.zeros(100, n_steps, *img_shape)}, + batch_size=[100, n_steps], + ) + storage.set(slice(0, 100), full_init, set_cursor=False) + + # Create lazy stack simulating parallel env output + # stack_dim=1 means stacked along time dimension + time_tds = [ + TensorDict( + {"pixels": torch.rand(n_envs, *img_shape)}, + batch_size=[n_envs], + ) + for _ in range(n_steps) + ] + lazy_td = LazyStackedTensorDict.lazy_stack(time_tds, dim=1) + assert lazy_td.stack_dim == 1 + assert lazy_td.batch_size == torch.Size([n_envs, n_steps]) + + # Write using tensor indices (simulating circular buffer behavior) + cursor = torch.tensor([0, 1, 2, 3]) + storage.set(cursor, lazy_td) + + # Verify data integrity + for i in range(n_envs): + stored = storage[i] + expected = lazy_td[i].to_tensordict() + assert torch.allclose( + stored["pixels"], expected["pixels"] + ), f"Data mismatch for env {i}" + + @pytest.mark.parametrize("device_data", get_default_devices()) + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) + @pytest.mark.parametrize("isinit", [True, False]) + def test_storage_dumps_loads( + self, device_data, storage_type, data_type, isinit, tmpdir + ): + torch.manual_seed(0) + + dir_rb = tmpdir / "rb" + dir_save = tmpdir / "save" + dir_rb.mkdir() + dir_save.mkdir() + torch.manual_seed(0) + + @tensorclass + class TC: + tensor: torch.Tensor + td: TensorDict + text: str + + if data_type == "tensor": + data = torch.randint(10, (3,), device=device_data) + elif data_type == "pytree": + data = { + "a": torch.randint(10, (3,), device=device_data), + "b": {"c": [torch.ones(3), (-torch.ones(3, 2),)]}, + 30: -torch.ones(3, 1), + } + elif data_type == "td": + data = TensorDict( + { + "a": torch.randint(10, (3,), device=device_data), + "b": TensorDict( + {"c": torch.randint(10, (3,), device=device_data)}, + batch_size=[3], + ), + }, + batch_size=[3], + device=device_data, + ) + elif data_type == "tc": + data = TC( + tensor=torch.randint(10, (3,), device=device_data), + td=TensorDict( + {"c": torch.randint(10, (3,), device=device_data)}, batch_size=[3] + ), + text="some text", + batch_size=[3], + device=device_data, + ) + else: + raise NotImplementedError + + if storage_type in (LazyMemmapStorage,): + storage = storage_type(max_size=10, scratch_dir=dir_rb) + else: + storage = storage_type(max_size=10) + + # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index + if data_type == "pytree": + storage.set(range(3), tree_map(lambda x: x.cpu(), data)) + else: + storage.set(range(3), data.cpu()) + + storage.dumps(dir_save) + # check we can dump twice + storage.dumps(dir_save) + + storage_recover = storage_type(max_size=10) + if isinit: + if data_type == "pytree": + storage_recover.set( + range(3), tree_map(lambda x: x.cpu().clone().zero_(), data) + ) + else: + storage_recover.set(range(3), data.cpu().clone().zero_()) + + if data_type in ("tensor", "pytree") and not isinit: + with pytest.raises( + RuntimeError, + match="Cannot fill a non-initialized pytree-based TensorStorage", + ): + storage_recover.loads(dir_save) + return + storage_recover.loads(dir_save) + # tree_map with more than one pytree is only available in torch >= 2.3 + if torch_2_3: + if data_type in ("tensor", "pytree"): + tree_map( + torch.testing.assert_close, + tree_flatten(storage[:])[0], + tree_flatten(storage_recover[:])[0], + ) + else: + assert_allclose_td(storage[:], storage_recover[:]) + if data == "tc": + assert storage._storage.text == storage_recover._storage.text + + def test_add_list_of_tds(self): + rb = ReplayBuffer(storage=LazyTensorStorage(100)) + rb.extend([TensorDict({"a": torch.randn(2, 3)}, [2])]) + assert len(rb) == 1 + assert rb[:].shape == torch.Size([1, 2]) + + @pytest.mark.parametrize( + "storage_type,collate_fn", + [ + (LazyTensorStorage, None), + (LazyMemmapStorage, None), + (ListStorage, torch.stack), + ], + ) + def test_storage_inplace_writing(self, storage_type, collate_fn): + rb = ReplayBuffer(storage=storage_type(102), collate_fn=collate_fn) + data = TensorDict( + {"a": torch.arange(100), ("b", "c"): torch.arange(100)}, [100] + ) + rb.extend(data) + assert len(rb) == 100 + rb[3:4] = TensorDict( + {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] + ) + assert (rb[3:4] == 0).all() + assert len(rb) == 100 + assert rb.writer._cursor == 100 + rb[10:20] = TensorDict( + {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] + ) + assert (rb[10:20] == 0).all() + assert len(rb) == 100 + assert rb.writer._cursor == 100 + rb[torch.arange(30, 40)] = TensorDict( + {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] + ) + assert (rb[30:40] == 0).all() + assert len(rb) == 100 + + @pytest.mark.parametrize( + "storage_type,collate_fn", + [ + (LazyTensorStorage, None), + (LazyMemmapStorage, None), + (ListStorage, torch.stack), + ], + ) + def test_storage_inplace_writing_transform(self, storage_type, collate_fn): + rb = ReplayBuffer(storage=storage_type(102), collate_fn=collate_fn) + rb.append_transform(lambda x: x + 1, invert=True) + rb.append_transform(lambda x: x + 1) + data = TensorDict( + {"a": torch.arange(100), ("b", "c"): torch.arange(100)}, [100] + ) + rb.extend(data) + assert len(rb) == 100 + rb[3:4] = TensorDict( + {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] + ) + assert (rb[3:4] == 2).all(), rb[3:4]["a"] + assert len(rb) == 100 + assert rb.writer._cursor == 100 + rb[10:20] = TensorDict( + {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] + ) + assert (rb[10:20] == 2).all() + assert len(rb) == 100 + assert rb.writer._cursor == 100 + rb[torch.arange(30, 40)] = TensorDict( + {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] + ) + assert (rb[30:40] == 2).all() + assert len(rb) == 100 + + @pytest.mark.parametrize( + "storage_type,collate_fn", + [ + (LazyTensorStorage, None), + # (LazyMemmapStorage, None), + (ListStorage, TensorDict.maybe_dense_stack), + ], + ) + def test_storage_inplace_writing_newkey(self, storage_type, collate_fn): + rb = ReplayBuffer(storage=storage_type(102), collate_fn=collate_fn) + data = TensorDict( + {"a": torch.arange(100), ("b", "c"): torch.arange(100)}, [100] + ) + rb.extend(data) + assert len(rb) == 100 + rb[3:4] = TensorDict( + {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0]), "d": torch.ones(1)}, + [1], + ) + assert "d" in rb[3] + assert "d" in rb[3:4] + if storage_type is not ListStorage: + assert "d" in rb[3:5] + else: + # a lazy stack doesn't show exclusive fields + assert "d" not in rb[3:5] + + @pytest.mark.parametrize("storage_type", [LazyTensorStorage, LazyMemmapStorage]) + def test_storage_inplace_writing_ndim(self, storage_type): + rb = ReplayBuffer(storage=storage_type(102, ndim=2)) + data = TensorDict( + { + "a": torch.arange(50).expand(2, 50), + ("b", "c"): torch.arange(50).expand(2, 50), + }, + [2, 50], + ) + rb.extend(data) + assert len(rb) == 100 + rb[0, 3:4] = TensorDict( + {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] + ) + assert (rb[0, 3:4] == 0).all() + assert (rb[1, 3:4] != 0).all() + assert rb.writer._cursor == 50 + rb[1, 5:6] = TensorDict( + {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] + ) + assert (rb[1, 5:6] == 0).all() + assert rb.writer._cursor == 50 + rb[:, 7:8] = TensorDict( + {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] + ).expand(2, 1) + assert (rb[:, 7:8] == 0).all() + assert rb.writer._cursor == 50 + # test broadcasting + rb[:, 10:20] = TensorDict( + {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] + ) + assert (rb[:, 10:20] == 0).all() + assert len(rb) == 100 + + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" + ) + @pytest.mark.parametrize("max_size", [1000, None]) + @pytest.mark.parametrize("stack_dim", [-1, 0]) + def test_lazy_stack_storage(self, max_size, stack_dim): + # Create an instance of LazyStackStorage with given parameters + storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim) + # Create a ReplayBuffer using the created storage + rb = ReplayBuffer(storage=storage) + # Generate some random data to add to the buffer + torch.manual_seed(0) + data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!") + data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!") + # Add the data to the buffer + rb.add(data0) + rb.add(data1) + # Sample from the buffer + sample = rb.sample(10) + # Check that the sampled data has the correct shape and type + assert isinstance(sample, LazyStackedTensorDict) + assert sample["b"].shape[0] == 10 + assert all(isinstance(item, str) for item in sample["c"]) + # If densify is True, check that the sampled data is dense + sample = sample.densify(layout=torch.jagged) + assert isinstance(sample["a"], torch.Tensor) + assert sample["a"].shape[0] == 10 + + +@pytest.mark.parametrize("max_size", [1000]) +@pytest.mark.parametrize("shape", [[3, 4]]) +@pytest.mark.parametrize("storage", [LazyTensorStorage, LazyMemmapStorage]) +class TestLazyStorages: + def _get_nested_tensorclass(self, shape): + @tensorclass + class NestedTensorClass: + key1: torch.Tensor + key2: torch.Tensor + + @tensorclass + class TensorClass: + key1: torch.Tensor + key2: torch.Tensor + next: NestedTensorClass + + return TensorClass( + key1=torch.ones(*shape), + key2=torch.ones(*shape), + next=NestedTensorClass( + key1=torch.ones(*shape), key2=torch.ones(*shape), batch_size=shape + ), + batch_size=shape, + ) + + def _get_nested_td(self, shape): + nested_td = TensorDict( + { + "key1": torch.ones(*shape), + "key2": torch.ones(*shape), + "next": TensorDict( + { + "key1": torch.ones(*shape), + "key2": torch.ones(*shape), + }, + shape, + ), + }, + shape, + ) + return nested_td + + def test_init(self, max_size, shape, storage): + td = self._get_nested_td(shape) + mystorage = storage(max_size=max_size) + mystorage._init(td) + assert mystorage._storage.shape == (max_size, *shape) + + def test_set(self, max_size, shape, storage): + td = self._get_nested_td(shape) + mystorage = storage(max_size=max_size) + mystorage.set(list(range(td.shape[0])), td) + assert mystorage._storage.shape == (max_size, *shape[1:]) + idx = list(range(1, td.shape[0] - 1)) + tc_sample = mystorage.get(idx) + assert tc_sample.shape == torch.Size([td.shape[0] - 2, *td.shape[1:]]) + + def test_init_tensorclass(self, max_size, shape, storage): + tc = self._get_nested_tensorclass(shape) + mystorage = storage(max_size=max_size) + mystorage._init(tc) + assert is_tensorclass(mystorage._storage) + assert mystorage._storage.shape == (max_size, *shape) + + def test_set_tensorclass(self, max_size, shape, storage): + tc = self._get_nested_tensorclass(shape) + mystorage = storage(max_size=max_size) + mystorage.set(list(range(tc.shape[0])), tc) + assert mystorage._storage.shape == (max_size, *shape[1:]) + idx = list(range(1, tc.shape[0] - 1)) + tc_sample = mystorage.get(idx) + assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]]) + + def test_extend_list_pytree(self, max_size, shape, storage): + memory = ReplayBuffer( + storage=storage(max_size=max_size), + sampler=SamplerWithoutReplacement(), + ) + data = [ + ( + torch.full(shape, i), + {"a": torch.full(shape, i), "b": (torch.full(shape, i))}, + [torch.full(shape, i)], + ) + for i in range(10) + ] + memory.extend(data) + assert len(memory) == 10 + assert len(memory._storage) == 10 + sample = memory.sample(10) + for leaf in tree_iter(sample): + assert (leaf.unique(sorted=True) == torch.arange(10)).all() + memory = ReplayBuffer( + storage=storage(max_size=max_size), + sampler=SamplerWithoutReplacement(), + ) + t1x4 = torch.Tensor([0.1, 0.2, 0.3, 0.4]) + t1x1 = torch.Tensor([0.01]) + with pytest.raises( + RuntimeError, match="Stacking the elements of the list resulted in an error" + ): + memory.extend([t1x4, t1x1, t1x4 + 0.4, t1x1 + 0.01]) + + +def test_storage_save_hook(tmpdir): + observed = {} + + class SaveHook: + shift = None + is_full = None + + def __call__(self, data, path=None): + observed["shift"] = self.shift + observed["is_full"] = self.is_full + return data + + hook = SaveHook() + rb = ReplayBuffer(storage=LazyMemmapStorage(10)) + rb.register_save_hook(hook) + rb.extend(torch.arange(5)) + rb.dumps(tmpdir) + + assert hook.shift == 5, f"Expected shift=5, got {hook.shift}" + assert hook.is_full is False, f"Expected is_full=False, got {hook.is_full}" + assert observed["shift"] == 5 + assert observed["is_full"] is False + + +class TestSharedStorageInit: + def worker(self, rb, worker_id, queue): + length = len(rb) + data = TensorDict({"x": torch.full((2,), worker_id)}, batch_size=(2,)) + worker_id * 2 + index = rb.extend(data) + assert len(rb) >= length + 2 + assert (rb[index] == data).all() + queue.put("done") + + @pytest.mark.parametrize( + "storage_cls, use_tmpdir", + [ + (LazyTensorStorage, False), + (LazyMemmapStorage, False), + (LazyMemmapStorage, True), + ], + ) + def test_shared_storage_multiprocess(self, storage_cls, use_tmpdir, tmpdir): + if use_tmpdir: + storage_cls = functools.partial(storage_cls, scratch_dir=tmpdir) + storage = storage_cls(max_size=100, shared_init=True) + rb = ReplayBuffer(storage=storage, batch_size=2).share(True) + queue = mp.Queue() + + processes = [] + for i in range(4): + p = mp.Process(target=self.worker, args=(rb, i, queue)) + processes.append(p) + p.start() + + for p in processes: + p.join() + queue.get() + + all_data = storage.get(slice(0, 8)) + values = set(all_data["x"].tolist()) + expected = {0.0, 1.0, 2.0, 3.0} + assert expected.issubset(values) + assert len(storage) >= 8 + + def prioritized_collector_worker(self, rb, worker_id, queue): + data = TensorDict( + { + "obs": torch.full((4, 1), worker_id, dtype=torch.float32), + "td_error": torch.linspace(0.1, 1.0, 4) + worker_id, + }, + batch_size=(4,), + ) + rb.extend(data) + queue.put("done") + + @pytest.mark.gpu + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") + def test_prioritized_memmap_cuda_sampler_after_multiprocess_writes(self, tmpdir): + ext = pytest.importorskip("torchrl._torchrl") + if not hasattr(ext, "CudaSumSegmentTreeFp32"): + pytest.skip("TorchRL was not built with CUDA segment tree support") + + storage = LazyMemmapStorage(max_size=32, scratch_dir=tmpdir, shared_init=True) + writer_rb = TensorDictReplayBuffer(storage=storage, batch_size=4).share(True) + queue = mp.Queue() + + processes = [] + for i in range(2): + p = mp.Process( + target=self.prioritized_collector_worker, + args=(writer_rb, i, queue), + ) + processes.append(p) + p.start() + + for p in processes: + p.join() + assert p.exitcode == 0 + assert queue.get(timeout=5) == "done" + + assert len(storage) == 8 + learner_rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + storage=storage, + sampler_device="cuda:0", + batch_size=4, + priority_key="td_error", + ) + + sample = learner_rb.sample() + assert learner_rb._sampler.device == torch.device("cuda:0") + assert sample["obs"].device.type == "cpu" + assert sample["index"].device.type == "cpu" + assert sample["priority_weight"].device.type == "cpu" + + sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 + learner_rb.update_tensordict_priority(sample) + sample = learner_rb.sample() + assert sample["index"].device.type == "cpu" + + +@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") +class TestCompressedListStorage: + """Test cases for CompressedListStorage.""" + + def test_compressed_storage_initialization(self): + """Test that CompressedListStorage initializes correctly.""" + storage = CompressedListStorage(max_size=100, compression_level=3) + assert storage.max_size == 100 + assert storage.compression_level == 3 + assert len(storage) == 0 + + @pytest.mark.parametrize( + "test_tensor", + [ + torch.rand(1), # 0D scalar + torch.randn(84, dtype=torch.float32), # 1D tensor + torch.randn(84, 84, dtype=torch.float32), # 2D tensor + torch.randn(1, 84, 84, dtype=torch.float32), # 3D tensor + torch.randn(32, 84, 84, dtype=torch.float32), # 3D tensor + ], + ) + def test_compressed_storage_tensor(self, test_tensor): + """Test compression and decompression of tensor data of various shapes.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Store tensor + storage.set(0, test_tensor) + + # Retrieve tensor + retrieved_tensor = storage.get(0) + + # Verify data integrity + assert ( + test_tensor.shape == retrieved_tensor.shape + ), f"Expected shape {test_tensor.shape}, got {retrieved_tensor.shape}" + assert ( + test_tensor.dtype == retrieved_tensor.dtype + ), f"Expected dtype {test_tensor.dtype}, got {retrieved_tensor.dtype}" + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_tensordict(self): + """Test compression and decompression of TensorDict data.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Create test TensorDict + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "reward": torch.randn(3), + "done": torch.tensor([False, True, False]), + }, + batch_size=[3], + ) + + # Store TensorDict + storage.set(0, test_td) + + # Retrieve TensorDict + retrieved_td = storage.get(0) + + # Verify data integrity + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6) + assert torch.allclose(test_td["done"], retrieved_td["done"]) + + def test_compressed_storage_multiple_indices(self): + """Test storing and retrieving multiple items.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Store multiple tensors + tensors = [ + torch.randn(2, 2, dtype=torch.float32), + torch.randn(3, 3, dtype=torch.float32), + torch.randn(4, 4, dtype=torch.float32), + ] + + for i, tensor in enumerate(tensors): + storage.set(i, tensor) + + # Retrieve multiple tensors + retrieved = storage.get([0, 1, 2]) + + # Verify data integrity + for original, retrieved_tensor in zip(tensors, retrieved): + assert torch.allclose(original, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_with_replay_buffer(self): + """Test CompressedListStorage with ReplayBuffer.""" + storage = CompressedListStorage(max_size=100, compression_level=3) + rb = ReplayBuffer(storage=storage, batch_size=5) + + # Create test data + data = TensorDict( + { + "obs": torch.randn(10, 3, 84, 84, dtype=torch.float32), + "action": torch.randint(0, 4, (10,)), + "reward": torch.randn(10), + }, + batch_size=[10], + ) + + # Add data to replay buffer + rb.extend(data) + + # Sample from replay buffer + sample = rb.sample(5) + + # Verify sample has correct shape + assert is_tensor_collection(sample), sample + assert sample["obs"].shape[0] == 5 + assert sample["obs"].shape[1:] == (3, 84, 84) + assert sample["action"].shape[0] == 5 + assert sample["reward"].shape[0] == 5 + + def test_compressed_storage_state_dict(self): + """Test saving and loading state dict.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Add some data + test_tensor = torch.randn(3, 3, dtype=torch.float32) + storage.set(0, test_tensor) + + # Save state dict + state_dict = storage.state_dict() + + # Create new storage and load state dict + new_storage = CompressedListStorage(max_size=10, compression_level=3) + new_storage.load_state_dict(state_dict) + + # Verify data integrity + retrieved_tensor = new_storage.get(0) + assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) + + def test_compressed_storage_checkpointing(self): + """Test checkpointing functionality.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Add some data + test_td = TensorDict( + { + "obs": torch.randn(3, 84, 84, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + }, + batch_size=[3], + ) + storage.set(0, test_td) + + # second batch, different shape + test_td2 = TensorDict( + { + "obs": torch.randn(3, 85, 83, dtype=torch.float32), + "action": torch.tensor([1, 2, 3]), + "meta": torch.randn(3), + "astring": "a string!", + }, + batch_size=[3], + ) + storage.set(1, test_td) + + # Create temporary directory for checkpointing + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_path = Path(tmpdir) / "checkpoint" + + # Save checkpoint + storage.dumps(checkpoint_path) + + # Create new storage and load checkpoint + new_storage = CompressedListStorage(max_size=10, compression_level=3) + new_storage.loads(checkpoint_path) + + # Verify data integrity + retrieved_td = new_storage.get(0) + assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) + assert torch.allclose(test_td["action"], retrieved_td["action"]) + + def test_compressed_storage_length(self): + """Test that length is calculated correctly.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Initially empty + assert len(storage) == 0 + + # Add some data + storage.set(0, torch.randn(2, 2)) + assert len(storage) == 1 + + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 2 + + storage.set(2, torch.randn(2, 2)) + assert len(storage) == 3 + + def test_compressed_storage_contains(self): + """Test the contains method.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Initially empty + assert not storage.contains(0) + + # Add data + storage.set(0, torch.randn(2, 2)) + assert storage.contains(0) + assert not storage.contains(1) + + def test_compressed_storage_empty(self): + """Test emptying the storage.""" + storage = CompressedListStorage(max_size=10, compression_level=3) + + # Add some data + storage.set(0, torch.randn(2, 2)) + storage.set(1, torch.randn(2, 2)) + assert len(storage) == 2 + + # Empty storage + storage._empty() + assert len(storage) == 0 + + def test_compressed_storage_custom_compression(self): + """Test custom compression functions.""" + + def custom_compress(tensor): + # Simple compression: just convert to uint8 + return tensor.to(torch.uint8) + + def custom_decompress(compressed_tensor, metadata): + # Simple decompression: convert back to original dtype + return compressed_tensor.to(metadata["dtype"]) + + storage = CompressedListStorage( + max_size=10, + compression_fn=custom_compress, + decompression_fn=custom_decompress, + ) + + # Test with tensor + test_tensor = torch.randn(2, 2, dtype=torch.float32) + storage.set(0, test_tensor) + retrieved_tensor = storage.get(0) + + # Note: This will lose precision due to uint8 conversion + # but should still work + assert retrieved_tensor.shape == test_tensor.shape + + def test_compressed_storage_error_handling(self): + """Test error handling for invalid operations.""" + storage = CompressedListStorage(max_size=5, compression_level=3) + + # Test setting data beyond max_size + with pytest.raises(RuntimeError): + storage.set(10, torch.randn(2, 2)) + + # Test getting non-existent data + with pytest.raises(IndexError): + storage.get(0) + + def test_compressed_storage_memory_efficiency(self): + """Test that compression actually reduces memory usage.""" + storage = CompressedListStorage(max_size=100, compression_level=3) + + # Create large tensor data + large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) + large_tensor.copy_( + torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor) + // (3 * 84 * 84) + ) + original_size = large_tensor.numel() * large_tensor.element_size() + + # Store in compressed storage + storage.set(0, large_tensor) + + # Estimate compressed size + compressed_data = storage._storage[0] + compressed_size = compressed_data.numel() # uint8 bytes + + # Verify compression ratio is reasonable (at least 2x for random data) + compression_ratio = original_size / compressed_size + assert ( + compression_ratio > 1.5 + ), f"Compression ratio {compression_ratio} is too low" + + +class TestRBLazyInit: + def test_lazy_init(self): + def transform(td): + return td + + rb = ReplayBuffer( + storage=partial(ListStorage), + writer=partial(RoundRobinWriter), + sampler=partial(RandomSampler), + transform_factory=lambda: transform, + ) + assert not rb.initialized + assert not hasattr(rb, "_storage") + assert rb._init_storage is not None + assert not hasattr(rb, "_sampler") + assert rb._init_sampler is not None + assert not hasattr(rb, "_writer") + assert rb._init_writer is not None + rb.extend(TensorDict(batch_size=[2])) + assert rb.initialized + assert rb._storage is not None + assert rb._init_storage is None + assert rb._sampler is not None + assert rb._init_sampler is None + assert rb._writer is not None + assert rb._init_writer is None + + rb = ReplayBuffer( + storage=partial(ListStorage), + writer=partial(RoundRobinWriter), + sampler=partial(RandomSampler), + ) + assert rb.initialized + assert rb._storage is not None + assert rb._init_storage is None + assert rb._sampler is not None + assert rb._init_sampler is None + assert rb._writer is not None + assert rb._init_writer is None + + rb = ReplayBuffer( + storage=partial(ListStorage), + writer=partial(RoundRobinWriter), + sampler=partial(RandomSampler), + delayed_init=False, + ) + assert rb.initialized + assert rb._storage is not None + assert rb._init_storage is None + assert rb._sampler is not None + assert rb._init_sampler is None + assert rb._writer is not None + assert rb._init_writer is None + + +@pytest.mark.skipif( + _os_is_windows, reason="Windows file locking prevents cleanup tests" +) +class TestLazyMemmapStorageCleanup: + """Tests for LazyMemmapStorage automatic cleanup functionality.""" + + def test_cleanup_explicit_scratch_dir(self, tmpdir): + """Test that cleanup removes files when scratch_dir is specified.""" + scratch_dir = str(tmpdir / "memmap_storage") + os.makedirs(scratch_dir, exist_ok=True) + + storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) + rb = ReplayBuffer(storage=storage) + rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) + + # Verify files were created + assert os.path.isdir(scratch_dir) + assert len(os.listdir(scratch_dir)) > 0 + + # Cleanup should remove the directory + result = storage.cleanup() + assert result is True + assert not os.path.exists(scratch_dir) + + # Second cleanup should be a no-op + result = storage.cleanup() + assert result is False + + def test_cleanup_temp_dir(self): + """Test cleanup when using default temp directory.""" + storage = LazyMemmapStorage(100, auto_cleanup=True) + rb = ReplayBuffer(storage=storage) + rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) + + # Get the temp directory paths before cleanup + temp_paths = set() + for tensor in storage._storage.values(include_nested=True, leaves_only=True): + try: + if hasattr(tensor, "filename") and tensor.filename: + temp_paths.add(os.path.dirname(tensor.filename)) + except (AttributeError, RuntimeError): + continue + + # Cleanup should remove the files if any were created on disk + result = storage.cleanup() + if len(temp_paths) > 0: + assert result is True + # Paths should no longer exist + for path in temp_paths: + assert not os.path.exists(path) + else: + # If no files were created (e.g. anonymous memmap), result should be False + assert result is False + + def test_auto_cleanup_default_behavior(self, tmpdir): + """Test that auto_cleanup defaults correctly based on scratch_dir.""" + # When scratch_dir is None, auto_cleanup should default to True + storage1 = LazyMemmapStorage(100) + assert storage1._auto_cleanup is True + assert storage1._scratch_dir_is_temp is True + + # When scratch_dir is provided, auto_cleanup should default to False + scratch_dir = str(tmpdir / "user_storage") + storage2 = LazyMemmapStorage(100, scratch_dir=scratch_dir) + assert storage2._auto_cleanup is False + assert storage2._scratch_dir_is_temp is False + + # User can override + storage3 = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) + assert storage3._auto_cleanup is True + + storage4 = LazyMemmapStorage(100, auto_cleanup=False) + assert storage4._auto_cleanup is False + + def test_cleanup_idempotent(self, tmpdir): + """Test that cleanup can be called multiple times safely.""" + scratch_dir = str(tmpdir / "memmap_storage") + os.makedirs(scratch_dir, exist_ok=True) + + storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) + rb = ReplayBuffer(storage=storage) + rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) + + # Multiple cleanups should not raise + storage.cleanup() + storage.cleanup() + storage.cleanup() + assert storage._cleaned_up is True + + def test_cleanup_nonexistent_dir(self, tmpdir): + """Test cleanup when directory was already deleted.""" + scratch_dir = str(tmpdir / "memmap_storage") + os.makedirs(scratch_dir, exist_ok=True) + + storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) + rb = ReplayBuffer(storage=storage) + rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) + + # Delete the directory externally + shutil.rmtree(scratch_dir) + assert not os.path.exists(scratch_dir) + + # Cleanup should handle missing directory gracefully + result = storage.cleanup() + assert result is False # No cleanup needed since dir is gone + + def test_cleanup_uninitialized_storage(self): + """Test cleanup on storage that was never used.""" + storage = LazyMemmapStorage(100, auto_cleanup=True) + # Storage is not initialized - cleanup should be safe + result = storage.cleanup() + assert result is False + + def test_cleanup_registry(self): + """Test that storages are registered for cleanup.""" + storage = LazyMemmapStorage(100, auto_cleanup=True) + # Check storage is in the registry (avoids race with GC on WeakSet) + assert storage in _MEMMAP_STORAGE_REGISTRY + + # Storage with auto_cleanup=False should not be registered + storage2 = LazyMemmapStorage(100, auto_cleanup=False) + assert storage2 not in _MEMMAP_STORAGE_REGISTRY + # Original storage should still be in the registry + assert storage in _MEMMAP_STORAGE_REGISTRY + + # Cleanup should still work + storage.cleanup() + + def test_cleanup_subprocess(self, tmpdir): + """Test that cleanup works correctly in subprocess scenarios.""" + scratch_dir = str(tmpdir / "subprocess_storage") + + # Create a script that creates a storage and exits normally + script = f""" +import torch +from tensordict import TensorDict +from torchrl.data import ReplayBuffer, LazyMemmapStorage + +storage = LazyMemmapStorage(100, scratch_dir="{scratch_dir}", auto_cleanup=True) +rb = ReplayBuffer(storage=storage) +rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) +print("Storage created") +# Normal exit - atexit handler should clean up +""" + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=30, + ) + + # Script should have succeeded + assert result.returncode == 0, f"Script failed: {result.stderr}" + + # Directory should have been cleaned up on exit + assert not os.path.exists( + scratch_dir + ), f"Directory {scratch_dir} should have been cleaned up" + + def test_cleanup_signal_interrupt(self, tmpdir): + """Test that cleanup happens on SIGINT (Ctrl+C).""" + scratch_dir = str(tmpdir / "signal_storage") + + # Create a script that sleeps and can be interrupted + script = f""" +import signal +import time +import torch +from tensordict import TensorDict +from torchrl.data import ReplayBuffer, LazyMemmapStorage + +storage = LazyMemmapStorage(100, scratch_dir="{scratch_dir}", auto_cleanup=True) +rb = ReplayBuffer(storage=storage) +rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) +print("READY", flush=True) +time.sleep(60) # Will be interrupted +""" + proc = subprocess.Popen( + [sys.executable, "-c", script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Wait for the script to be ready + try: + # Read until we see READY + start = time.time() + while time.time() - start < 10: + line = proc.stdout.readline() + if "READY" in line: + break + else: + proc.kill() + pytest.skip("Script did not start in time") + + # Give it a moment to set up signal handlers + time.sleep(0.5) + + # Verify directory exists + assert os.path.isdir(scratch_dir) + + # Send SIGINT (Ctrl+C) + proc.send_signal(signal.SIGINT) + proc.wait(timeout=5) + + # Directory should have been cleaned up + assert not os.path.exists( + scratch_dir + ), f"Directory {scratch_dir} should have been cleaned up on SIGINT" + finally: + if proc.poll() is None: + proc.kill() + proc.wait() + + def test_cleanup_with_del(self, tmpdir): + """Test that __del__ triggers cleanup.""" + scratch_dir = str(tmpdir / "del_storage") + os.makedirs(scratch_dir, exist_ok=True) + + def create_and_delete(): + storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) + rb = ReplayBuffer(storage=storage) + rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) + # Storage goes out of scope here + + create_and_delete() + + # Force garbage collection + gc.collect() + + # Note: __del__ is not guaranteed to run immediately, but the cleanup + # infrastructure should still work via atexit + + def test_cleanup_preserves_user_data_by_default(self, tmpdir): + """Test that user-specified directories are NOT cleaned by default.""" + scratch_dir = str(tmpdir / "user_data") + os.makedirs(scratch_dir, exist_ok=True) + + storage = LazyMemmapStorage(100, scratch_dir=scratch_dir) + rb = ReplayBuffer(storage=storage) + rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) + + # auto_cleanup should be False by default + assert storage._auto_cleanup is False + + # Directory should exist + assert os.path.isdir(scratch_dir) + + # Explicit cleanup should still work + storage.cleanup() + assert not os.path.exists(scratch_dir) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/rb/test_writers.py b/test/rb/test_writers.py new file mode 100644 index 00000000000..aa9cd432302 --- /dev/null +++ b/test/rb/test_writers.py @@ -0,0 +1,375 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import argparse + +import pytest +import torch +from tensordict import TensorDict +from torch import multiprocessing as mp + +from torchrl.data import ( + PrioritizedReplayBuffer, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.replay_buffers import samplers +from torchrl.data.replay_buffers.samplers import ( + PrioritizedSampler, + RandomSampler, + SamplerWithoutReplacement, +) +from torchrl.data.replay_buffers.storages import ( + LazyMemmapStorage, + LazyTensorStorage, + ListStorage, +) +from torchrl.data.replay_buffers.writers import ( + TensorDictMaxValueWriter, + TensorDictRoundRobinWriter, +) +from torchrl.testing import get_default_devices + + +class TestMaxValueWriter: + @pytest.mark.parametrize("size", [20, 25, 30]) + @pytest.mark.parametrize("batch_size", [1, 10, 15]) + @pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_max_value_writer(self, size, batch_size, reward_ranges, device): + torch.manual_seed(0) + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), + ) + + max_reward1, max_reward2, max_reward3 = reward_ranges + + td = TensorDict( + { + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward1).all() + assert (0 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") <= max_reward2).all() + assert (max_reward1 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + td = TensorDict( + { + "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + + for sample in td: + rb.add(sample) + + sample = rb.sample() + assert (sample.get("key") <= max_reward3).all() + assert (max_reward2 <= sample.get("key")).all() + assert len(sample.get("index").unique()) == len(sample.get("index")) + + # Finally, test the case when no obs should be added + td = TensorDict( + { + "key": torch.zeros(size), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + sample = rb.sample() + assert (sample.get("key") != 0).all() + + @pytest.mark.parametrize("size", [20, 25, 30]) + @pytest.mark.parametrize("batch_size", [1, 10, 15]) + @pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_max_value_writer_serialize( + self, size, batch_size, reward_ranges, device, tmpdir + ): + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(size, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key"), + ) + + max_reward1, max_reward2, max_reward3 = reward_ranges + + td = TensorDict( + { + "key": torch.clamp_max(torch.rand(size), max=max_reward1), + "obs": torch.rand(size), + }, + batch_size=size, + device=device, + ) + rb.extend(td) + rb.writer.dumps(tmpdir) + # check we can dump twice + rb.writer.dumps(tmpdir) + other = TensorDictMaxValueWriter(rank_key="key") + other.loads(tmpdir) + assert len(rb.writer._current_top_values) == len(other._current_top_values) + torch.testing.assert_close( + torch.tensor(rb.writer._current_top_values), + torch.tensor(other._current_top_values), + ) + + @pytest.mark.parametrize("size", [[], [1], [2, 3]]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("reduction", ["max", "min", "mean", "median", "sum"]) + def test_max_value_writer_reduce(self, size, device, reduction): + torch.manual_seed(0) + batch_size = 4 + rb = TensorDictReplayBuffer( + storage=LazyTensorStorage(1, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=batch_size, + writer=TensorDictMaxValueWriter(rank_key="key", reduction=reduction), + ) + + key = torch.rand(batch_size, *size, device=device) + obs = torch.rand(batch_size, *size, device=device) + td = TensorDict( + {"key": key, "obs": obs}, + batch_size=batch_size, + device=device, + ) + rb.extend(td) + sample = rb.sample() + if reduction == "max": + rank_key = torch.stack([k.max() for k in key.unbind(0)]) + elif reduction == "min": + rank_key = torch.stack([k.min() for k in key.unbind(0)]) + elif reduction == "mean": + rank_key = torch.stack([k.mean() for k in key.unbind(0)]) + elif reduction == "median": + rank_key = torch.stack([k.median() for k in key.unbind(0)]) + elif reduction == "sum": + rank_key = torch.stack([k.sum() for k in key.unbind(0)]) + + top_rank = torch.argmax(rank_key) + assert (sample.get("obs") == obs[top_rank]).all() + + +class TestMultiProc: + @staticmethod + def worker(rb, q0, q1): + td = TensorDict({"a": torch.ones(10), "next": {"reward": torch.ones(10)}}, [10]) + rb.extend(td) + q0.put("extended") + extended = q1.get(timeout=5) + assert extended == "extended" + assert len(rb) == 21, len(rb) + assert (rb["a"][:9] == 2).all() + q0.put("finish") + + @staticmethod + def async_prb_worker(rb, worker_id, q): + td = TensorDict( + { + "obs": torch.full((4, 1), worker_id, dtype=torch.float32), + "prio": {"td_error": torch.linspace(0.1, 1.0, 4) + worker_id}, + }, + [4], + ) + rb.extend(td) + q.put("finish") + + @staticmethod + def async_generic_prb_worker(rb, worker_id, q): + data = TensorDict( + {"obs": torch.full((4, 1), worker_id, dtype=torch.float32)}, + [4], + ) + rb.extend(data) + q.put("finish") + + def exec_multiproc_rb( + self, + storage_type=LazyMemmapStorage, + init=True, + writer_type=TensorDictRoundRobinWriter, + sampler_type=RandomSampler, + device=None, + ): + rb = TensorDictReplayBuffer( + storage=storage_type(21), writer=writer_type(), sampler=sampler_type() + ) + if init: + td = TensorDict( + {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, + [10], + device=device, + ) + rb.extend(td) + q0 = mp.Queue(1) + q1 = mp.Queue(1) + proc = mp.Process(target=self.worker, args=(rb, q0, q1)) + proc.start() + try: + extended = q0.get(timeout=100) + assert extended == "extended" + assert len(rb) == 20 + assert (rb["a"][10:20] == 1).all() + td = TensorDict({"a": torch.zeros(10) + 2}, [10]) + rb.extend(td) + q1.put("extended") + finish = q0.get(timeout=5) + assert finish == "finish" + finally: + proc.join() + + def test_multiproc_rb(self): + return self.exec_multiproc_rb() + + def test_error_list(self): + # list storage cannot be shared + with pytest.raises(RuntimeError, match="Cannot share a storage of type"): + self.exec_multiproc_rb(storage_type=ListStorage) + + def test_error_maxwriter(self): + # TensorDictMaxValueWriter cannot be shared + with pytest.raises(RuntimeError, match="cannot be shared between processes"): + self.exec_multiproc_rb(writer_type=TensorDictMaxValueWriter) + + def test_error_prb(self): + # PrioritizedSampler cannot be shared + if samplers.SumSegmentTreeFp32 is None: + pytest.skip("PrioritizedSampler extension is unavailable.") + with pytest.raises( + RuntimeError, + match="cannot be shared between processes.*sync=False", + ): + self.exec_multiproc_rb( + sampler_type=lambda: PrioritizedSampler(21, alpha=1.1, beta=0.5) + ) + + def test_prioritized_sampler_shared_error_mentions_sync_false(self, monkeypatch): + sampler = PrioritizedSampler.__new__(PrioritizedSampler) + monkeypatch.setattr(samplers, "get_spawning_popen", lambda: object()) + with pytest.raises(RuntimeError, match="sync=False"): + sampler.__getstate__() + + def test_shared_prefetch_error_mentions_fix(self): + with pytest.raises( + ValueError, + match="Cannot share prefetched replay buffers.*prefetch=0.*shared=False", + ): + TensorDictReplayBuffer( + storage=LazyTensorStorage(10), + batch_size=2, + prefetch=1, + shared=True, + ) + + def test_async_prioritized_rb_multiproc_writes(self): + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + priority_key=("prio", "td_error"), + storage=LazyMemmapStorage(32, shared_init=True), + batch_size=4, + shared=True, + sync=False, + ) + q = mp.Queue() + processes = [] + for worker_id in range(2): + proc = mp.Process( + target=self.async_prb_worker, + args=(rb, worker_id, q), + ) + processes.append(proc) + proc.start() + + for proc in processes: + proc.join() + assert proc.exitcode == 0 + assert q.get(timeout=5) == "finish" + + assert rb.write_count == 8 + sample = rb.sample() + assert rb._prioritized_sampler_write_count == 8 + assert sample["obs"].shape == (4, 1) + assert "priority_weight" in sample.keys() + assert "index" in sample.keys() + + sample["prio", "td_error"] = torch.ones(sample.shape) * 10 + rb.update_tensordict_priority(sample) + assert rb.prioritized_sampler._max_priority[0] is not None + + def test_async_generic_prioritized_rb_multiproc_writes(self): + rb = PrioritizedReplayBuffer( + alpha=0.7, + beta=0.5, + storage=LazyMemmapStorage(32), + batch_size=4, + sync=False, + ) + rb.extend(TensorDict({"obs": torch.zeros((1, 1))}, [1])) + rb.empty() + rb.share(True) + q = mp.Queue() + processes = [] + for worker_id in range(2): + proc = mp.Process( + target=self.async_generic_prb_worker, + args=(rb, worker_id, q), + ) + processes.append(proc) + proc.start() + + for proc in processes: + proc.join() + assert proc.exitcode == 0 + assert q.get(timeout=5) == "finish" + + assert rb.write_count == 8 + sample, info = rb.sample(return_info=True) + assert rb._prioritized_sampler_write_count == 8 + assert sample["obs"].shape == (4, 1) + assert "priority_weight" in info + assert "index" in info + + rb.update_priority(info["index"], torch.ones(4) * 10) + assert rb.prioritized_sampler._max_priority[0] is not None + + def test_error_noninit(self): + # list storage cannot be shared + with pytest.raises(RuntimeError, match="it has not been initialized yet"): + self.exec_multiproc_rb(init=False) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_modules.py b/test/test_modules.py deleted file mode 100644 index d047221e245..00000000000 --- a/test/test_modules.py +++ /dev/null @@ -1,1753 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import argparse -import re - -from numbers import Number - -import numpy as np -import pytest -import torch -from packaging import version -from tensordict import TensorDict -from torch import nn -from torchrl.data.tensor_specs import Bounded, Composite -from torchrl.modules import ( - CEMPlanner, - DiffusionActor, - DTActor, - GRU, - GRUCell, - LSTM, - LSTMCell, - MultiAgentConvNet, - MultiAgentMLP, - OnlineDTActor, - QMixer, - SafeModule, - TanhModule, - ValueOperator, - VDNMixer, -) -from torchrl.modules.distributions.utils import safeatanh, safetanh -from torchrl.modules.models import ( - BatchRenorm1d, - Conv3dNet, - ConvNet, - MLP, - NoisyLazyLinear, - NoisyLinear, -) -from torchrl.modules.models.decision_transformer import ( - _has_transformers, - DecisionTransformer, -) -from torchrl.modules.models.model_based import ( - DreamerActor, - ObsDecoder, - ObsEncoder, - RSSMPosterior, - RSSMPrior, - RSSMRollout, -) -from torchrl.modules.models.multiagent import MultiAgentNetBase -from torchrl.modules.models.recipes.impala import _ConvNetBlock -from torchrl.modules.models.utils import SquashDims -from torchrl.modules.planners.mppi import MPPIPlanner -from torchrl.objectives.value import TDLambdaEstimator - -from torchrl.testing import get_default_devices, retry - -from torchrl.testing.mocking_classes import MockBatchedUnLockedEnv - - -@pytest.fixture -def double_prec_fixture(): - dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.double) - yield - torch.set_default_dtype(dtype) - - -class TestMLP: - @pytest.mark.parametrize("in_features", [3, 10, None]) - @pytest.mark.parametrize("out_features", [3, (3, 10)]) - @pytest.mark.parametrize("depth, num_cells", [(3, 32), (None, (32, 32, 32))]) - @pytest.mark.parametrize( - "activation_class, activation_kwargs", - [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], - ) - @pytest.mark.parametrize( - "norm_class, norm_kwargs", - [ - (nn.LazyBatchNorm1d, {}), - (nn.BatchNorm1d, {"num_features": 32}), - (nn.LayerNorm, {"normalized_shape": 32}), - ], - ) - @pytest.mark.parametrize("dropout", [0.0, 0.5]) - @pytest.mark.parametrize("bias_last_layer", [True, False]) - @pytest.mark.parametrize("single_bias_last_layer", [True, False]) - @pytest.mark.parametrize("layer_class", [nn.Linear, NoisyLinear]) - @pytest.mark.parametrize("device", get_default_devices()) - def test_mlp( - self, - in_features, - out_features, - depth, - num_cells, - activation_class, - activation_kwargs, - dropout, - bias_last_layer, - norm_class, - norm_kwargs, - single_bias_last_layer, - layer_class, - device, - seed=0, - ): - torch.manual_seed(seed) - batch = 2 - mlp = MLP( - in_features=in_features, - out_features=out_features, - depth=depth, - num_cells=num_cells, - activation_class=activation_class, - activation_kwargs=activation_kwargs, - norm_class=norm_class, - norm_kwargs=norm_kwargs, - dropout=dropout, - bias_last_layer=bias_last_layer, - single_bias_last_layer=False, - layer_class=layer_class, - device=device, - ) - if in_features is None: - in_features = 5 - x = torch.randn(batch, in_features, device=device) - y = mlp(x) - out_features = ( - [out_features] if isinstance(out_features, Number) else out_features - ) - assert y.shape == torch.Size([batch, *out_features]) - - def test_kwargs(self): - def make_activation(shift): - return lambda x: x + shift - - def layer(*args, **kwargs): - linear = nn.Linear(*args, **kwargs) - linear.weight.data.copy_(torch.eye(4)) - return linear - - in_features = 4 - out_features = 4 - num_cells = [4, 4, 4] - mlp = MLP( - in_features=in_features, - out_features=out_features, - num_cells=num_cells, - activation_class=make_activation, - activation_kwargs=[{"shift": 0}, {"shift": 1}, {"shift": 2}], - layer_class=layer, - layer_kwargs=[{"bias": False}] * 4, - bias_last_layer=False, - ) - x = torch.zeros(4) - y = mlp(x) - for i, module in enumerate(mlp.modules()): - if isinstance(module, nn.Linear): - assert (module.weight == torch.eye(4)).all(), i - assert module.bias is None, i - assert (y == 3).all() - - -@pytest.mark.parametrize("in_features", [3, 10, None]) -@pytest.mark.parametrize( - "input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features", - [(100, None, None, 3, 1, 0, 32 * 94 * 94), (100, 3, 32, 3, 1, 1, 32 * 100 * 100)], -) -@pytest.mark.parametrize( - "activation_class, activation_kwargs", - [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], -) -@pytest.mark.parametrize( - "norm_class, norm_kwargs", - [(None, None), (nn.LazyBatchNorm2d, {}), (nn.BatchNorm2d, {"num_features": 32})], -) -@pytest.mark.parametrize("bias_last_layer", [True, False]) -@pytest.mark.parametrize( - "aggregator_class, aggregator_kwargs", - [(SquashDims, {})], -) -@pytest.mark.parametrize("squeeze_output", [False]) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("batch", [(2,), (2, 2)]) -def test_convnet( - batch, - in_features, - depth, - num_cells, - kernel_sizes, - strides, - paddings, - activation_class, - activation_kwargs, - norm_class, - norm_kwargs, - bias_last_layer, - aggregator_class, - aggregator_kwargs, - squeeze_output, - device, - input_size, - expected_features, - seed=0, -): - torch.manual_seed(seed) - convnet = ConvNet( - in_features=in_features, - depth=depth, - num_cells=num_cells, - kernel_sizes=kernel_sizes, - strides=strides, - paddings=paddings, - activation_class=activation_class, - activation_kwargs=activation_kwargs, - norm_class=norm_class, - norm_kwargs=norm_kwargs, - bias_last_layer=bias_last_layer, - aggregator_class=aggregator_class, - aggregator_kwargs=aggregator_kwargs, - squeeze_output=squeeze_output, - device=device, - ) - if in_features is None: - in_features = 5 - x = torch.randn(*batch, in_features, input_size, input_size, device=device) - y = convnet(x) - assert y.shape == torch.Size([*batch, expected_features]) - - -class TestConv3d: - @pytest.mark.parametrize("in_features", [3, 10, None]) - @pytest.mark.parametrize( - "input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features", - [ - (10, None, None, 3, 1, 0, 32 * 4 * 4 * 4), - (10, 3, 32, 3, 1, 1, 32 * 10 * 10 * 10), - ], - ) - @pytest.mark.parametrize( - "activation_class, activation_kwargs", - [(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})], - ) - @pytest.mark.parametrize( - "norm_class, norm_kwargs", - [ - (None, None), - (nn.LazyBatchNorm3d, {}), - (nn.BatchNorm3d, {"num_features": 32}), - ], - ) - @pytest.mark.parametrize("bias_last_layer", [True, False]) - @pytest.mark.parametrize( - "aggregator_class, aggregator_kwargs", - [(SquashDims, None)], - ) - @pytest.mark.parametrize("squeeze_output", [False]) - @pytest.mark.parametrize("device", get_default_devices()) - @pytest.mark.parametrize("batch", [(2,), (2, 2)]) - def test_conv3dnet( - self, - batch, - in_features, - depth, - num_cells, - kernel_sizes, - strides, - paddings, - activation_class, - activation_kwargs, - norm_class, - norm_kwargs, - bias_last_layer, - aggregator_class, - aggregator_kwargs, - squeeze_output, - device, - input_size, - expected_features, - seed=0, - ): - torch.manual_seed(seed) - conv3dnet = Conv3dNet( - in_features=in_features, - depth=depth, - num_cells=num_cells, - kernel_sizes=kernel_sizes, - strides=strides, - paddings=paddings, - activation_class=activation_class, - activation_kwargs=activation_kwargs, - norm_class=norm_class, - norm_kwargs=norm_kwargs, - bias_last_layer=bias_last_layer, - aggregator_class=aggregator_class, - aggregator_kwargs=aggregator_kwargs, - squeeze_output=squeeze_output, - device=device, - ) - if in_features is None: - in_features = 5 - x = torch.randn( - *batch, in_features, input_size, input_size, input_size, device=device - ) - y = conv3dnet(x) - assert y.shape == torch.Size([*batch, expected_features]) - with pytest.raises(ValueError, match="must have at least 4 dimensions"): - conv3dnet(torch.randn(3, 16, 16)) - - def test_errors(self): - with pytest.raises( - ValueError, match="Null depth is not permitted with Conv3dNet" - ): - conv3dnet = Conv3dNet( - in_features=5, - num_cells=32, - depth=0, - ) - with pytest.raises( - ValueError, match="depth=None requires one of the input args" - ): - conv3dnet = Conv3dNet( - in_features=5, - num_cells=32, - depth=None, - ) - with pytest.raises( - ValueError, match="consider matching or specifying a constant num_cells" - ): - conv3dnet = Conv3dNet( - in_features=5, - num_cells=[32], - depth=None, - kernel_sizes=[3, 3], - ) - - -@pytest.mark.parametrize( - "layer_class", - [ - NoisyLinear, - NoisyLazyLinear, - ], -) -@pytest.mark.parametrize("device", get_default_devices()) -def test_noisy(layer_class, device, seed=0): - torch.manual_seed(seed) - layer = layer_class(3, 4, device=device) - x = torch.randn(10, 3, device=device) - y1 = layer(x) - layer.reset_noise() - y2 = layer(x) - y3 = layer(x) - torch.testing.assert_close(y2, y3) - with pytest.raises(AssertionError): - torch.testing.assert_close(y1, y2) - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("batch_size", [3, 5]) -class TestPlanner: - def test_CEM_model_free_env(self, device, batch_size, seed=1): - env = MockBatchedUnLockedEnv(device=device) - torch.manual_seed(seed) - planner = CEMPlanner( - env, - planning_horizon=10, - optim_steps=2, - num_candidates=100, - top_k=2, - ) - td = env.reset(TensorDict(batch_size=batch_size).to(device)) - td_copy = td.clone() - td = planner(td) - assert ( - td.get("action").shape[-len(env.action_spec.shape) :] - == env.action_spec.shape - ) - assert env.action_spec.is_in(td.get("action")) - - for key in td.keys(): - if key != "action": - assert torch.allclose(td[key], td_copy[key]) - - def test_MPPI(self, device, batch_size, seed=1): - torch.manual_seed(seed) - env = MockBatchedUnLockedEnv(device=device) - value_net = nn.LazyLinear(1, device=device) - value_net = ValueOperator(value_net, in_keys=["observation"]) - advantage_module = TDLambdaEstimator( - gamma=0.99, - lmbda=0.95, - value_network=value_net, - ) - value_net(env.reset()) - planner = MPPIPlanner( - env, - advantage_module, - temperature=1.0, - planning_horizon=10, - optim_steps=2, - num_candidates=100, - top_k=2, - ) - td = env.reset(TensorDict(batch_size=batch_size).to(device)) - td_copy = td.clone() - td = planner(td) - assert ( - td.get("action").shape[-len(env.action_spec.shape) :] - == env.action_spec.shape - ) - assert env.action_spec.is_in(td.get("action")) - - for key in td.keys(): - if key != "action": - assert torch.allclose(td[key], td_copy[key]) - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("batch_size", [[], [3], [5]]) -@pytest.mark.skipif( - version.parse(torch.__version__) < version.parse("1.11.0"), - reason="""Dreamer works with batches of null to 2 dimensions. Torch < 1.11 -requires one-dimensional batches (for RNN and Conv nets for instance). If you'd like -to see torch < 1.11 supported for dreamer, please submit an issue.""", -) -class TestDreamerComponents: - @pytest.mark.parametrize("out_features", [3, 5]) - @pytest.mark.parametrize("temporal_size", [[], [2], [4]]) - def test_dreamer_actor(self, device, batch_size, temporal_size, out_features): - actor = DreamerActor( - out_features, - ).to(device) - emb = torch.randn(*batch_size, *temporal_size, 15, device=device) - state = torch.randn(*batch_size, *temporal_size, 2, device=device) - loc, scale = actor(emb, state) - assert loc.shape == (*batch_size, *temporal_size, out_features) - assert scale.shape == (*batch_size, *temporal_size, out_features) - assert torch.all(scale > 0) - - @pytest.mark.parametrize("depth", [32, 64]) - @pytest.mark.parametrize("temporal_size", [[], [2], [4]]) - def test_dreamer_encoder(self, device, temporal_size, batch_size, depth): - encoder = ObsEncoder(channels=depth).to(device) - obs = torch.randn(*batch_size, *temporal_size, 3, 64, 64, device=device) - emb = encoder(obs) - assert emb.shape == (*batch_size, *temporal_size, depth * 8 * 4) - - @pytest.mark.parametrize("depth", [32, 64]) - @pytest.mark.parametrize("stoch_size", [10, 20]) - @pytest.mark.parametrize("deter_size", [20, 30]) - @pytest.mark.parametrize("temporal_size", [[], [2], [4]]) - def test_dreamer_decoder( - self, device, batch_size, temporal_size, depth, stoch_size, deter_size - ): - decoder = ObsDecoder(channels=depth).to(device) - stoch_state = torch.randn( - *batch_size, *temporal_size, stoch_size, device=device - ) - det_state = torch.randn(*batch_size, *temporal_size, deter_size, device=device) - obs = decoder(stoch_state, det_state) - assert obs.shape == (*batch_size, *temporal_size, 3, 64, 64) - - @pytest.mark.parametrize("depth", [32, 64]) - @pytest.mark.parametrize("out_channels", [1, 3]) - @pytest.mark.parametrize("stoch_size", [10]) - @pytest.mark.parametrize("deter_size", [20]) - def test_dreamer_decoder_out_channels( - self, device, batch_size, depth, out_channels, stoch_size, deter_size - ): - decoder = ObsDecoder(channels=depth, out_channels=out_channels).to(device) - stoch_state = torch.randn(*batch_size, stoch_size, device=device) - det_state = torch.randn(*batch_size, deter_size, device=device) - obs = decoder(stoch_state, det_state) - assert obs.shape == (*batch_size, out_channels, 64, 64) - - @pytest.mark.parametrize("stoch_size", [10, 20]) - @pytest.mark.parametrize("deter_size", [20, 30]) - @pytest.mark.parametrize("action_size", [3, 6]) - def test_rssm_prior(self, device, batch_size, stoch_size, deter_size, action_size): - action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) - rssm_prior = RSSMPrior( - action_spec, - hidden_dim=stoch_size, - rnn_hidden_dim=stoch_size, - state_dim=deter_size, - ).to(device) - state = torch.randn(*batch_size, deter_size, device=device) - action = torch.randn(*batch_size, action_size, device=device) - belief = torch.randn(*batch_size, stoch_size, device=device) - prior_mean, prior_std, next_state, belief = rssm_prior(state, belief, action) - assert prior_mean.shape == (*batch_size, deter_size) - assert prior_std.shape == (*batch_size, deter_size) - assert next_state.shape == (*batch_size, deter_size) - assert belief.shape == (*batch_size, stoch_size) - assert torch.all(prior_std > 0) - - @pytest.mark.parametrize("stoch_size", [10, 20]) - @pytest.mark.parametrize("deter_size", [20, 30]) - def test_rssm_posterior(self, device, batch_size, stoch_size, deter_size): - rssm_posterior = RSSMPosterior( - hidden_dim=stoch_size, - state_dim=deter_size, - ).to(device) - belief = torch.randn(*batch_size, stoch_size, device=device) - obs_emb = torch.randn(*batch_size, 1024, device=device) - # Init of lazy linears - _ = rssm_posterior(belief.clone(), obs_emb.clone()) - - torch.manual_seed(0) - posterior_mean, posterior_std, next_state = rssm_posterior( - belief.clone(), obs_emb.clone() - ) - assert posterior_mean.shape == (*batch_size, deter_size) - assert posterior_std.shape == (*batch_size, deter_size) - assert next_state.shape == (*batch_size, deter_size) - assert torch.all(posterior_std > 0) - - torch.manual_seed(0) - posterior_mean_bis, posterior_std_bis, next_state_bis = rssm_posterior( - belief.clone(), obs_emb.clone() - ) - assert torch.allclose(posterior_mean, posterior_mean_bis) - assert torch.allclose(posterior_std, posterior_std_bis) - assert torch.allclose(next_state, next_state_bis) - - @pytest.mark.parametrize("stoch_size", [10, 20]) - @pytest.mark.parametrize("deter_size", [20, 30]) - @pytest.mark.parametrize("temporal_size", [2, 4]) - @pytest.mark.parametrize("action_size", [3, 6]) - def test_rssm_rollout( - self, device, batch_size, temporal_size, stoch_size, deter_size, action_size - ): - action_spec = Bounded(shape=(action_size,), dtype=torch.float32, low=-1, high=1) - rssm_prior = RSSMPrior( - action_spec, - hidden_dim=stoch_size, - rnn_hidden_dim=stoch_size, - state_dim=deter_size, - ).to(device) - rssm_posterior = RSSMPosterior( - hidden_dim=stoch_size, - state_dim=deter_size, - ).to(device) - - rssm_rollout = RSSMRollout( - SafeModule( - rssm_prior, - in_keys=["state", "belief", "action"], - out_keys=[ - ("next", "prior_mean"), - ("next", "prior_std"), - "_", - ("next", "belief"), - ], - ), - SafeModule( - rssm_posterior, - in_keys=[("next", "belief"), ("next", "encoded_latents")], - out_keys=[ - ("next", "posterior_mean"), - ("next", "posterior_std"), - ("next", "state"), - ], - ), - ) - - state = torch.randn(*batch_size, temporal_size, deter_size, device=device) - belief = torch.randn(*batch_size, temporal_size, stoch_size, device=device) - action = torch.randn(*batch_size, temporal_size, action_size, device=device) - obs_emb = torch.randn(*batch_size, temporal_size, 1024, device=device) - - tensordict = TensorDict( - { - "state": state.clone(), - "action": action.clone(), - "next": { - "encoded_latents": obs_emb.clone(), - "belief": belief.clone(), - }, - }, - device=device, - batch_size=torch.Size([*batch_size, temporal_size]), - ) - ## Init of lazy linears - _ = rssm_rollout(tensordict.clone()) - torch.manual_seed(0) - rollout = rssm_rollout(tensordict) - assert rollout["next", "prior_mean"].shape == ( - *batch_size, - temporal_size, - deter_size, - ) - assert rollout["next", "prior_std"].shape == ( - *batch_size, - temporal_size, - deter_size, - ) - assert rollout["next", "state"].shape == ( - *batch_size, - temporal_size, - deter_size, - ) - assert rollout["next", "belief"].shape == ( - *batch_size, - temporal_size, - stoch_size, - ) - assert rollout["next", "posterior_mean"].shape == ( - *batch_size, - temporal_size, - deter_size, - ) - assert rollout["next", "posterior_std"].shape == ( - *batch_size, - temporal_size, - deter_size, - ) - assert torch.all(rollout["next", "prior_std"] > 0) - assert torch.all(rollout["next", "posterior_std"] > 0) - - state[..., 1:, :] = 0 - belief[..., 1:, :] = 0 - # Only the first state is used for the prior. The rest are recomputed - - tensordict_bis = TensorDict( - { - "state": state.clone(), - "action": action.clone(), - "next": {"encoded_latents": obs_emb.clone(), "belief": belief.clone()}, - }, - device=device, - batch_size=torch.Size([*batch_size, temporal_size]), - ) - torch.manual_seed(0) - rollout_bis = rssm_rollout(tensordict_bis) - - assert torch.allclose( - rollout["next", "prior_mean"], rollout_bis["next", "prior_mean"] - ), (rollout["next", "prior_mean"] - rollout_bis["next", "prior_mean"]).norm() - assert torch.allclose( - rollout["next", "prior_std"], rollout_bis["next", "prior_std"] - ) - assert torch.allclose(rollout["next", "state"], rollout_bis["next", "state"]) - assert torch.allclose(rollout["next", "belief"], rollout_bis["next", "belief"]) - assert torch.allclose( - rollout["next", "posterior_mean"], rollout_bis["next", "posterior_mean"] - ) - assert torch.allclose( - rollout["next", "posterior_std"], rollout_bis["next", "posterior_std"] - ) - - -class TestTanh: - def test_errors(self): - with pytest.raises( - ValueError, match="in_keys and out_keys should have the same length" - ): - TanhModule(in_keys=["a", "b"], out_keys=["a"]) - with pytest.raises(ValueError, match=r"The minimum value \(-2\) provided"): - spec = Bounded(-1, 1, shape=()) - TanhModule(in_keys=["act"], low=-2, spec=spec) - with pytest.raises(ValueError, match=r"The maximum value \(-2\) provided to"): - spec = Bounded(-1, 1, shape=()) - TanhModule(in_keys=["act"], high=-2, spec=spec) - with pytest.raises(ValueError, match="Got high < low"): - TanhModule(in_keys=["act"], high=-2, low=-1) - - def test_minmax(self): - mod = TanhModule( - in_keys=["act"], - high=2, - ) - assert isinstance(mod.act_high, torch.Tensor) - mod = TanhModule( - in_keys=["act"], - low=-2, - ) - assert isinstance(mod.act_low, torch.Tensor) - mod = TanhModule( - in_keys=["act"], - high=np.ones((1,)), - ) - assert isinstance(mod.act_high, torch.Tensor) - mod = TanhModule( - in_keys=["act"], - low=-np.ones((1,)), - ) - assert isinstance(mod.act_low, torch.Tensor) - - @pytest.mark.parametrize("clamp", [True, False]) - def test_boundaries(self, clamp): - torch.manual_seed(0) - eps = torch.finfo(torch.float).resolution - for _ in range(10): - min, max = (5 * torch.randn(2)).sort()[0] - mod = TanhModule(in_keys=["act"], low=min, high=max, clamp=clamp) - assert mod.non_trivial - td = TensorDict({"act": (2 * torch.rand(100) - 1) * 10}, []) - mod(td) - # we should have a good proportion of samples close to the boundaries - assert torch.isclose(td["act"], max).any() - assert torch.isclose(td["act"], min).any() - if not clamp: - assert (td["act"] <= max + eps).all() - assert (td["act"] >= min - eps).all() - else: - assert (td["act"] < max + eps).all() - assert (td["act"] > min - eps).all() - - @pytest.mark.parametrize("out_keys", [[("a", "c"), "b"], None]) - @pytest.mark.parametrize("has_spec", [[True, True], [True, False], [False, False]]) - def test_multi_inputs(self, out_keys, has_spec): - in_keys = [("x", "z"), "y"] - real_out_keys = out_keys if out_keys is not None else in_keys - - if any(has_spec): - spec = {} - if has_spec[0]: - spec.update({real_out_keys[0]: Bounded(-2.0, 2.0, shape=())}) - low, high = -2.0, 2.0 - if has_spec[1]: - spec.update({real_out_keys[1]: Bounded(-3.0, 3.0, shape=())}) - low, high = None, None - spec = Composite(spec) - else: - spec = None - low, high = -2.0, 2.0 - - mod = TanhModule( - in_keys=in_keys, - out_keys=out_keys, - low=low, - high=high, - spec=spec, - clamp=False, - ) - data = TensorDict({in_key: torch.randn(100) * 100 for in_key in in_keys}, []) - mod(data) - assert all(out_key in data.keys(True, True) for out_key in real_out_keys) - eps = torch.finfo(torch.float).resolution - - for out_key in real_out_keys: - key = out_key if isinstance(out_key, str) else "_".join(out_key) - low_key = f"{key}_low" - high_key = f"{key}_high" - min, max = getattr(mod, low_key), getattr(mod, high_key) - assert torch.isclose(data[out_key], max).any() - assert torch.isclose(data[out_key], min).any() - assert (data[out_key] <= max + eps).all() - assert (data[out_key] >= min - eps).all() - - -class TestMultiAgent: - def _get_mock_input_td( - self, n_agents, n_agents_inputs, state_shape=(64, 64, 3), T=None, batch=(2,) - ): - if T is not None: - batch = batch + (T,) - obs = torch.randn(*batch, n_agents, n_agents_inputs) - state = torch.randn(*batch, *state_shape) - - td = TensorDict( - { - "agents": TensorDict( - {"observation": obs}, - [*batch, n_agents], - ), - "state": state, - }, - batch_size=batch, - ) - return td - - @retry(AssertionError, 5) - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("share_params", [True, False]) - @pytest.mark.parametrize("centralized", [True, False]) - @pytest.mark.parametrize("n_agent_inputs", [6, None]) - @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) - def test_multiagent_mlp( - self, - n_agents, - centralized, - share_params, - batch, - n_agent_inputs, - n_agent_outputs=2, - ): - torch.manual_seed(1) - mlp = MultiAgentMLP( - n_agent_inputs=n_agent_inputs, - n_agent_outputs=n_agent_outputs, - n_agents=n_agents, - centralized=centralized, - share_params=share_params, - depth=2, - ) - if n_agent_inputs is None: - n_agent_inputs = 6 - td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) - obs = td.get(("agents", "observation")) - - out = mlp(obs) - assert out.shape == (*batch, n_agents, n_agent_outputs) - for i in range(n_agents): - if centralized and share_params: - assert torch.allclose(out[..., i, :], out[..., 0, :]) - else: - for j in range(i + 1, n_agents): - assert not torch.allclose(out[..., i, :], out[..., j, :]) - - obs[..., 0, 0] += 1 - out2 = mlp(obs) - for i in range(n_agents): - if centralized: - # a modification to the input of agent 0 will impact all agents - assert not torch.allclose(out[..., i, :], out2[..., i, :]) - elif i > 0: - assert torch.allclose(out[..., i, :], out2[..., i, :]) - - obs = ( - torch.randn(*batch, 1, n_agent_inputs) - .expand(*batch, n_agents, n_agent_inputs) - .clone() - ) - out = mlp(obs) - for i in range(n_agents): - if share_params: - # same input same output - assert torch.allclose(out[..., i, :], out[..., 0, :]) - else: - for j in range(i + 1, n_agents): - # same input different output - assert not torch.allclose(out[..., i, :], out[..., j, :]) - pattern = rf"""MultiAgentMLP\( - MLP\( - \(0\): Linear\(in_features=\d+, out_features=32, bias=True\) - \(1\): Tanh\(\) - \(2\): Linear\(in_features=32, out_features=32, bias=True\) - \(3\): Tanh\(\) - \(4\): Linear\(in_features=32, out_features=2, bias=True\) - \), - n_agents={n_agents}, - share_params={share_params}, - centralized={centralized}, - agent_dim={-2}\)""" - assert re.match(pattern, str(mlp), re.DOTALL) - - @retry(AssertionError, 5) - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("share_params", [True, False]) - @pytest.mark.parametrize("centralized", [True, False]) - @pytest.mark.parametrize("n_agent_inputs", [6, None]) - @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) - def test_multiagent_mlp_init( - self, - n_agents, - centralized, - share_params, - batch, - n_agent_inputs, - n_agent_outputs=2, - ): - torch.manual_seed(1) - mlp = MultiAgentMLP( - n_agent_inputs=n_agent_inputs, - n_agent_outputs=n_agent_outputs, - n_agents=n_agents, - centralized=centralized, - share_params=share_params, - depth=2, - ) - for m in mlp.modules(): - if isinstance(m, nn.Linear): - assert not isinstance(m.weight, nn.Parameter) - assert m.weight.device == torch.device("meta") - break - else: - raise RuntimeError("could not find a Linear module") - if n_agent_inputs is None: - n_agent_inputs = 6 - td = self._get_mock_input_td(n_agents, n_agent_inputs, batch=batch) - obs = td.get(("agents", "observation")) - mlp(obs) - snet = mlp.get_stateful_net() - assert snet is not mlp._empty_net - - def zero_inplace(mod): - if hasattr(mod, "weight"): - mod.weight.data *= 0 - if hasattr(mod, "bias"): - mod.bias.data *= 0 - - snet.apply(zero_inplace) - assert (mlp.params == 0).all() - - def one_outofplace(mod): - if hasattr(mod, "weight"): - mod.weight = nn.Parameter(torch.ones_like(mod.weight.data)) - if hasattr(mod, "bias"): - mod.bias = nn.Parameter(torch.ones_like(mod.bias.data)) - - snet.apply(one_outofplace) - assert (mlp.params == 0).all() - mlp.from_stateful_net(snet) - assert (mlp.params == 1).all() - - @retry(AssertionError, 5) - @pytest.mark.parametrize("n_agents", [3]) - @pytest.mark.parametrize("share_params", [True]) - @pytest.mark.parametrize("centralized", [True]) - @pytest.mark.parametrize("n_agent_inputs", [6]) - @pytest.mark.parametrize("batch", [(4,)]) - @pytest.mark.parametrize("tdparams", [True, False]) - def test_multiagent_mlp_tdparams( - self, - n_agents, - centralized, - share_params, - batch, - n_agent_inputs, - tdparams, - n_agent_outputs=2, - ): - torch.manual_seed(1) - mlp = MultiAgentMLP( - n_agent_inputs=n_agent_inputs, - n_agent_outputs=n_agent_outputs, - n_agents=n_agents, - centralized=centralized, - share_params=share_params, - depth=2, - use_td_params=tdparams, - ) - if tdparams: - assert list(mlp._empty_net.parameters()) == [] - assert list(mlp.params.parameters()) == list(mlp.parameters()) - else: - assert list(mlp._empty_net.parameters()) == list(mlp.parameters()) - assert not hasattr(mlp.params, "parameters") - if torch.backends.mps.is_available(): - device = torch.device("mps") - elif torch.cuda.is_available(): - device = torch.device("cuda") - else: - return - mlp = nn.Sequential(mlp) - mlp.to(device) - param_set = set(mlp.parameters()) - for p in mlp[0].params.values(True, True): - assert p in param_set - - def test_multiagent_mlp_lazy(self): - torch.manual_seed(0) - mlp = MultiAgentMLP( - n_agent_inputs=None, - n_agent_outputs=6, - n_agents=3, - centralized=True, - share_params=False, - depth=2, - ) - optim = torch.optim.SGD(mlp.parameters(), lr=1e-3) - for p in mlp.parameters(): - if isinstance(p, torch.nn.parameter.UninitializedParameter): - break - else: - raise AssertionError("No UninitializedParameter found") - for p in optim.param_groups[0]["params"]: - if isinstance(p, torch.nn.parameter.UninitializedParameter): - break - else: - raise AssertionError("No UninitializedParameter found") - for _ in range(2): - td = self._get_mock_input_td(3, 4, batch=(10,)) - obs = td.get(("agents", "observation")) - out = mlp(obs) - assert ( - not mlp.params[0] - .apply(lambda x, y: torch.isclose(x, y), mlp.params[1]) - .any() - ) - out.mean().backward() - optim.step() - for p in mlp.parameters(): - if isinstance(p, torch.nn.parameter.UninitializedParameter): - raise AssertionError("UninitializedParameter found") - for p in optim.param_groups[0]["params"]: - if isinstance(p, torch.nn.parameter.UninitializedParameter): - raise AssertionError("UninitializedParameter found") - - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("share_params", [True, False]) - @pytest.mark.parametrize("centralized", [True, False]) - def test_multiagent_reset_mlp( - self, - n_agents, - centralized, - share_params, - ): - actor_net = MultiAgentMLP( - n_agent_inputs=4, - n_agent_outputs=6, - num_cells=(4, 4), - n_agents=n_agents, - centralized=centralized, - share_params=share_params, - ) - params_before = actor_net.params.clone() - actor_net.reset_parameters() - params_after = actor_net.params - assert not params_before.apply( - lambda x, y: torch.isclose(x, y), params_after, batch_size=[] - ).any() - if params_after.numel() > 1: - assert ( - not params_after[0] - .apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[]) - .any() - ) - - @pytest.mark.parametrize("share_params", [True, False]) - @pytest.mark.parametrize("agent_dim", [1, -3]) - def test_multiagent_custom_agent_dim(self, share_params, agent_dim): - """Test that custom agent_dim values work correctly. - - Regression test for https://github.com/pytorch/rl/issues/3288 - """ - n_agents = 3 - obs_dim = 5 - seq_len = 6 - output_dim = 4 - - class SingleAgentMLP(nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - self.net = nn.Sequential( - nn.Linear(in_dim, 32), - nn.Tanh(), - nn.Linear(32, out_dim), - ) - - def forward(self, x): - return self.net(x) - - class MultiAgentPolicyNet(MultiAgentNetBase): - def __init__( - self, - obs_dim, - output_dim, - n_agents, - share_params, - agent_dim, - device=None, - ): - self.obs_dim = obs_dim - self.output_dim = output_dim - self._agent_dim = agent_dim - - super().__init__( - n_agents=n_agents, - centralized=False, - share_params=share_params, - agent_dim=agent_dim, - device=device, - ) - - def _build_single_net(self, *, device, **kwargs): - net = SingleAgentMLP(self.obs_dim, self.output_dim) - return net.to(device) if device is not None else net - - def _pre_forward_check(self, inputs): - if inputs.shape[self._agent_dim] != self.n_agents: - raise ValueError( - f"Multi-agent network expected input with shape[{self._agent_dim}]={self.n_agents}," - f" but got {inputs.shape}" - ) - return inputs - - policy_net = MultiAgentPolicyNet( - obs_dim=obs_dim, - output_dim=output_dim, - n_agents=n_agents, - share_params=share_params, - agent_dim=agent_dim, - ) - - # Input shape: (batch, n_agents, seq_len, obs_dim) with agents at dim 1 - batch_size = 4 - obs = torch.randn(batch_size, n_agents, seq_len, obs_dim) - out = policy_net(obs) - - # Output should preserve agent dimension position - expected_shape = (batch_size, n_agents, seq_len, output_dim) - assert ( - out.shape == expected_shape - ), f"Expected {expected_shape}, got {out.shape}" - - # Verify different agents produce different outputs (unless share_params with same input) - if not share_params: - for i in range(n_agents): - for j in range(i + 1, n_agents): - assert not torch.allclose(out[:, i], out[:, j]) - - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("share_params", [True, False]) - @pytest.mark.parametrize("centralized", [True, False]) - @pytest.mark.parametrize("channels", [3, None]) - @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) - def test_multiagent_cnn( - self, - n_agents, - centralized, - share_params, - batch, - channels, - x=15, - y=15, - ): - torch.manual_seed(0) - cnn = MultiAgentConvNet( - n_agents=n_agents, - centralized=centralized, - share_params=share_params, - in_features=channels, - kernel_sizes=3, - ) - if channels is None: - channels = 3 - td = TensorDict( - { - "agents": TensorDict( - {"observation": torch.randn(*batch, n_agents, channels, x, y)}, - [*batch, n_agents], - ) - }, - batch_size=batch, - ) - obs = td[("agents", "observation")] - out = cnn(obs) - assert out.shape[:-1] == (*batch, n_agents) - if centralized and share_params: - torch.testing.assert_close(out, out[..., :1, :].expand_as(out)) - else: - for i in range(n_agents): - for j in range(i + 1, n_agents): - assert not torch.allclose(out[..., i, :], out[..., j, :]) - obs[..., 0, 0, 0, 0] += 1 - out2 = cnn(obs) - if centralized: - # a modification to the input of agent 0 will impact all agents - assert not torch.isclose(out, out2).all() - elif n_agents > 1: - assert not torch.isclose(out[..., 0, :], out2[..., 0, :]).all() - torch.testing.assert_close(out[..., 1:, :], out2[..., 1:, :]) - - obs = torch.randn(*batch, 1, channels, x, y).expand( - *batch, n_agents, channels, x, y - ) - out = cnn(obs) - for i in range(n_agents): - if share_params: - # same input same output - assert torch.allclose(out[..., i, :], out[..., 0, :]) - else: - for j in range(i + 1, n_agents): - # same input different output - assert not torch.allclose(out[..., i, :], out[..., j, :]) - - def test_multiagent_cnn_lazy(self): - torch.manual_seed(42) - n_agents = 5 - n_channels = 3 - cnn = MultiAgentConvNet( - n_agents=n_agents, - centralized=False, - share_params=False, - in_features=None, - kernel_sizes=3, - ) - optim = torch.optim.SGD(cnn.parameters(), lr=1e-3) - for p in cnn.parameters(): - if isinstance(p, torch.nn.parameter.UninitializedParameter): - break - else: - raise AssertionError("No UninitializedParameter found") - for p in optim.param_groups[0]["params"]: - if isinstance(p, torch.nn.parameter.UninitializedParameter): - break - else: - raise AssertionError("No UninitializedParameter found") - for _ in range(2): - td = TensorDict( - { - "agents": TensorDict( - {"observation": torch.randn(4, n_agents, n_channels, 15, 15)}, - [4, 5], - ) - }, - batch_size=[4], - ) - obs = td[("agents", "observation")] - out = cnn(obs) - assert ( - not cnn.params[0] - .apply(lambda x, y: torch.isclose(x, y), cnn.params[1]) - .any() - ) - out.mean().backward() - optim.step() - for p in cnn.parameters(): - if isinstance(p, torch.nn.parameter.UninitializedParameter): - raise AssertionError("UninitializedParameter found") - for p in optim.param_groups[0]["params"]: - if isinstance(p, torch.nn.parameter.UninitializedParameter): - raise AssertionError("UninitializedParameter found") - - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("share_params", [True, False]) - @pytest.mark.parametrize("centralized", [True, False]) - def test_multiagent_reset_cnn( - self, - n_agents, - centralized, - share_params, - ): - torch.manual_seed(42) - actor_net = MultiAgentConvNet( - in_features=4, - num_cells=[5, 5], - n_agents=n_agents, - centralized=centralized, - share_params=share_params, - ) - params_before = actor_net.params.clone() - actor_net.reset_parameters() - params_after = actor_net.params - assert not params_before.apply( - lambda x, y: torch.isclose(x, y), params_after, batch_size=[] - ).any() - if params_after.numel() > 1: - assert ( - not params_after[0] - .apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[]) - .any() - ) - - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) - def test_vdn(self, n_agents, batch): - torch.manual_seed(0) - mixer = VDNMixer(n_agents=n_agents, device="cpu") - - td = self._get_mock_input_td(n_agents, batch=batch, n_agents_inputs=1) - obs = td.get(("agents", "observation")) - assert obs.shape == (*batch, n_agents, 1) - out = mixer(obs) - assert out.shape == (*batch, 1) - assert torch.equal(obs.sum(-2), out) - - @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) - @pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)]) - def test_qmix(self, n_agents, batch, state_shape): - torch.manual_seed(0) - mixer = QMixer( - n_agents=n_agents, - state_shape=state_shape, - mixing_embed_dim=32, - device="cpu", - ) - - td = self._get_mock_input_td( - n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape - ) - obs = td.get(("agents", "observation")) - state = td.get("state") - assert obs.shape == (*batch, n_agents, 1) - assert state.shape == (*batch, *state_shape) - out = mixer(obs, state) - assert out.shape == (*batch, 1) - - @pytest.mark.parametrize("mixer", ["qmix", "vdn"]) - def test_mixer_malformed_input( - self, mixer, n_agents=3, batch=(32,), state_shape=(64, 64, 3) - ): - td = self._get_mock_input_td( - n_agents, batch=batch, n_agents_inputs=3, state_shape=state_shape - ) - if mixer == "qmix": - mixer = QMixer( - n_agents=n_agents, - state_shape=state_shape, - mixing_embed_dim=32, - device="cpu", - ) - else: - mixer = VDNMixer(n_agents=n_agents, device="cpu") - obs = td.get(("agents", "observation")) - state = td.get("state") - - if mixer.needs_state: - with pytest.raises( - ValueError, - match="Mixer that needs state was passed more than 2 inputs", - ): - mixer(obs) - else: - with pytest.raises( - ValueError, - match="Mixer that doesn't need state was passed more than 1 input", - ): - mixer(obs, state) - - in_put = [obs, state] if mixer.needs_state else [obs] - with pytest.raises( - ValueError, - match="Mixer network expected chosen_action_value with last 2 dimensions", - ): - mixer(*in_put) - if mixer.needs_state: - state_diff = state.unsqueeze(-1) - with pytest.raises( - ValueError, - match="Mixer network expected state with ending shape", - ): - mixer(obs, state_diff) - - td = self._get_mock_input_td( - n_agents, batch=batch, n_agents_inputs=1, state_shape=state_shape - ) - obs = td.get(("agents", "observation")) - state = td.get("state") - obs = obs.sum(-2) - in_put = [obs, state] if mixer.needs_state else [obs] - with pytest.raises( - ValueError, - match="Mixer network expected chosen_action_value with last 2 dimensions", - ): - mixer(*in_put) - - obs = td.get(("agents", "observation")) - state = td.get("state") - in_put = [obs, state] if mixer.needs_state else [obs] - mixer(*in_put) - - -@pytest.mark.skipif(torch.__version__ < "2.0", reason="torch 2.0 is required") -@pytest.mark.parametrize("use_vmap", [False, True]) -@pytest.mark.parametrize("scale", range(10)) -def test_tanh_atanh(use_vmap, scale): - if use_vmap: - try: - from torch import vmap - except ImportError: - try: - from functorch import vmap - except ImportError: - raise pytest.skip("functorch not found") - - torch.manual_seed(0) - x = (torch.randn(10, dtype=torch.double) * scale).requires_grad_(True) - if not use_vmap: - y = safetanh(x, 1e-6) - else: - y = vmap(safetanh, (0, None))(x, 1e-6) - - if not use_vmap: - xp = safeatanh(y, 1e-6) - else: - xp = vmap(safeatanh, (0, None))(y, 1e-6) - - xp.sum().backward() - torch.testing.assert_close(x.grad, torch.ones_like(x)) - - -@pytest.mark.skipif( - not _has_transformers, reason="transformers needed for TestDecisionTransformer" -) -class TestDecisionTransformer: - def test_init(self): - DecisionTransformer( - 3, - 4, - ) - with pytest.raises(TypeError): - DecisionTransformer(3, 4, config="some_str") - DecisionTransformer( - 3, - 4, - config=DecisionTransformer.DTConfig( - n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 - ), - ) - - @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) - def test_exec(self, batch_dims, T=5): - observations = torch.randn(*batch_dims, T, 3) - actions = torch.randn(*batch_dims, T, 4) - r2go = torch.randn(*batch_dims, T, 1) - model = DecisionTransformer( - 3, - 4, - config=DecisionTransformer.DTConfig( - n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 - ), - ) - out = model(observations, actions, r2go) - assert out.shape == torch.Size([*batch_dims, T, 16]) - - @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) - def test_dtactor(self, batch_dims, T=5): - dtactor = DTActor( - 3, - 4, - transformer_config=DecisionTransformer.DTConfig( - n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 - ), - ) - observations = torch.randn(*batch_dims, T, 3) - actions = torch.randn(*batch_dims, T, 4) - r2go = torch.randn(*batch_dims, T, 1) - out = dtactor(observations, actions, r2go) - assert out.shape == torch.Size([*batch_dims, T, 4]) - - @pytest.mark.parametrize("batch_dims", [[], [3], [3, 4]]) - def test_onlinedtactor(self, batch_dims, T=5): - dtactor = OnlineDTActor( - 3, - 4, - transformer_config=DecisionTransformer.DTConfig( - n_layer=2, n_embd=16, n_positions=16, n_inner=16, n_head=2 - ), - ) - observations = torch.randn(*batch_dims, T, 3) - actions = torch.randn(*batch_dims, T, 4) - r2go = torch.randn(*batch_dims, T, 1) - mu, sig = dtactor(observations, actions, r2go) - assert mu.shape == torch.Size([*batch_dims, T, 4]) - assert sig.shape == torch.Size([*batch_dims, T, 4]) - assert (dtactor.log_std_min < sig.log()).all() - assert (dtactor.log_std_max > sig.log()).all() - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("bias", [True, False]) -def test_python_lstm_cell(device, bias): - lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias) - lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias) - - lstm_cell1.load_state_dict(lstm_cell2.state_dict()) - - # Make sure parameters match - for (k1, v1), (k2, v2) in zip( - lstm_cell1.named_parameters(), lstm_cell2.named_parameters() - ): - assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - - # Run loop - input = torch.randn(2, 3, 10, device=device) - h0 = torch.randn(3, 20, device=device) - c0 = torch.randn(3, 20, device=device) - with torch.no_grad(): - for i in range(input.size()[0]): - h1, c1 = lstm_cell1(input[i], (h0, c0)) - h2, c2 = lstm_cell2(input[i], (h0, c0)) - - # Make sure the final hidden states have the same shape - assert h1.shape == h2.shape - assert c1.shape == c2.shape - torch.testing.assert_close(h1, h2) - torch.testing.assert_close(c1, c2) - h0 = h1 - c0 = c1 - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("bias", [True, False]) -def test_python_gru_cell(device, bias): - gru_cell1 = GRUCell(10, 20, device=device, bias=bias) - gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias) - - gru_cell2.load_state_dict(gru_cell1.state_dict()) - - # Make sure parameters match - for (k1, v1), (k2, v2) in zip( - gru_cell1.named_parameters(), gru_cell2.named_parameters() - ): - assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert (v1 == v2).all() - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - - # Run loop - input = torch.randn(2, 3, 10, device=device) - h0 = torch.zeros(3, 20, device=device) - with torch.no_grad(): - for i in range(input.size()[0]): - h1 = gru_cell1(input[i], h0) - h2 = gru_cell2(input[i], h0) - - # Make sure the final hidden states have the same shape - assert h1.shape == h2.shape - torch.testing.assert_close(h1, h2) - h0 = h1 - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("batch_first", [True, False]) -@pytest.mark.parametrize("dropout", [0.0, 0.5]) -@pytest.mark.parametrize("num_layers", [1, 2]) -def test_python_lstm(device, bias, dropout, batch_first, num_layers): - B = 5 - T = 3 - lstm1 = LSTM( - input_size=10, - hidden_size=20, - num_layers=num_layers, - device=device, - bias=bias, - batch_first=batch_first, - ) - lstm2 = nn.LSTM( - input_size=10, - hidden_size=20, - num_layers=num_layers, - device=device, - bias=bias, - batch_first=batch_first, - ) - - lstm2.load_state_dict(lstm1.state_dict()) - - # Make sure parameters match - for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()): - assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - - if batch_first: - input = torch.randn(B, T, 10, device=device) - else: - input = torch.randn(T, B, 10, device=device) - - h0 = torch.randn(num_layers, 5, 20, device=device) - c0 = torch.randn(num_layers, 5, 20, device=device) - - # Test without hidden states - with torch.no_grad(): - output1, (h1, c1) = lstm1(input) - output2, (h2, c2) = lstm2(input) - - assert h1.shape == h2.shape - assert c1.shape == c2.shape - assert output1.shape == output2.shape - if dropout == 0.0: - torch.testing.assert_close(output1, output2) - torch.testing.assert_close(h1, h2) - torch.testing.assert_close(c1, c2) - - # Test with hidden states - with torch.no_grad(): - output1, (h1, c1) = lstm1(input, (h0, c0)) - output2, (h2, c2) = lstm1(input, (h0, c0)) - - assert h1.shape == h2.shape - assert c1.shape == c2.shape - assert output1.shape == output2.shape - if dropout == 0.0: - torch.testing.assert_close(output1, output2) - torch.testing.assert_close(h1, h2) - torch.testing.assert_close(c1, c2) - - -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("batch_first", [True, False]) -@pytest.mark.parametrize("dropout", [0.0, 0.5]) -@pytest.mark.parametrize("num_layers", [1, 2]) -def test_python_gru(device, bias, dropout, batch_first, num_layers): - B = 5 - T = 3 - gru1 = GRU( - input_size=10, - hidden_size=20, - num_layers=num_layers, - device=device, - bias=bias, - batch_first=batch_first, - ) - gru2 = nn.GRU( - input_size=10, - hidden_size=20, - num_layers=num_layers, - device=device, - bias=bias, - batch_first=batch_first, - ) - gru2.load_state_dict(gru1.state_dict()) - - # Make sure parameters match - for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()): - assert k1 == k2, f"Parameter names do not match: {k1} != {k2}" - torch.testing.assert_close(v1, v2) - assert ( - v1.shape == v2.shape - ), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}" - - if batch_first: - input = torch.randn(B, T, 10, device=device) - else: - input = torch.randn(T, B, 10, device=device) - - h0 = torch.randn(num_layers, 5, 20, device=device) - - # Test without hidden states - with torch.no_grad(): - output1, h1 = gru1(input) - output2, h2 = gru2(input) - - assert h1.shape == h2.shape - assert output1.shape == output2.shape - if dropout == 0.0: - torch.testing.assert_close(output1, output2) - torch.testing.assert_close(h1, h2) - - # Test with hidden states - with torch.no_grad(): - output1, h1 = gru1(input, h0) - output2, h2 = gru2(input, h0) - - assert h1.shape == h2.shape - assert output1.shape == output2.shape - if dropout == 0.0: - torch.testing.assert_close(output1, output2) - torch.testing.assert_close(h1, h2) - - -class TestBatchRenorm: - @pytest.mark.parametrize("num_steps", [0, 5]) - @pytest.mark.parametrize("smooth", [False, True]) - def test_batchrenorm(self, num_steps, smooth): - torch.manual_seed(0) - bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) - brn = BatchRenorm1d( - 5, - momentum=0.1, - eps=1e-5, - warmup_steps=num_steps, - max_d=10000, - max_r=10000, - smooth=smooth, - ) - bn.train() - brn.train() - data_train = torch.randn(100, 5).split(25) - data_test = torch.randn(100, 5) - for i, d in enumerate(data_train): - b = bn(d) - a = brn(d) - if num_steps > 0 and ( - (i < num_steps and not smooth) or (i == 0 and smooth) - ): - torch.testing.assert_close(a, b) - else: - assert not torch.isclose(a, b).all(), i - - bn.eval() - brn.eval() - torch.testing.assert_close(bn(data_test), brn(data_test)) - - -def test_convnetblock_uses_both_resnets(): - """Regression test for https://github.com/pytorch/rl/issues/3519.""" - block = _ConvNetBlock(num_ch=16) - x = torch.randn(2, 3, 8, 8) - out = block(x).mean() - out.backward() - - resnet1_grad = sum(p.grad.abs().sum() for p in block.resnet1.parameters()) - resnet2_grad = sum(p.grad.abs().sum() for p in block.resnet2.parameters()) - assert resnet1_grad > 0, "resnet1 parameters received no gradients" - assert resnet2_grad > 0, "resnet2 parameters received no gradients" - - -class TestDiffusionActor: - def test_output_shape(self): - actor = DiffusionActor(action_dim=2, obs_dim=3, num_steps=5) - td = TensorDict({"observation": torch.randn(4, 3)}, batch_size=[4]) - td = actor(td) - assert td["action"].shape == torch.Size([4, 2]) - - def test_unbatched(self): - actor = DiffusionActor(action_dim=4, obs_dim=6, num_steps=3) - td = TensorDict({"observation": torch.randn(6)}, batch_size=[]) - td = actor(td) - assert td["action"].shape == torch.Size([4]) - - def test_custom_in_out_keys(self): - actor = DiffusionActor( - action_dim=2, - obs_dim=3, - num_steps=3, - in_keys=["obs"], - out_keys=["act"], - ) - assert actor.in_keys == ["obs"] - assert actor.out_keys == ["act"] - td = TensorDict({"obs": torch.randn(4, 3)}, batch_size=[4]) - td = actor(td) - assert td["act"].shape == torch.Size([4, 2]) - - def test_custom_score_network(self): - score_net = nn.Linear(2 + 3 + 1, 2) - actor = DiffusionActor( - action_dim=2, obs_dim=3, score_network=score_net, num_steps=3 - ) - td = TensorDict({"observation": torch.randn(4, 3)}, batch_size=[4]) - td = actor(td) - assert td["action"].shape == torch.Size([4, 2]) - - def test_spec_wrapping(self): - spec = Bounded(low=-1.0, high=1.0, shape=(2,)) - actor = DiffusionActor(action_dim=2, obs_dim=3, num_steps=3, spec=spec) - assert actor.spec is not None - - def test_gradients_flow(self): - actor = DiffusionActor(action_dim=2, obs_dim=3, num_steps=3) - obs = torch.randn(4, 3) - td = TensorDict({"observation": obs}, batch_size=[4]) - td = actor(td) - td["action"].sum().backward() - for p in actor.parameters(): - assert p.grad is not None - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_rb.py b/test/test_rb.py deleted file mode 100644 index ccaa553b84b..00000000000 --- a/test/test_rb.py +++ /dev/null @@ -1,6580 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import argparse -import contextlib -import functools -import gc -import importlib -import os -import pickle -import shutil -import signal -import subprocess -import sys -import tempfile -import time -import warnings -from functools import partial -from pathlib import Path -from unittest import mock - -import numpy as np -import pytest -import torch -import torchrl -from packaging import version -from packaging.version import parse -from tensordict import ( - assert_allclose_td, - is_tensor_collection, - is_tensorclass, - LazyStackedTensorDict, - tensorclass, - TensorDict, - TensorDictBase, -) -from torch import multiprocessing as mp -from torch.utils._pytree import tree_flatten, tree_map - -from torchrl._utils import _replace_last, logger as torchrl_logger, rl_warnings -from torchrl.collectors import Collector -from torchrl.collectors.utils import split_trajectories -from torchrl.data import ( - CompressedListStorage, - FlatStorageCheckpointer, - MultiStep, - NestedStorageCheckpointer, - PrioritizedReplayBuffer, - RayReplayBuffer, - RemoteTensorDictReplayBuffer, - ReplayBuffer, - ReplayBufferEnsemble, - TensorDictPrioritizedReplayBuffer, - TensorDictReplayBuffer, -) -from torchrl.data.replay_buffers import samplers, writers -from torchrl.data.replay_buffers.checkpointers import H5StorageCheckpointer -from torchrl.data.replay_buffers.samplers import ( - PrioritizedSampler, - PrioritizedSliceSampler, - RandomSampler, - Sampler, - SamplerEnsemble, - SamplerWithoutReplacement, - SliceSampler, - SliceSamplerWithoutReplacement, - StalenessAwareSampler, -) -from torchrl.data.replay_buffers.scheduler import ( - LinearScheduler, - SchedulerList, - StepScheduler, -) - -from torchrl.data.replay_buffers.storages import ( - _MEMMAP_STORAGE_REGISTRY, - LazyMemmapStorage, - LazyStackStorage, - LazyTensorStorage, - ListStorage, - StorageEnsemble, - TensorStorage, -) -from torchrl.data.replay_buffers.utils import tree_iter -from torchrl.data.replay_buffers.writers import ( - RoundRobinWriter, - TensorDictMaxValueWriter, - TensorDictRoundRobinWriter, - WriterEnsemble, -) -from torchrl.envs import GymEnv, SerialEnv -from torchrl.envs.transforms.transforms import ( - BinarizeReward, - CatFrames, - CatTensors, - CenterCrop, - Compose, - DiscreteActionProjection, - DoubleToFloat, - FiniteTensorDictCheck, - FlattenObservation, - GrayScale, - gSDENoise, - ObservationNorm, - PinMemoryTransform, - RenameTransform, - Resize, - RewardClipping, - RewardScaling, - SqueezeTransform, - StepCounter, - ToTensorImage, - UnsqueezeTransform, - VecNorm, -) -from torchrl.envs.transforms import NextStateReconstructor -from torchrl.modules import GRUModule, RandomPolicy, set_recurrent_mode - -from torchrl.testing import ( - capture_log_records, - CARTPOLE_VERSIONED, - get_default_devices, - make_tc, -) -from torchrl.testing.mocking_classes import CountingEnv - -OLD_TORCH = parse(torch.__version__) < parse("2.0.0") -_has_tv = importlib.util.find_spec("torchvision") is not None -_has_gym = importlib.util.find_spec("gym") is not None -_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None -_os_is_windows = sys.platform == "win32" -_has_transformers = importlib.util.find_spec("transformers") is not None -_has_ray = importlib.util.find_spec("ray") is not None -_has_zstandard = importlib.util.find_spec("zstandard") is not None - -TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) - -torch_2_3 = version.parse( - ".".join([str(s) for s in version.parse(str(torch.__version__)).release]) -) >= version.parse("2.3.0") - -ReplayBufferRNG = functools.partial(ReplayBuffer, generator=torch.Generator()) -TensorDictReplayBufferRNG = functools.partial( - TensorDictReplayBuffer, generator=torch.Generator() -) - - -def test_replay_buffer_read_write_all_in_order(): - rb = TensorDictReplayBuffer(storage=LazyTensorStorage(6)) - rb_slice = TensorDictReplayBuffer(storage=LazyTensorStorage(6)) - data = TensorDict({"obs": torch.arange(6), "reward": torch.zeros(6)}, [6]) - rb.extend(data) - rb_slice.extend(data.clone()) - - all_data = rb.read_all_in_order() - assert_allclose_td(all_data, rb[:]) - assert all_data["obs"].tolist() == list(range(6)) - all_data["value_target"] = all_data["obs"] + 1 - rb.write_all(all_data) - rb_slice[:] = all_data.clone() - - updated = rb.read_all_in_order() - assert_allclose_td(updated, rb[:]) - assert_allclose_td(updated, rb_slice[:]) - assert updated["value_target"].tolist() == list(range(1, 7)) - - -def test_replay_buffer_read_write_all_in_order_with_end(): - rb = TensorDictReplayBuffer(storage=LazyTensorStorage(10)) - rb_slice = TensorDictReplayBuffer(storage=LazyTensorStorage(10)) - rb.extend(TensorDict({"obs": torch.arange(6)}, [6])) - rb_slice.extend(TensorDict({"obs": torch.arange(6)}, [6])) - - partial = rb.read_all_in_order(end=3) - assert_allclose_td(partial, rb[:3]) - partial["obs"] = partial["obs"] + 10 - rb.write_all(partial, end=3) - rb_slice[:3] = partial.clone() - - updated = rb.read_all_in_order() - assert_allclose_td(updated, rb_slice[:]) - assert updated["obs"].tolist() == [10, 11, 12, 3, 4, 5] - - -def test_replay_buffer_read_write_all_in_order_matches_full_slice_ndim2(): - rb = TensorDictReplayBuffer(storage=LazyTensorStorage(6, ndim=2)) - rb_slice = TensorDictReplayBuffer(storage=LazyTensorStorage(6, ndim=2)) - data = TensorDict( - {"obs": torch.arange(6).reshape(2, 3), "reward": torch.zeros(2, 3)}, - [2, 3], - ) - rb.extend(data) - rb_slice.extend(data.clone()) - - all_data = rb.read_all_in_order() - assert_allclose_td(all_data, rb[:]) - all_data["value_target"] = all_data["obs"] + 1 - rb.write_all(all_data) - rb_slice[:] = all_data.clone() - - assert_allclose_td(rb.read_all_in_order(), rb[:]) - assert_allclose_td(rb.read_all_in_order(), rb_slice[:]) - - -@pytest.mark.parametrize( - "sampler", - [ - samplers.RandomSampler, - samplers.SamplerWithoutReplacement, - samplers.PrioritizedSampler, - ], -) -@pytest.mark.parametrize( - "writer", [writers.RoundRobinWriter, writers.TensorDictMaxValueWriter] -) -@pytest.mark.parametrize( - "rb_type,storage,datatype", - [ - [ReplayBuffer, ListStorage, None], - [ReplayBufferRNG, ListStorage, None], - [TensorDictReplayBuffer, ListStorage, "tensordict"], - [TensorDictReplayBufferRNG, ListStorage, "tensordict"], - [RemoteTensorDictReplayBuffer, ListStorage, "tensordict"], - [ReplayBuffer, LazyTensorStorage, "tensor"], - [ReplayBuffer, LazyTensorStorage, "tensordict"], - [ReplayBuffer, LazyTensorStorage, "pytree"], - [ReplayBufferRNG, LazyTensorStorage, "tensor"], - [ReplayBufferRNG, LazyTensorStorage, "tensordict"], - [ReplayBufferRNG, LazyTensorStorage, "pytree"], - [TensorDictReplayBuffer, LazyTensorStorage, "tensordict"], - [TensorDictReplayBufferRNG, LazyTensorStorage, "tensordict"], - [RemoteTensorDictReplayBuffer, LazyTensorStorage, "tensordict"], - [ReplayBuffer, LazyMemmapStorage, "tensor"], - [ReplayBuffer, LazyMemmapStorage, "tensordict"], - [ReplayBuffer, LazyMemmapStorage, "pytree"], - [ReplayBufferRNG, LazyMemmapStorage, "tensor"], - [ReplayBufferRNG, LazyMemmapStorage, "tensordict"], - [ReplayBufferRNG, LazyMemmapStorage, "pytree"], - [TensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], - [TensorDictReplayBufferRNG, LazyMemmapStorage, "tensordict"], - [RemoteTensorDictReplayBuffer, LazyMemmapStorage, "tensordict"], - ], -) -@pytest.mark.parametrize("size", [3, 5, 100]) -class TestComposableBuffers: - def _get_rb( - self, rb_type, size, sampler, writer, storage, compilable=False, **kwargs - ): - if storage is not None: - storage = storage(size, compilable=compilable) - - sampler_args = {} - if sampler is samplers.PrioritizedSampler: - sampler_args = {"max_capacity": size, "alpha": 0.8, "beta": 0.9} - - sampler = sampler(**sampler_args) - writer = writer(compilable=compilable) - rb = rb_type( - storage=storage, - sampler=sampler, - writer=writer, - batch_size=3, - compilable=compilable, - **kwargs, - ) - return rb - - def _get_datum(self, datatype): - if datatype is None: - data = torch.randint(100, (1,)) - elif datatype == "tensor": - data = torch.randint(100, (1,)) - elif datatype == "tensordict": - data = TensorDict( - {"a": torch.randint(100, (1,)), "next": {"reward": torch.randn(1)}}, [] - ) - elif datatype == "pytree": - data = { - "a": torch.randint(100, (1,)), - "b": {"c": [torch.zeros(3), (torch.ones(2),)]}, - 30: torch.zeros(2), - } - else: - raise NotImplementedError(datatype) - return data - - def _get_data(self, datatype, size): - if datatype is None: - data = torch.randint(100, (size, 1)) - elif datatype == "tensor": - data = torch.randint(100, (size, 1)) - elif datatype == "tensordict": - data = TensorDict( - { - "a": torch.randint(100, (size, 1)), - "next": {"reward": torch.randn(size, 1)}, - }, - [size], - ) - elif datatype == "pytree": - data = { - "a": torch.randint(100, (size, 1)), - "b": {"c": [torch.zeros(size, 3), (torch.ones(size, 2),)]}, - 30: torch.zeros(size, 2), - } - else: - raise NotImplementedError(datatype) - return data - - def test_rb_repr(self, rb_type, sampler, writer, storage, size, datatype): - if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: - pytest.skip( - "Distributed package support on Windows is a prototype feature and is subject to changes." - ) - torch.manual_seed(0) - rb = self._get_rb( - rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size - ) - data = self._get_datum(datatype) - if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: - with pytest.raises( - RuntimeError, match="expects data to be a tensor collection" - ): - rb.add(data) - return - rb.add(data) - # we just check that str runs, not its value - assert str(rb) - rb.sample() - assert str(rb) - - def test_add(self, rb_type, sampler, writer, storage, size, datatype): - if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: - pytest.skip( - "Distributed package support on Windows is a prototype feature and is subject to changes." - ) - torch.manual_seed(0) - rb = self._get_rb( - rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size - ) - data = self._get_datum(datatype) - if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: - with pytest.raises( - RuntimeError, match="expects data to be a tensor collection" - ): - rb.add(data) - return - rb.add(data) - s, info = rb.sample(1, return_info=True) - assert len(rb) == 1 - if isinstance(s, (torch.Tensor, TensorDictBase)): - assert s.ndim, s - s = s[0] - else: - - def assert_ndim(tensor): - assert tensor.shape[0] == 1 - - tree_map(assert_ndim, s) - s = tree_map(lambda s: s[0], s) - if isinstance(s, TensorDictBase): - s = s.select(*data.keys(True), strict=False) - data = data.select(*s.keys(True), strict=False) - assert (s == data).all() - assert list(s.keys(True, True)) - else: - flat_s = tree_flatten(s)[0] - flat_data = tree_flatten(data)[0] - assert all((_s == _data).all() for (_s, _data) in zip(flat_s, flat_data)) - - def test_cursor_position(self, rb_type, sampler, writer, storage, size, datatype): - storage = storage(size) - writer = writer() - writer.register_storage(storage) - batch1 = self._get_data(datatype, size=5) - cond = ( - OLD_TORCH - and not isinstance(writer, TensorDictMaxValueWriter) - and size < len(batch1) - and isinstance(storage, TensorStorage) - ) - - if not is_tensor_collection(batch1) and isinstance( - writer, TensorDictMaxValueWriter - ): - with pytest.raises( - RuntimeError, match="expects data to be a tensor collection" - ): - writer.extend(batch1) - return - - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - writer.extend(batch1) - - # Added less data than storage max size - if size > 5: - assert writer._cursor == 5 - # Added more data than storage max size - elif size < 5: - # if Max writer, we don't necessarily overwrite existing values so - # we just check that the cursor is before the threshold - if isinstance(writer, TensorDictMaxValueWriter): - assert writer._cursor <= 5 - size - else: - assert writer._cursor == 5 - size - # Added as data as storage max size - else: - assert writer._cursor == 0 - if not isinstance(writer, TensorDictMaxValueWriter): - batch2 = self._get_data(datatype, size=size - 1) - writer.extend(batch2) - assert writer._cursor == size - 1 - - def test_extend(self, rb_type, sampler, writer, storage, size, datatype): - if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: - pytest.skip( - "Distributed package support on Windows is a prototype feature and is subject to changes." - ) - torch.manual_seed(0) - rb = self._get_rb( - rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size - ) - data_shape = 5 - data = self._get_data(datatype, size=data_shape) - cond = ( - OLD_TORCH - and writer is not TensorDictMaxValueWriter - and size < len(data) - and isinstance(rb.storage, TensorStorage) - ) - if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: - with pytest.raises( - RuntimeError, match="expects data to be a tensor collection" - ): - rb.extend(data) - return - length = min(rb.storage.max_size, len(rb) + data_shape) - if writer is TensorDictMaxValueWriter: - data["next", "reward"][-length:] = 1_000_000 - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data) - length = len(rb) - if is_tensor_collection(data): - data_iter = data[-length:] - else: - - def data_iter(): - for t in range(-length, -1): - yield tree_map(lambda x, t=t: x[t], data) - - data_iter = data_iter() - for d in data_iter: - for b in rb.storage: - if isinstance(b, TensorDictBase): - keys = set(d.keys()).intersection(b.keys()) - b = b.exclude("index").select(*keys, strict=False) - keys = set(d.keys()).intersection(b.keys()) - d = d.select(*keys, strict=False) - if isinstance(b, (torch.Tensor, TensorDictBase)): - value = b == d - value = value.all() - else: - d_flat = tree_flatten(d)[0] - b_flat = tree_flatten(b)[0] - value = all((_b == _d).all() for (_b, _d) in zip(b_flat, d_flat)) - if value: - break - else: - raise RuntimeError("did not find match") - - data2 = self._get_data(datatype, size=2 * size + 2) - cond = ( - OLD_TORCH - and writer is not TensorDictMaxValueWriter - and size < len(data2) - and isinstance(rb.storage, TensorStorage) - ) - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data2) - - @pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" - ) - # Compiling on Windows requires "cl" compiler to be installed. - # - # Our Windows CI jobs do not have "cl", so skip this test. - @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") - @pytest.mark.skipif( - sys.version_info >= (3, 14), - reason="torch.compile is not supported on Python 3.14+", - ) - @pytest.mark.parametrize("avoid_max_size", [False, True]) - def test_extend_sample_recompile( - self, rb_type, sampler, writer, storage, size, datatype, avoid_max_size - ): - if rb_type is not ReplayBuffer: - pytest.skip( - "Only replay buffer of type 'ReplayBuffer' is currently supported." - ) - if sampler is not RandomSampler: - pytest.skip("Only sampler of type 'RandomSampler' is currently supported.") - if storage is not LazyTensorStorage: - pytest.skip( - "Only storage of type 'LazyTensorStorage' is currently supported." - ) - if writer is not RoundRobinWriter: - pytest.skip( - "Only writer of type 'RoundRobinWriter' is currently supported." - ) - if datatype == "tensordict": - pytest.skip("'tensordict' datatype is not currently supported.") - - torch._dynamo.reset_code_caches() - - # Number of times to extend the replay buffer - num_extend = 10 - data_size = size - - # These two cases are separated because when the max storage size is - # reached, the code execution path changes, causing necessary - # recompiles. - if avoid_max_size: - storage_size = (num_extend + 1) * data_size - else: - storage_size = 2 * data_size - - rb = self._get_rb( - rb_type=rb_type, - sampler=sampler, - writer=writer, - storage=storage, - size=storage_size, - compilable=True, - ) - data = self._get_data(datatype, size=data_size) - - @torch.compile - def extend_and_sample(data): - rb.extend(data) - return rb.sample() - - # NOTE: The first three calls to 'extend' and 'sample' can currently - # cause recompilations, so avoid capturing those. - num_extend_before_capture = 3 - - for _ in range(num_extend_before_capture): - extend_and_sample(data) - - try: - torch._logging.set_logs(recompiles=True) - records = [] - capture_log_records(records, "torch._dynamo", "recompiles") - - for _ in range(num_extend - num_extend_before_capture): - extend_and_sample(data) - - finally: - torch._logging.set_logs() - - assert len(rb) == min((num_extend * data_size), storage_size) - assert len(records) == 0 - - def test_sample(self, rb_type, sampler, writer, storage, size, datatype): - if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: - pytest.skip( - "Distributed package support on Windows is a prototype feature and is subject to changes." - ) - torch.manual_seed(0) - rb = self._get_rb( - rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size - ) - data = self._get_data(datatype, size=5) - cond = ( - OLD_TORCH - and writer is not TensorDictMaxValueWriter - and size < len(data) - and isinstance(rb.storage, TensorStorage) - ) - if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: - with pytest.raises( - RuntimeError, match="expects data to be a tensor collection" - ): - rb.extend(data) - return - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data) - rb_sample = rb.sample() - # if not isinstance(new_data, (torch.Tensor, TensorDictBase)): - # new_data = new_data[0] - - if is_tensor_collection(data) or isinstance(data, torch.Tensor): - rb_sample_iter = rb_sample - else: - - def data_iter_func(maxval, data=data): - for t in range(maxval): - yield tree_map(lambda x, t=t: x[t], data) - - rb_sample_iter = data_iter_func(rb._batch_size, rb_sample) - - for single_sample in rb_sample_iter: - if is_tensor_collection(data) or isinstance(data, torch.Tensor): - data_iter = data - else: - data_iter = data_iter_func(5, data) - - for data_sample in data_iter: - if isinstance(data_sample, TensorDictBase): - keys = set(single_sample.keys()).intersection(data_sample.keys()) - data_sample = data_sample.exclude("index").select( - *keys, strict=False - ) - keys = set(single_sample.keys()).intersection(data_sample.keys()) - single_sample = single_sample.select(*keys, strict=False) - - if isinstance(data_sample, (torch.Tensor, TensorDictBase)): - value = data_sample == single_sample - value = value.all() - else: - d_flat = tree_flatten(single_sample)[0] - b_flat = tree_flatten(data_sample)[0] - value = all((_b == _d).all() for (_b, _d) in zip(b_flat, d_flat)) - - if value: - break - else: - raise RuntimeError("did not find match") - - def test_index(self, rb_type, sampler, writer, storage, size, datatype): - if rb_type is RemoteTensorDictReplayBuffer and _os_is_windows: - pytest.skip( - "Distributed package support on Windows is a prototype feature and is subject to changes." - ) - torch.manual_seed(0) - rb = self._get_rb( - rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size - ) - data = self._get_data(datatype, size=5) - cond = ( - OLD_TORCH - and writer is not TensorDictMaxValueWriter - and size < len(data) - and isinstance(rb.storage, TensorStorage) - ) - if not is_tensor_collection(data) and writer is TensorDictMaxValueWriter: - with pytest.raises( - RuntimeError, match="expects data to be a tensor collection" - ): - rb.extend(data) - return - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data) - d1 = rb[2] - d2 = rb.storage[2] - if type(d1) is not type(d2): - d1 = d1[0] - if is_tensor_collection(data) or isinstance(data, torch.Tensor): - b = d1 == d2 - if not isinstance(b, bool): - b = b.all() - else: - d1_flat = tree_flatten(d1)[0] - d2_flat = tree_flatten(d2)[0] - b = all((_d1 == _d2).all() for (_d1, _d2) in zip(d1_flat, d2_flat)) - assert b - - def test_pickable(self, rb_type, sampler, writer, storage, size, datatype): - rb = self._get_rb( - rb_type=rb_type, - sampler=sampler, - writer=writer, - storage=storage, - size=size, - delayed_init=False, - ) - serialized = pickle.dumps(rb) - rb2 = pickle.loads(serialized) - assert rb.__dict__.keys() == rb2.__dict__.keys() - for key in sorted(rb.__dict__.keys()): - assert isinstance(rb.__dict__[key], type(rb2.__dict__[key])) - - -class TestStorages: - def _get_tensor(self): - return torch.randn(10, 11) - - def _get_tensordict(self): - return TensorDict( - {"data": torch.randn(10, 11), ("nested", "data"): torch.randn(10, 11, 3)}, - [10, 11], - ) - - def _get_pytree(self): - return { - "a": torch.randint(100, (10, 11, 1)), - "b": {"c": [torch.zeros(10, 11), (torch.ones(10, 11),)]}, - 30: torch.zeros(10, 11), - } - - def _get_tensorclass(self): - data = self._get_tensordict() - return make_tc(data)(**data, batch_size=data.shape) - - @pytest.mark.parametrize("storage_type", [TensorStorage]) - def test_errors(self, storage_type): - with pytest.raises(ValueError, match="Expected storage to be non-null"): - storage_type(None) - data = torch.randn(3) - with pytest.raises( - ValueError, match="The max-size and the storage shape mismatch" - ): - storage_type(data, max_size=4) - - def test_existsok_lazymemmap(self, tmpdir): - storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir) - rb = ReplayBuffer(storage=storage0) - rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) - - storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir) - rb = ReplayBuffer(storage=storage1) - with pytest.raises(RuntimeError, match="existsok"): - rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) - - storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True) - rb = ReplayBuffer(storage=storage2) - rb.extend(TensorDict(a=torch.randn(3), batch_size=[3])) - - @pytest.mark.parametrize( - "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] - ) - @pytest.mark.parametrize("storage_type", [TensorStorage]) - def test_get_set(self, storage_type, data_type): - if data_type == "tensor": - data = self._get_tensor() - elif data_type == "tensorclass": - data = self._get_tensorclass() - elif data_type == "tensordict": - data = self._get_tensordict() - elif data_type == "pytree": - data = self._get_pytree() - else: - raise NotImplementedError - storage = storage_type(data) - if data_type == "pytree": - storage.set(range(10), tree_map(torch.zeros_like, data)) - - def check(x): - assert (x == 0).all() - - tree_map(check, storage.get(range(10))) - else: - storage.set(range(10), torch.zeros_like(data)) - assert (storage.get(range(10)) == 0).all() - - @pytest.mark.parametrize( - "data_type", ["tensor", "tensordict", "tensorclass", "pytree"] - ) - @pytest.mark.parametrize("storage_type", [TensorStorage]) - def test_state_dict(self, storage_type, data_type): - if data_type == "tensor": - data = self._get_tensor() - elif data_type == "tensorclass": - data = self._get_tensorclass() - elif data_type == "tensordict": - data = self._get_tensordict() - elif data_type == "pytree": - data = self._get_pytree() - else: - raise NotImplementedError - storage = storage_type(data) - if data_type == "pytree": - with pytest.raises(TypeError, match="are not supported by"): - storage.state_dict() - return - sd = storage.state_dict() - storage2 = storage_type(torch.zeros_like(data)) - storage2.load_state_dict(sd) - assert (storage.get(range(10)) == storage2.get(range(10))).all() - assert type(storage.get(range(10))) is type( # noqa: E721 - storage2.get(range(10)) - ) - - @pytest.mark.gpu - @pytest.mark.skipif( - not torch.cuda.device_count(), - reason="not cuda device found to test rb storage.", - ) - @pytest.mark.parametrize( - "device_data,device_storage", - [ - [torch.device("cuda"), torch.device("cpu")], - [torch.device("cpu"), torch.device("cuda")], - [torch.device("cpu"), "auto"], - [torch.device("cuda"), "auto"], - ], - ) - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - @pytest.mark.parametrize("data_type", ["tensor", "tc", "td"]) - def test_storage_device(self, device_data, device_storage, storage_type, data_type): - @tensorclass - class TC: - a: torch.Tensor - - if data_type == "tensor": - data = torch.randn(3, device=device_data) - elif data_type == "td": - data = TensorDict( - {"a": torch.randn(3, device=device_data)}, [], device=device_data - ) - elif data_type == "tc": - data = TC( - a=torch.randn(3, device=device_data), - batch_size=[], - device=device_data, - ) - else: - raise NotImplementedError - - if ( - storage_type is LazyMemmapStorage - and device_storage != "auto" - and device_storage.type != "cpu" - ): - with pytest.raises(ValueError, match="Memory map device other than CPU"): - storage_type(max_size=10, device=device_storage) - return - storage = storage_type(max_size=10, device=device_storage) - storage.set(0, data) - if device_storage != "auto": - assert storage.get(0).device.type == device_storage.type - else: - assert storage.get(0).device.type == storage.device.type - - @pytest.mark.parametrize("storage_in", ["tensor", "memmap"]) - @pytest.mark.parametrize("storage_out", ["tensor", "memmap"]) - @pytest.mark.parametrize("init_out", [True, False]) - @pytest.mark.parametrize( - "backend", ["torch"] + (["torchsnapshot"] if _has_snapshot else []) - ) - def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): - os.environ["CKPT_BACKEND"] = backend - buffer_size = 100 - if storage_in == "memmap": - storage_in = LazyMemmapStorage(buffer_size, device="cpu") - elif storage_in == "tensor": - storage_in = LazyTensorStorage(buffer_size, device="cpu") - if storage_out == "memmap": - storage_out = LazyMemmapStorage(buffer_size, device="cpu") - elif storage_out == "tensor": - storage_out = LazyTensorStorage(buffer_size, device="cpu") - - replay_buffer = TensorDictReplayBuffer( - pin_memory=False, prefetch=3, storage=storage_in, batch_size=3 - ) - # fill replay buffer with random data - transition = TensorDict( - { - "observation": torch.ones(1, 4), - "action": torch.ones(1, 2), - "reward": torch.ones(1, 1), - "dones": torch.ones(1, 1), - "next": {"observation": torch.ones(1, 4)}, - }, - batch_size=1, - ) - for _ in range(3): - replay_buffer.extend(transition) - - state_dict = replay_buffer.state_dict() - - new_replay_buffer = TensorDictReplayBuffer( - pin_memory=False, - prefetch=3, - storage=storage_out, - batch_size=state_dict["_batch_size"], - ) - if init_out: - new_replay_buffer.extend(transition) - - new_replay_buffer.load_state_dict(state_dict) - s = new_replay_buffer.sample() - assert (s.exclude("index") == 1).all() - - @pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" - ) - @pytest.mark.skipif(_os_is_windows, reason="windows tests do not support compile") - @pytest.mark.skipif( - sys.version_info >= (3, 14), - reason="torch.compile is not supported on Python 3.14+", - ) - # This test checks if the `torch._dynamo.disable` wrapper around - # `TensorStorage._rand_given_ndim` is still necessary. - def test__rand_given_ndim_recompile(self): - torch._dynamo.reset_code_caches() - - # Number of times to extend the replay buffer - num_extend = 5 - data_size = 50 - storage_size = (num_extend + 1) * data_size - sample_size = 3 - - storage = LazyTensorStorage(storage_size, compilable=True) - sampler = RandomSampler() - - # Override to avoid the `torch._dynamo.disable` wrapper - storage._rand_given_ndim = storage._rand_given_ndim_impl - - @torch.compile - def extend_and_sample(data): - storage.set(torch.arange(data_size) + len(storage), data) - return sampler.sample(storage, sample_size) - - data = torch.randint(100, (data_size, 1)) - - try: - torch._logging.set_logs(recompiles=True) - records = [] - capture_log_records(records, "torch._dynamo", "recompiles") - - for _ in range(num_extend): - extend_and_sample(data) - - finally: - torch._logging.set_logs() - - assert len(storage) == num_extend * data_size - assert len(records) <= 8, ( - "Excessive recompilations detected. Expected 8 or fewer, but got " - f"{len(records)}. This suggests the `torch.compiler.disable` " - "decorators may not be working properly or new recompilation " - "sources have been introduced." - ) - - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - def test_extend_lazystack(self, storage_type): - rb = ReplayBuffer( - storage=storage_type(6), - batch_size=2, - ) - td1 = TensorDict(a=torch.rand(5, 4, 8), batch_size=5) - td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5) - ltd = LazyStackedTensorDict(td1, td2, stack_dim=1) - rb.extend(ltd) - rb.sample(3) - assert len(rb) == 5 - - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - def test_extend_lazystack_direct_write(self, storage_type): - """Test that lazy stacks can be extended to storage correctly. - - This tests that lazy stacks from collectors are properly stored in - replay buffers and that the data integrity is preserved. Also verifies - that the update_() optimization is used for tensor indices. - """ - rb = ReplayBuffer( - storage=storage_type(100), - batch_size=10, - ) - # Create a list of tensordicts (like a collector would produce) - tensordicts = [ - TensorDict( - {"obs": torch.rand(4, 8), "action": torch.rand(2)}, batch_size=() - ) - for _ in range(10) - ] - # Create lazy stack with stack_dim=0 (the batch dimension) - lazy_td = LazyStackedTensorDict.lazy_stack(tensordicts, dim=0) - assert isinstance(lazy_td, LazyStackedTensorDict) - - # Track calls to update_at_() - used for tensor indices - update_at_called = [] - original_update_at = TensorDictBase.update_at_ - - def mock_update_at(self, *args, **kwargs): - update_at_called.append(True) - return original_update_at(self, *args, **kwargs) - - # Extend with lazy stack and verify update_at_() is called - # (rb.extend uses tensor indices, so update_at_() path is taken) - with mock.patch.object(TensorDictBase, "update_at_", mock_update_at): - rb.extend(lazy_td) - - # Verify update_at_() was called (optimization was used) - assert len(update_at_called) > 0, "update_at_() should have been called" - - # Verify data integrity - assert len(rb) == 10 - sample = rb.sample(5) - assert sample["obs"].shape == (5, 4, 8) - assert sample["action"].shape == (5, 2) - - # Verify all data is accessible by reading the entire storage - all_data = rb[:] - assert all_data["obs"].shape == (10, 4, 8) - assert all_data["action"].shape == (10, 2) - - # Verify data values are preserved (check against original stacked data) - expected = lazy_td.to_tensordict() - assert torch.allclose(all_data["obs"], expected["obs"]) - assert torch.allclose(all_data["action"], expected["action"]) - - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - def test_extend_lazystack_2d_storage(self, storage_type): - """Test lazy stack optimization for 2D storage (parallel envs). - - When using parallel environments, the storage is 2D [max_size, n_steps] - and the lazy stack has stack_dim=1 (time dimension). This test verifies - the optimization handles this case correctly. - """ - n_envs = 4 - n_steps = 10 - img_shape = (3, 32, 32) - - # Create 2D storage - capacity is 100 * n_steps when ndim=2 - storage = storage_type(100 * n_steps, ndim=2) - - # Pre-initialize storage with correct shape by setting first element - init_td = TensorDict( - {"pixels": torch.zeros(n_steps, *img_shape)}, - batch_size=[n_steps], - ) - storage.set(0, init_td, set_cursor=False) - - # Expand storage to full size - full_init = TensorDict( - {"pixels": torch.zeros(100, n_steps, *img_shape)}, - batch_size=[100, n_steps], - ) - storage.set(slice(0, 100), full_init, set_cursor=False) - - # Create lazy stack simulating parallel env output - # stack_dim=1 means stacked along time dimension - time_tds = [ - TensorDict( - {"pixels": torch.rand(n_envs, *img_shape)}, - batch_size=[n_envs], - ) - for _ in range(n_steps) - ] - lazy_td = LazyStackedTensorDict.lazy_stack(time_tds, dim=1) - assert lazy_td.stack_dim == 1 - assert lazy_td.batch_size == torch.Size([n_envs, n_steps]) - - # Write using tensor indices (simulating circular buffer behavior) - cursor = torch.tensor([0, 1, 2, 3]) - storage.set(cursor, lazy_td) - - # Verify data integrity - for i in range(n_envs): - stored = storage[i] - expected = lazy_td[i].to_tensordict() - assert torch.allclose( - stored["pixels"], expected["pixels"] - ), f"Data mismatch for env {i}" - - @pytest.mark.parametrize("device_data", get_default_devices()) - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) - @pytest.mark.parametrize("isinit", [True, False]) - def test_storage_dumps_loads( - self, device_data, storage_type, data_type, isinit, tmpdir - ): - torch.manual_seed(0) - - dir_rb = tmpdir / "rb" - dir_save = tmpdir / "save" - dir_rb.mkdir() - dir_save.mkdir() - torch.manual_seed(0) - - @tensorclass - class TC: - tensor: torch.Tensor - td: TensorDict - text: str - - if data_type == "tensor": - data = torch.randint(10, (3,), device=device_data) - elif data_type == "pytree": - data = { - "a": torch.randint(10, (3,), device=device_data), - "b": {"c": [torch.ones(3), (-torch.ones(3, 2),)]}, - 30: -torch.ones(3, 1), - } - elif data_type == "td": - data = TensorDict( - { - "a": torch.randint(10, (3,), device=device_data), - "b": TensorDict( - {"c": torch.randint(10, (3,), device=device_data)}, - batch_size=[3], - ), - }, - batch_size=[3], - device=device_data, - ) - elif data_type == "tc": - data = TC( - tensor=torch.randint(10, (3,), device=device_data), - td=TensorDict( - {"c": torch.randint(10, (3,), device=device_data)}, batch_size=[3] - ), - text="some text", - batch_size=[3], - device=device_data, - ) - else: - raise NotImplementedError - - if storage_type in (LazyMemmapStorage,): - storage = storage_type(max_size=10, scratch_dir=dir_rb) - else: - storage = storage_type(max_size=10) - - # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index - if data_type == "pytree": - storage.set(range(3), tree_map(lambda x: x.cpu(), data)) - else: - storage.set(range(3), data.cpu()) - - storage.dumps(dir_save) - # check we can dump twice - storage.dumps(dir_save) - - storage_recover = storage_type(max_size=10) - if isinit: - if data_type == "pytree": - storage_recover.set( - range(3), tree_map(lambda x: x.cpu().clone().zero_(), data) - ) - else: - storage_recover.set(range(3), data.cpu().clone().zero_()) - - if data_type in ("tensor", "pytree") and not isinit: - with pytest.raises( - RuntimeError, - match="Cannot fill a non-initialized pytree-based TensorStorage", - ): - storage_recover.loads(dir_save) - return - storage_recover.loads(dir_save) - # tree_map with more than one pytree is only available in torch >= 2.3 - if torch_2_3: - if data_type in ("tensor", "pytree"): - tree_map( - torch.testing.assert_close, - tree_flatten(storage[:])[0], - tree_flatten(storage_recover[:])[0], - ) - else: - assert_allclose_td(storage[:], storage_recover[:]) - if data == "tc": - assert storage._storage.text == storage_recover._storage.text - - def test_add_list_of_tds(self): - rb = ReplayBuffer(storage=LazyTensorStorage(100)) - rb.extend([TensorDict({"a": torch.randn(2, 3)}, [2])]) - assert len(rb) == 1 - assert rb[:].shape == torch.Size([1, 2]) - - @pytest.mark.parametrize( - "storage_type,collate_fn", - [ - (LazyTensorStorage, None), - (LazyMemmapStorage, None), - (ListStorage, torch.stack), - ], - ) - def test_storage_inplace_writing(self, storage_type, collate_fn): - rb = ReplayBuffer(storage=storage_type(102), collate_fn=collate_fn) - data = TensorDict( - {"a": torch.arange(100), ("b", "c"): torch.arange(100)}, [100] - ) - rb.extend(data) - assert len(rb) == 100 - rb[3:4] = TensorDict( - {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] - ) - assert (rb[3:4] == 0).all() - assert len(rb) == 100 - assert rb.writer._cursor == 100 - rb[10:20] = TensorDict( - {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] - ) - assert (rb[10:20] == 0).all() - assert len(rb) == 100 - assert rb.writer._cursor == 100 - rb[torch.arange(30, 40)] = TensorDict( - {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] - ) - assert (rb[30:40] == 0).all() - assert len(rb) == 100 - - @pytest.mark.parametrize( - "storage_type,collate_fn", - [ - (LazyTensorStorage, None), - (LazyMemmapStorage, None), - (ListStorage, torch.stack), - ], - ) - def test_storage_inplace_writing_transform(self, storage_type, collate_fn): - rb = ReplayBuffer(storage=storage_type(102), collate_fn=collate_fn) - rb.append_transform(lambda x: x + 1, invert=True) - rb.append_transform(lambda x: x + 1) - data = TensorDict( - {"a": torch.arange(100), ("b", "c"): torch.arange(100)}, [100] - ) - rb.extend(data) - assert len(rb) == 100 - rb[3:4] = TensorDict( - {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] - ) - assert (rb[3:4] == 2).all(), rb[3:4]["a"] - assert len(rb) == 100 - assert rb.writer._cursor == 100 - rb[10:20] = TensorDict( - {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] - ) - assert (rb[10:20] == 2).all() - assert len(rb) == 100 - assert rb.writer._cursor == 100 - rb[torch.arange(30, 40)] = TensorDict( - {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] - ) - assert (rb[30:40] == 2).all() - assert len(rb) == 100 - - @pytest.mark.parametrize( - "storage_type,collate_fn", - [ - (LazyTensorStorage, None), - # (LazyMemmapStorage, None), - (ListStorage, TensorDict.maybe_dense_stack), - ], - ) - def test_storage_inplace_writing_newkey(self, storage_type, collate_fn): - rb = ReplayBuffer(storage=storage_type(102), collate_fn=collate_fn) - data = TensorDict( - {"a": torch.arange(100), ("b", "c"): torch.arange(100)}, [100] - ) - rb.extend(data) - assert len(rb) == 100 - rb[3:4] = TensorDict( - {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0]), "d": torch.ones(1)}, - [1], - ) - assert "d" in rb[3] - assert "d" in rb[3:4] - if storage_type is not ListStorage: - assert "d" in rb[3:5] - else: - # a lazy stack doesn't show exclusive fields - assert "d" not in rb[3:5] - - @pytest.mark.parametrize("storage_type", [LazyTensorStorage, LazyMemmapStorage]) - def test_storage_inplace_writing_ndim(self, storage_type): - rb = ReplayBuffer(storage=storage_type(102, ndim=2)) - data = TensorDict( - { - "a": torch.arange(50).expand(2, 50), - ("b", "c"): torch.arange(50).expand(2, 50), - }, - [2, 50], - ) - rb.extend(data) - assert len(rb) == 100 - rb[0, 3:4] = TensorDict( - {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] - ) - assert (rb[0, 3:4] == 0).all() - assert (rb[1, 3:4] != 0).all() - assert rb.writer._cursor == 50 - rb[1, 5:6] = TensorDict( - {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] - ) - assert (rb[1, 5:6] == 0).all() - assert rb.writer._cursor == 50 - rb[:, 7:8] = TensorDict( - {"a": torch.tensor([0]), ("b", "c"): torch.tensor([0])}, [1] - ).expand(2, 1) - assert (rb[:, 7:8] == 0).all() - assert rb.writer._cursor == 50 - # test broadcasting - rb[:, 10:20] = TensorDict( - {"a": torch.tensor([0] * 10), ("b", "c"): torch.tensor([0] * 10)}, [10] - ) - assert (rb[:, 10:20] == 0).all() - assert len(rb) == 100 - - @pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" - ) - @pytest.mark.parametrize("max_size", [1000, None]) - @pytest.mark.parametrize("stack_dim", [-1, 0]) - def test_lazy_stack_storage(self, max_size, stack_dim): - # Create an instance of LazyStackStorage with given parameters - storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim) - # Create a ReplayBuffer using the created storage - rb = ReplayBuffer(storage=storage) - # Generate some random data to add to the buffer - torch.manual_seed(0) - data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!") - data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!") - # Add the data to the buffer - rb.add(data0) - rb.add(data1) - # Sample from the buffer - sample = rb.sample(10) - # Check that the sampled data has the correct shape and type - assert isinstance(sample, LazyStackedTensorDict) - assert sample["b"].shape[0] == 10 - assert all(isinstance(item, str) for item in sample["c"]) - # If densify is True, check that the sampled data is dense - sample = sample.densify(layout=torch.jagged) - assert isinstance(sample["a"], torch.Tensor) - assert sample["a"].shape[0] == 10 - - -@pytest.mark.parametrize("max_size", [1000]) -@pytest.mark.parametrize("shape", [[3, 4]]) -@pytest.mark.parametrize("storage", [LazyTensorStorage, LazyMemmapStorage]) -class TestLazyStorages: - def _get_nested_tensorclass(self, shape): - @tensorclass - class NestedTensorClass: - key1: torch.Tensor - key2: torch.Tensor - - @tensorclass - class TensorClass: - key1: torch.Tensor - key2: torch.Tensor - next: NestedTensorClass - - return TensorClass( - key1=torch.ones(*shape), - key2=torch.ones(*shape), - next=NestedTensorClass( - key1=torch.ones(*shape), key2=torch.ones(*shape), batch_size=shape - ), - batch_size=shape, - ) - - def _get_nested_td(self, shape): - nested_td = TensorDict( - { - "key1": torch.ones(*shape), - "key2": torch.ones(*shape), - "next": TensorDict( - { - "key1": torch.ones(*shape), - "key2": torch.ones(*shape), - }, - shape, - ), - }, - shape, - ) - return nested_td - - def test_init(self, max_size, shape, storage): - td = self._get_nested_td(shape) - mystorage = storage(max_size=max_size) - mystorage._init(td) - assert mystorage._storage.shape == (max_size, *shape) - - def test_set(self, max_size, shape, storage): - td = self._get_nested_td(shape) - mystorage = storage(max_size=max_size) - mystorage.set(list(range(td.shape[0])), td) - assert mystorage._storage.shape == (max_size, *shape[1:]) - idx = list(range(1, td.shape[0] - 1)) - tc_sample = mystorage.get(idx) - assert tc_sample.shape == torch.Size([td.shape[0] - 2, *td.shape[1:]]) - - def test_init_tensorclass(self, max_size, shape, storage): - tc = self._get_nested_tensorclass(shape) - mystorage = storage(max_size=max_size) - mystorage._init(tc) - assert is_tensorclass(mystorage._storage) - assert mystorage._storage.shape == (max_size, *shape) - - def test_set_tensorclass(self, max_size, shape, storage): - tc = self._get_nested_tensorclass(shape) - mystorage = storage(max_size=max_size) - mystorage.set(list(range(tc.shape[0])), tc) - assert mystorage._storage.shape == (max_size, *shape[1:]) - idx = list(range(1, tc.shape[0] - 1)) - tc_sample = mystorage.get(idx) - assert tc_sample.shape == torch.Size([tc.shape[0] - 2, *tc.shape[1:]]) - - def test_extend_list_pytree(self, max_size, shape, storage): - memory = ReplayBuffer( - storage=storage(max_size=max_size), - sampler=SamplerWithoutReplacement(), - ) - data = [ - ( - torch.full(shape, i), - {"a": torch.full(shape, i), "b": (torch.full(shape, i))}, - [torch.full(shape, i)], - ) - for i in range(10) - ] - memory.extend(data) - assert len(memory) == 10 - assert len(memory._storage) == 10 - sample = memory.sample(10) - for leaf in tree_iter(sample): - assert (leaf.unique(sorted=True) == torch.arange(10)).all() - memory = ReplayBuffer( - storage=storage(max_size=max_size), - sampler=SamplerWithoutReplacement(), - ) - t1x4 = torch.Tensor([0.1, 0.2, 0.3, 0.4]) - t1x1 = torch.Tensor([0.01]) - with pytest.raises( - RuntimeError, match="Stacking the elements of the list resulted in an error" - ): - memory.extend([t1x4, t1x1, t1x4 + 0.4, t1x1 + 0.01]) - - -@pytest.mark.parametrize("priority_key", ["pk", "td_error"]) -@pytest.mark.parametrize("contiguous", [True, False]) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("alpha", [0.0, 0.7]) -def test_ptdrb(priority_key, contiguous, alpha, device): - torch.manual_seed(0) - np.random.seed(0) - rb = TensorDictReplayBuffer( - sampler=samplers.PrioritizedSampler(5, alpha=alpha, beta=0.9), - priority_key=priority_key, - batch_size=5, - ) - td1 = TensorDict( - source={ - "a": torch.randn(3, 1), - priority_key: torch.rand(3, 1) / 10, - "_idx": torch.arange(3).view(3, 1), - }, - batch_size=[3], - device=device, - ) - rb.extend(td1) - s = rb.sample() - assert s.batch_size == torch.Size([5]) - assert (td1[s.get("_idx").squeeze()].get("a") == s.get("a")).all() - assert_allclose_td(td1[s.get("_idx").squeeze()].select("a"), s.select("a")) - - # test replacement - td2 = TensorDict( - source={ - "a": torch.randn(5, 1), - priority_key: torch.rand(5, 1) / 10, - "_idx": torch.arange(5).view(5, 1), - }, - batch_size=[5], - device=device, - ) - rb.extend(td2) - s = rb.sample() - assert s.batch_size == torch.Size([5]) - assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() - assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) - - if ( - alpha == 0.0 - ): # when alpha is 0.0, sampling is uniform, so no need to check priority sampling - return - - # test strong update - # get all indices that match first item - idx = s.get("_idx") - idx_match = (idx == idx[0]).nonzero()[:, 0] - s.set_at_( - priority_key, - torch.ones(idx_match.numel(), 1, device=device) * 100000000, - idx_match, - ) - val = s.get("a")[0] - - idx0 = s.get("_idx")[0] - rb.update_tensordict_priority(s) - s = rb.sample() - assert (val == s.get("a")).sum() >= 1 - torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) - - # test updating values of original td - td2.set_("a", torch.ones_like(td2.get("a"))) - s = rb.sample() - torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) - - -@pytest.mark.gpu -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_cuda_segment_tree_parity(): - ext = pytest.importorskip("torchrl._torchrl") - if not hasattr(ext, "CudaSumSegmentTreeFp32"): - pytest.skip("TorchRL was not built with CUDA segment tree support") - CudaMinSegmentTreeFp32 = ext.CudaMinSegmentTreeFp32 - CudaSumSegmentTreeFp32 = ext.CudaSumSegmentTreeFp32 - MinSegmentTreeFp32 = ext.MinSegmentTreeFp32 - SumSegmentTreeFp32 = ext.SumSegmentTreeFp32 - - device = torch.device("cuda:0") - size = 16 - index = torch.tensor([0, 3, 4, 7, 12, 15], device=device) - value = torch.tensor([1.0, 2.0, 4.0, 8.0, 16.0, 32.0], device=device) - - cpu_sum = SumSegmentTreeFp32(size) - cpu_min = MinSegmentTreeFp32(size) - cuda_sum = CudaSumSegmentTreeFp32(size, device) - cuda_min = CudaMinSegmentTreeFp32(size, device) - - cpu_sum[index.cpu()] = value.cpu() - cpu_min[index.cpu()] = value.cpu() - cuda_sum[index] = value - cuda_min[index] = value - - left = torch.tensor([0, 3, 4, 7], device=device) - right = torch.tensor([16, 8, 13, 16], device=device) - torch.testing.assert_close( - cuda_sum.query(left, right).cpu(), cpu_sum.query(left.cpu(), right.cpu()) - ) - torch.testing.assert_close( - cuda_min.query(left, right).cpu(), cpu_min.query(left.cpu(), right.cpu()) - ) - - mass = torch.tensor([0.5, 1.0, 2.9, 7.1, 30.0], device=device) - torch.testing.assert_close( - cuda_sum.scan_lower_bound(mass).cpu(), - cpu_sum.scan_lower_bound(mass.cpu()), - ) - - -@pytest.mark.gpu -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_cuda_prioritized_replay_buffer_samples_on_cuda(): - ext = pytest.importorskip("torchrl._torchrl") - if not hasattr(ext, "CudaSumSegmentTreeFp32"): - pytest.skip("TorchRL was not built with CUDA segment tree support") - device = torch.device("cuda:0") - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(32, device=device), - sampler=PrioritizedSampler(max_capacity=32, alpha=0.7, beta=0.5), - batch_size=8, - priority_key="td_error", - ) - data = TensorDict( - { - "obs": torch.arange(16, device=device).float().unsqueeze(-1), - "td_error": torch.linspace(0.1, 1.0, 16, device=device), - }, - batch_size=[16], - device=device, - ) - - rb.extend(data) - sample = rb.sample() - - assert sample.device == device - assert sample["index"].device == device - assert sample["priority_weight"].device == device - - sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 - rb.update_tensordict_priority(sample) - sample = rb.sample() - assert sample["index"].device == device - assert sample["priority_weight"].device == device - - -def test_tensordict_prioritized_replay_buffer_sampler_device_cpu(): - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - storage=LazyTensorStorage(32), - sampler_device="cpu", - batch_size=8, - priority_key="td_error", - ) - data = TensorDict( - { - "obs": torch.arange(16).float().unsqueeze(-1), - "td_error": torch.linspace(0.1, 1.0, 16), - }, - batch_size=[16], - ) - - rb.extend(data) - sample = rb.sample() - - assert rb._sampler.device == torch.device("cpu") - assert sample["index"].device == torch.device("cpu") - assert sample["priority_weight"].device == torch.device("cpu") - - -@pytest.mark.gpu -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_tensordict_prioritized_replay_buffer_memmap_storage_cuda_sampler(tmpdir): - ext = pytest.importorskip("torchrl._torchrl") - if not hasattr(ext, "CudaSumSegmentTreeFp32"): - pytest.skip("TorchRL was not built with CUDA segment tree support") - - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - storage=LazyMemmapStorage(32, scratch_dir=tmpdir), - sampler_device="cuda:0", - batch_size=8, - priority_key="td_error", - ) - data = TensorDict( - { - "obs": torch.arange(16).float().unsqueeze(-1), - "td_error": torch.linspace(0.1, 1.0, 16), - }, - batch_size=[16], - ) - - rb.extend(data) - sample = rb.sample() - - assert rb._sampler.device == torch.device("cuda:0") - assert sample["obs"].device.type == "cpu" - assert sample["index"].device.type == "cpu" - assert sample["priority_weight"].device.type == "cpu" - - sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 - rb.update_tensordict_priority(sample) - sample = rb.sample() - assert sample["index"].device.type == "cpu" - assert rb._sampler.device == torch.device("cuda:0") - - -@pytest.mark.gpu -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_tensordict_prioritized_replay_buffer_cuda_storage_cpu_sampler(): - device = torch.device("cuda:0") - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - storage=LazyTensorStorage(32, device=device), - sampler_device="cpu", - batch_size=8, - priority_key="td_error", - ) - data = TensorDict( - { - "obs": torch.arange(16, device=device).float().unsqueeze(-1), - "td_error": torch.linspace(0.1, 1.0, 16, device=device), - }, - batch_size=[16], - device=device, - ) - - rb.extend(data) - sample = rb.sample() - - assert rb._sampler.device == torch.device("cpu") - assert sample.device == device - assert sample["index"].device == device - assert sample["priority_weight"].device == device - - sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 - rb.update_tensordict_priority(sample) - sample = rb.sample() - assert sample.device == device - assert rb._sampler.device == torch.device("cpu") - - -@pytest.mark.gpu -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") -def test_cuda_prioritized_replay_buffer_weight_matches_cpu_formula(): - ext = pytest.importorskip("torchrl._torchrl") - if not hasattr(ext, "CudaSumSegmentTreeFp32"): - pytest.skip("TorchRL was not built with CUDA segment tree support") - - size = 64 - batch_size = 16 - alpha = 0.7 - beta = 0.5 - eps = 1e-8 - priorities = torch.linspace(0.1, 2.0, size) - expected_tree_priority = (priorities + eps).pow(alpha) - min_tree_priority = expected_tree_priority.min() - - def make_rb(device): - device = torch.device(device) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(size, device=device), - sampler=PrioritizedSampler( - max_capacity=size, - alpha=alpha, - beta=beta, - eps=eps, - device=device, - ), - batch_size=batch_size, - priority_key="td_error", - ) - data = TensorDict( - { - "obs": torch.arange(size, device=device), - "td_error": priorities.to(device), - }, - batch_size=[size], - device=device, - ) - rb.extend(data) - return rb - - cpu_rb = make_rb("cpu") - cuda_rb = make_rb("cuda:0") - - for rb, device in ( - (cpu_rb, torch.device("cpu")), - (cuda_rb, torch.device("cuda:0")), - ): - for _ in range(8): - sample = rb.sample() - index = sample["index"].to("cpu") - expected_weight = (expected_tree_priority[index] / min_tree_priority).pow( - -beta - ) - torch.testing.assert_close(sample["obs"].to("cpu"), index) - torch.testing.assert_close(sample["td_error"].to("cpu"), priorities[index]) - torch.testing.assert_close( - sample["priority_weight"].to("cpu"), expected_weight - ) - assert sample["index"].device == device - assert sample["priority_weight"].device == device - - -@pytest.mark.parametrize("stack", [False, True]) -@pytest.mark.parametrize("datatype", ["tc", "tb"]) -@pytest.mark.parametrize("reduction", ["min", "max", "median", "mean"]) -def test_replay_buffer_trajectories(stack, reduction, datatype): - traj_td = TensorDict( - {"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)}, - batch_size=[3, 4], - ) - rbcls = functools.partial(TensorDictReplayBuffer, priority_key="td_error") - if datatype == "tc": - c = make_tc(traj_td) - rbcls = functools.partial(ReplayBuffer, storage=LazyTensorStorage(100)) - traj_td = c(**traj_td, batch_size=traj_td.batch_size) - assert is_tensorclass(traj_td) - elif datatype != "tb": - raise NotImplementedError - - if stack: - traj_td = torch.stack(list(traj_td), 0) - - rb = rbcls( - sampler=samplers.PrioritizedSampler( - 5, - alpha=0.7, - beta=0.9, - reduction=reduction, - ), - batch_size=3, - ) - rb.extend(traj_td) - if datatype == "tc": - sampled_td, info = rb.sample(return_info=True) - index = info["index"] - else: - sampled_td = rb.sample() - if datatype == "tc": - assert is_tensorclass(traj_td) - return - - sampled_td.set("td_error", torch.rand(sampled_td.shape)) - if datatype == "tc": - rb.update_priority(index, sampled_td) - sampled_td, info = rb.sample(return_info=True) - assert (info["priority_weight"] > 0).all() - assert sampled_td.batch_size == torch.Size([3, 4]) - else: - rb.update_tensordict_priority(sampled_td) - sampled_td = rb.sample(include_info=True) - assert (sampled_td.get("priority_weight") > 0).all() - assert sampled_td.batch_size == torch.Size([3, 4]) - - # # set back the trajectory length - # sampled_td_filtered = sampled_td.to_tensordict().exclude( - # "priority_weight", "index", "td_error" - # ) - # sampled_td_filtered.batch_size = [3, 4] - - -class TestRNG: - def test_rb_rng(self): - state = torch.random.get_rng_state() - rb = ReplayBufferRNG( - sampler=RandomSampler(), storage=LazyTensorStorage(100), delayed_init=False - ) - assert rb.initialized - rb.extend(torch.arange(100)) - rb._rng.set_state(state) - a = rb.sample(32) - rb._rng.set_state(state) - b = rb.sample(32) - assert (a == b).all() - c = rb.sample(32) - assert (a != c).any() - - def test_prb_rng(self): - state = torch.random.get_rng_state() - rb = ReplayBuffer( - sampler=PrioritizedSampler(100, 1.0, 1.0), - storage=LazyTensorStorage(100), - generator=torch.Generator(), - ) - rb.extend(torch.arange(100)) - rb.update_priority(index=torch.arange(100), priority=torch.arange(1, 101)) - - rb._rng.set_state(state) - a = rb.sample(32) - - rb._rng.set_state(state) - b = rb.sample(32) - assert (a == b).all() - - c = rb.sample(32) - assert (a != c).any() - - def test_slice_rng(self): - state = torch.random.get_rng_state() - rb = ReplayBuffer( - sampler=SliceSampler(num_slices=4), - storage=LazyTensorStorage(100), - generator=torch.Generator(), - ) - done = torch.zeros(100, 1, dtype=torch.bool) - done[49] = 1 - done[-1] = 1 - data = TensorDict( - { - "data": torch.arange(100), - ("next", "done"): done, - }, - batch_size=[100], - ) - rb.extend(data) - - rb._rng.set_state(state) - a = rb.sample(32) - - rb._rng.set_state(state) - b = rb.sample(32) - assert (a == b).all() - - c = rb.sample(32) - assert (a != c).any() - - def test_rng_state_dict(self): - state = torch.random.get_rng_state() - rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) - rb.extend(torch.arange(100)) - rb._rng.set_state(state) - sd = rb.state_dict() - assert sd.get("_rng") is not None - a = rb.sample(32) - - rb.load_state_dict(sd) - b = rb.sample(32) - assert (a == b).all() - c = rb.sample(32) - assert (a != c).any() - - def test_rng_dumps(self, tmpdir): - state = torch.random.get_rng_state() - rb = ReplayBufferRNG(sampler=RandomSampler(), storage=LazyTensorStorage(100)) - rb.extend(torch.arange(100)) - rb._rng.set_state(state) - rb.dumps(tmpdir) - a = rb.sample(32) - - rb.loads(tmpdir) - b = rb.sample(32) - assert (a == b).all() - c = rb.sample(32) - assert (a != c).any() - - -@pytest.mark.parametrize( - "rbtype,storage", - [ - (ReplayBuffer, None), - (ReplayBuffer, ListStorage), - (ReplayBufferRNG, None), - (ReplayBufferRNG, ListStorage), - (PrioritizedReplayBuffer, None), - (PrioritizedReplayBuffer, ListStorage), - (TensorDictReplayBuffer, None), - (TensorDictReplayBuffer, ListStorage), - (TensorDictReplayBuffer, LazyTensorStorage), - (TensorDictReplayBuffer, LazyMemmapStorage), - (TensorDictReplayBufferRNG, None), - (TensorDictReplayBufferRNG, ListStorage), - (TensorDictReplayBufferRNG, LazyTensorStorage), - (TensorDictReplayBufferRNG, LazyMemmapStorage), - (TensorDictPrioritizedReplayBuffer, None), - (TensorDictPrioritizedReplayBuffer, ListStorage), - (TensorDictPrioritizedReplayBuffer, LazyTensorStorage), - (TensorDictPrioritizedReplayBuffer, LazyMemmapStorage), - ], -) -@pytest.mark.parametrize("size", [3, 5, 100]) -@pytest.mark.parametrize("prefetch", [0]) -class TestBuffers: - default_constr = { - ReplayBuffer: ReplayBuffer, - PrioritizedReplayBuffer: functools.partial( - PrioritizedReplayBuffer, alpha=0.8, beta=0.9 - ), - TensorDictReplayBuffer: TensorDictReplayBuffer, - TensorDictPrioritizedReplayBuffer: functools.partial( - TensorDictPrioritizedReplayBuffer, alpha=0.8, beta=0.9 - ), - TensorDictReplayBufferRNG: TensorDictReplayBufferRNG, - ReplayBufferRNG: ReplayBufferRNG, - } - - def _get_rb(self, rbtype, size, storage, prefetch): - if storage is not None: - storage = storage(size) - rb = self.default_constr[rbtype]( - storage=storage, prefetch=prefetch, batch_size=3 - ) - return rb - - def _get_datum(self, rbtype): - if rbtype in (ReplayBuffer, ReplayBufferRNG): - data = torch.randint(100, (1,)) - elif rbtype is PrioritizedReplayBuffer: - data = torch.randint(100, (1,)) - elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): - data = TensorDict({"a": torch.randint(100, (1,))}, []) - elif rbtype is TensorDictPrioritizedReplayBuffer: - data = TensorDict({"a": torch.randint(100, (1,))}, []) - else: - raise NotImplementedError(rbtype) - return data - - def _get_data(self, rbtype, size): - if rbtype in (ReplayBuffer, ReplayBufferRNG): - data = [torch.randint(100, (1,)) for _ in range(size)] - elif rbtype is PrioritizedReplayBuffer: - data = [torch.randint(100, (1,)) for _ in range(size)] - elif rbtype in (TensorDictReplayBuffer, TensorDictReplayBufferRNG): - data = TensorDict( - { - "a": torch.randint(100, (size,)), - "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), - }, - [size], - ) - elif rbtype is TensorDictPrioritizedReplayBuffer: - data = TensorDict( - { - "a": torch.randint(100, (size,)), - "b": TensorDict({"c": torch.randint(100, (size,))}, [size]), - }, - [size], - ) - else: - raise NotImplementedError(rbtype) - return data - - def test_cursor_position2(self, rbtype, storage, size, prefetch): - torch.manual_seed(0) - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - batch1 = self._get_data(rbtype, size=5) - cond = ( - OLD_TORCH and size < len(batch1) and isinstance(rb.storage, TensorStorage) - ) - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(batch1) - - # Added fewer data than storage max size - if size > 5 or storage is None: - assert rb.writer._cursor == 5 - # Added more data than storage max size - elif size < 5: - assert rb.writer._cursor == 5 - size - # Added as data as storage max size - else: - assert rb.writer._cursor == 0 - batch2 = self._get_data(rbtype, size=size - 1) - rb.extend(batch2) - assert rb.writer._cursor == size - 1 - - def test_add(self, rbtype, storage, size, prefetch): - torch.manual_seed(0) - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - data = self._get_datum(rbtype) - rb.add(data) - s = rb.sample(1)[0] - if isinstance(s, TensorDictBase): - s = s.select(*data.keys(True), strict=False) - data = data.select(*s.keys(True), strict=False) - assert (s == data).all() - assert list(s.keys(True, True)) - else: - assert (s == data).all() - - def test_empty(self, rbtype, storage, size, prefetch): - torch.manual_seed(0) - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - data = self._get_datum(rbtype) - for _ in range(2): - rb.add(data) - s = rb.sample(1)[0] - if isinstance(s, TensorDictBase): - s = s.select(*data.keys(True), strict=False) - data = data.select(*s.keys(True), strict=False) - assert (s == data).all() - assert list(s.keys(True, True)) - else: - assert (s == data).all() - rb.empty() - with pytest.raises( - RuntimeError, match="Cannot sample from an empty storage" - ): - rb.sample() - - def test_extend(self, rbtype, storage, size, prefetch): - torch.manual_seed(0) - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - data = self._get_data(rbtype, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data) - length = len(rb) - for d in data[-length:]: - for b in rb.storage: - if isinstance(b, TensorDictBase): - keys = set(d.keys()).intersection(b.keys()) - b = b.exclude("index").select(*keys, strict=False) - keys = set(d.keys()).intersection(b.keys()) - d = d.select(*keys, strict=False) - - value = b == d - if isinstance(value, (torch.Tensor, TensorDictBase)): - value = value.all() - if value: - break - else: - raise RuntimeError("did not find match") - - def test_sample(self, rbtype, storage, size, prefetch): - torch.manual_seed(0) - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - data = self._get_data(rbtype, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data) - new_data = rb.sample() - if not isinstance(new_data, (torch.Tensor, TensorDictBase)): - new_data = new_data[0] - - for d in new_data: - for b in data: - if isinstance(b, TensorDictBase): - keys = set(d.keys()).intersection(b.keys()) - b = b.exclude("index").select(*keys, strict=False) - keys = set(d.keys()).intersection(b.keys()) - d = d.select(*keys, strict=False) - - value = b == d - if isinstance(value, (torch.Tensor, TensorDictBase)): - value = value.all() - if value: - break - else: - raise RuntimeError("did not find matching value") - - def test_index(self, rbtype, storage, size, prefetch): - torch.manual_seed(0) - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - data = self._get_data(rbtype, size=5) - cond = OLD_TORCH and size < len(data) and isinstance(rb.storage, TensorStorage) - with ( - pytest.warns( - UserWarning, - match="A cursor of length superior to the storage capacity was provided", - ) - if cond - else contextlib.nullcontext() - ): - rb.extend(data) - d1 = rb[2] - d2 = rb.storage[2] - if type(d1) is not type(d2): - d1 = d1[0] - b = d1 == d2 - if not isinstance(b, bool): - b = b.all() - assert b - - def test_index_nonfull(self, rbtype, storage, size, prefetch): - # checks that indexing the buffer before it's full gives the accurate view of the data - rb = self._get_rb(rbtype, storage=storage, size=size, prefetch=prefetch) - data = self._get_data(rbtype, size=size - 1) - rb.extend(data) - assert len(rb[: size - 1]) == size - 1 - assert len(rb[size - 2 :]) == 1 - - -def test_replay_buffer_set_at_(): - """Tests that set_at_ writes through to storage in-place.""" - rb = ReplayBuffer( - storage=LazyTensorStorage(10), - batch_size=5, - ) - data = TensorDict({"a": torch.zeros(10), "b": torch.ones(10)}, batch_size=[10]) - rb.extend(data) - # Modify key "a" at indices [2, 5] - rb.set_at_("a", torch.tensor([99.0, 99.0]), torch.tensor([2, 5])) - assert rb["a"][2] == 99.0 - assert rb["a"][5] == 99.0 - assert rb["a"][0] == 0.0 # unchanged - assert rb["b"][2] == 1.0 # other key unchanged - - -def test_replay_buffer_set_(): - """Tests that set_ writes through to storage in-place.""" - rb = ReplayBuffer( - storage=LazyTensorStorage(10), - batch_size=5, - ) - data = TensorDict({"a": torch.zeros(10), "b": torch.ones(10)}, batch_size=[10]) - rb.extend(data) - rb.set_("a", torch.full((10,), 42.0)) - assert (rb["a"] == 42.0).all() - assert (rb["b"] == 1.0).all() # other key unchanged - - -def test_replay_buffer_update_(): - """Tests that update_ writes through to storage in-place.""" - rb = ReplayBuffer( - storage=LazyTensorStorage(10), - batch_size=5, - ) - data = TensorDict({"a": torch.zeros(10), "b": torch.ones(10)}, batch_size=[10]) - rb.extend(data) - update = TensorDict( - {"a": torch.full((10,), 7.0), "b": torch.full((10,), 8.0)}, - batch_size=[10], - ) - rb.update_(update) - assert (rb["a"] == 7.0).all() - assert (rb["b"] == 8.0).all() - - -def test_multi_loops(): - """Tests that one can iterate multiple times over a buffer without rep.""" - rb = ReplayBuffer( - batch_size=5, storage=ListStorage(10), sampler=SamplerWithoutReplacement() - ) - rb.extend(torch.zeros(10)) - for i, d in enumerate(rb): # noqa: B007 - assert (d == 0).all() - assert i == 1 - for i, d in enumerate(rb): # noqa: B007 - assert (d == 0).all() - assert i == 1 - - -def test_batch_errors(): - """Tests error messages related to batch-size""" - rb = ReplayBuffer( - storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=False) - ) - rb.extend(torch.zeros(10)) - rb.sample(3) # that works - with pytest.raises( - RuntimeError, - match="Cannot iterate over the replay buffer. Batch_size was not specified", - ): - for _ in rb: - pass - with pytest.raises(RuntimeError, match="batch_size not specified"): - rb.sample() - with pytest.raises(ValueError, match="Samplers with drop_last=True"): - ReplayBuffer( - storage=ListStorage(10), sampler=SamplerWithoutReplacement(drop_last=True) - ) - # that works - ReplayBuffer( - storage=ListStorage(10), - ) - rb = ReplayBuffer( - storage=ListStorage(10), - sampler=SamplerWithoutReplacement(drop_last=False), - batch_size=3, - ) - rb.extend(torch.zeros(10)) - for _ in rb: - pass - rb.sample() - - -def test_storage_save_hook(tmpdir): - observed = {} - - class SaveHook: - shift = None - is_full = None - - def __call__(self, data, path=None): - observed["shift"] = self.shift - observed["is_full"] = self.is_full - return data - - hook = SaveHook() - rb = ReplayBuffer(storage=LazyMemmapStorage(10)) - rb.register_save_hook(hook) - rb.extend(torch.arange(5)) - rb.dumps(tmpdir) - - assert hook.shift == 5, f"Expected shift=5, got {hook.shift}" - assert hook.is_full is False, f"Expected is_full=False, got {hook.is_full}" - assert observed["shift"] == 5 - assert observed["is_full"] is False - - -@pytest.mark.skipif(not torchrl._utils.RL_WARNINGS, reason="RL_WARNINGS is not set") -def test_add_warning(): - if not rl_warnings(): - return - rb = ReplayBuffer(storage=ListStorage(10), batch_size=3) - with pytest.warns( - UserWarning, - match=r"Using `add\(\)` with a TensorDict that has batch_size", - ): - rb.add(TensorDict(batch_size=[1])) - - -@pytest.mark.parametrize("priority_key", ["pk", "td_error"]) -@pytest.mark.parametrize("contiguous", [True, False]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_prb(priority_key, contiguous, device): - torch.manual_seed(0) - np.random.seed(0) - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.9, - priority_key=priority_key, - storage=ListStorage(5), - batch_size=5, - ) - td1 = TensorDict( - source={ - "a": torch.randn(3, 1), - priority_key: torch.rand(3, 1) / 10, - "_idx": torch.arange(3).view(3, 1), - }, - batch_size=[3], - ).to(device) - - rb.extend(td1) - s = rb.sample() - assert s.batch_size == torch.Size([5]) - assert (td1[s.get("_idx").squeeze()].get("a") == s.get("a")).all() - assert_allclose_td(td1[s.get("_idx").squeeze()].select("a"), s.select("a")) - - # test replacement - td2 = TensorDict( - source={ - "a": torch.randn(5, 1), - priority_key: torch.rand(5, 1) / 10, - "_idx": torch.arange(5).view(5, 1), - }, - batch_size=[5], - ).to(device) - rb.extend(td2) - s = rb.sample() - assert s.batch_size == torch.Size([5]) - assert (td2[s.get("_idx").squeeze()].get("a") == s.get("a")).all() - assert_allclose_td(td2[s.get("_idx").squeeze()].select("a"), s.select("a")) - - # test strong update - # get all indices that match first item - idx = s.get("_idx") - idx_match = (idx == idx[0]).nonzero()[:, 0] - s.set_at_( - priority_key, - torch.ones(idx_match.numel(), 1, device=device) * 100000000, - idx_match, - ) - val = s.get("a")[0] - - idx0 = s.get("_idx")[0] - rb.update_tensordict_priority(s) - s = rb.sample() - assert (val == s.get("a")).sum() >= 1 - torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) - - # test updating values of original td - td2.set_("a", torch.ones_like(td2.get("a"))) - s = rb.sample() - torch.testing.assert_close(td2[idx0].get("a").view(1), s.get("a").unique().view(1)) - - -@pytest.mark.parametrize("stack", [False, True]) -@pytest.mark.parametrize("reduction", ["min", "max", "mean", "median"]) -def test_rb_trajectories(stack, reduction): - traj_td = TensorDict( - {"obs": torch.randn(3, 4, 5), "actions": torch.randn(3, 4, 2)}, - batch_size=[3, 4], - ) - if stack: - traj_td = torch.stack([td.to_tensordict() for td in traj_td], 0) - - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.9, - priority_key="td_error", - storage=ListStorage(5), - batch_size=3, - ) - rb.extend(traj_td) - sampled_td = rb.sample() - sampled_td.set("td_error", torch.rand(3, 4)) - rb.update_tensordict_priority(sampled_td) - sampled_td = rb.sample(include_info=True) - assert (sampled_td.get("priority_weight") > 0).all() - assert sampled_td.batch_size == torch.Size([3, 4]) - - # set back the trajectory length - sampled_td_filtered = sampled_td.to_tensordict().exclude( - "priority_weight", "index", "td_error" - ) - sampled_td_filtered.batch_size = [3, 4] - - -def test_shared_storage_prioritized_sampler(): - n = 100 - - storage = LazyMemmapStorage(n) - writer = RoundRobinWriter() - sampler0 = RandomSampler() - sampler1 = PrioritizedSampler(max_capacity=n, alpha=0.7, beta=1.1) - - rb0 = ReplayBuffer(storage=storage, writer=writer, sampler=sampler0, batch_size=10) - rb1 = ReplayBuffer(storage=storage, writer=writer, sampler=sampler1, batch_size=10) - - data = TensorDict({"a": torch.arange(50)}, [50]) - - # Extend rb0. rb1 should be aware of changes to storage. - rb0.extend(data) - - assert len(rb0) == 50 - assert len(storage) == 50 - assert len(rb1) == 50 - - rb0.sample() - rb1.sample() - - assert rb1._sampler._sum_tree.query(0, 10) == 10 - assert rb1._sampler._sum_tree.query(0, 50) == 50 - assert rb1._sampler._sum_tree.query(0, 70) == 50 - - -class TestTransforms: - def test_append_transform(self): - rb = ReplayBuffer(collate_fn=lambda x: torch.stack(x, 0), batch_size=1) - td = TensorDict( - { - "observation": torch.randn(2, 4, 3, 16), - "observation2": torch.randn(2, 4, 3, 16), - }, - [], - ) - rb.add(td) - flatten = CatTensors( - in_keys=["observation", "observation2"], out_key="observation_cat" - ) - - rb.append_transform(flatten) - - sampled = rb.sample() - assert sampled.get("observation_cat").shape[-1] == 32 - - def test_init_transform(self): - flatten = FlattenObservation( - -2, -1, in_keys=["observation"], out_keys=["flattened"] - ) - - rb = ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 - ) - - td = TensorDict({"observation": torch.randn(2, 4, 3, 16)}, []) - rb.add(td) - sampled = rb.sample() - assert sampled.get("flattened").shape[-1] == 48 - - def test_insert_transform(self): - flatten = FlattenObservation( - -2, -1, in_keys=["observation"], out_keys=["flattened"] - ) - rb = ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=flatten, batch_size=1 - ) - td = TensorDict({"observation": torch.randn(2, 4, 3, 16, 1)}, []) - rb.add(td) - - rb.insert_transform(0, SqueezeTransform(-1, in_keys=["observation"])) - - sampled = rb.sample() - assert sampled.get("flattened").shape[-1] == 48 - - with pytest.raises(ValueError): - rb.insert_transform(10, SqueezeTransform(-1, in_keys=["observation"])) - - transforms = [ - ToTensorImage, - pytest.param( - partial(RewardClipping, clamp_min=0.1, clamp_max=0.9), id="RewardClipping" - ), - BinarizeReward, - pytest.param( - partial(Resize, w=2, h=2), - id="Resize", - marks=pytest.mark.skipif( - not _has_tv, reason="needs torchvision dependency" - ), - ), - pytest.param( - partial(CenterCrop, w=1), - id="CenterCrop", - marks=pytest.mark.skipif( - not _has_tv, reason="needs torchvision dependency" - ), - ), - pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), - pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), - GrayScale, - pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), - pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), - pytest.param(partial(RewardScaling, loc=1, scale=2), id="RewardScaling"), - DoubleToFloat, - VecNorm, - ] - - @pytest.mark.parametrize("transform", transforms) - def test_smoke_replay_buffer_transform(self, transform): - rb = TensorDictReplayBuffer( - transform=transform(in_keys=["observation"]), batch_size=1 - ) - - # td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 1), "action": torch.randn(3)}, []) - td = TensorDict({"observation": torch.randn(3, 3, 3, 16, 3)}, []) - rb.add(td) - - m = mock.Mock() - m.side_effect = [td.unsqueeze(0)] - rb._transform.forward = m - # rb._transform.__len__ = lambda *args: 3 - rb.sample() - assert rb._transform.forward.called - - # was_called = [False] - # forward = rb._transform.forward - # def new_forward(*args, **kwargs): - # was_called[0] = True - # return forward(*args, **kwargs) - # rb._transform.forward = new_forward - # rb.sample() - # assert was_called[0] - - transforms2 = [ - partial(DiscreteActionProjection, num_actions_effective=1, max_actions=3), - FiniteTensorDictCheck, - gSDENoise, - PinMemoryTransform, - ] - - @pytest.mark.parametrize("transform", transforms2) - def test_smoke_replay_buffer_transform_no_inkeys(self, transform): - if transform == PinMemoryTransform and not torch.cuda.is_available(): - raise pytest.skip("No CUDA device detected, skipping PinMemory") - rb = ReplayBuffer( - collate_fn=lambda x: torch.stack(x, 0), transform=transform(), batch_size=1 - ) - - action = torch.zeros(3) - action[..., 0] = 1 - td = TensorDict( - {"observation": torch.randn(3, 3, 3, 16, 1), "action": action}, [] - ) - rb.add(td) - rb.sample() - - rb._transform = mock.MagicMock() - rb._transform.__len__ = lambda *args: 3 - rb.sample() - assert rb._transform.called - - @pytest.mark.parametrize("at_init", [True, False]) - def test_transform_nontensor(self, at_init): - def t(x): - return tree_map(lambda y: y * 0, x) - - if at_init: - rb = ReplayBuffer(storage=LazyMemmapStorage(100), transform=t) - else: - rb = ReplayBuffer(storage=LazyMemmapStorage(100)) - rb.append_transform(t) - data = { - "a": torch.randn(3), - "b": {"c": (torch.zeros(2), [torch.ones(1)])}, - 30: -torch.ones(()), - } - rb.add(data) - - def assert0(x): - assert (x == 0).all() - - s = rb.sample(10) - tree_map(assert0, s) - - def test_transform_inv(self): - rb = ReplayBuffer(storage=LazyMemmapStorage(10), batch_size=4) - data = TensorDict({"a": torch.zeros(10)}, [10]) - - def t(data): - data += 1 - return data - - rb.append_transform(t, invert=True) - rb.extend(data) - assert (data == 1).all() - - -@pytest.mark.parametrize("size", [10, 15, 20]) -@pytest.mark.parametrize("samples", [5, 9, 11, 14, 16]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_samplerwithoutrep(size, samples, drop_last): - torch.manual_seed(0) - storage = ListStorage(size) - storage.set(range(size), range(size)) - assert len(storage) == size - sampler = SamplerWithoutReplacement(drop_last=drop_last) - visited = False - for _ in range(10): - _n_left = ( - sampler._sample_list.numel() if sampler._sample_list is not None else size - ) - if samples > size and drop_last: - with pytest.raises( - ValueError, - match=r"The batch size .* is greater than the storage capacity", - ): - idx, _ = sampler.sample(storage, samples) - break - idx, _ = sampler.sample(storage, samples) - if drop_last or _n_left >= samples: - assert idx.numel() == samples - assert idx.unique().numel() == idx.numel() - else: - assert idx.numel() == _n_left - visited = True - if not drop_last and (size % samples > 0): - assert visited - else: - assert not visited - - -@pytest.mark.parametrize("size", [10, 15, 20]) -@pytest.mark.parametrize("drop_last", [True, False]) -def test_replay_buffer_iter(size, drop_last): - torch.manual_seed(0) - storage = ListStorage(size) - sampler = SamplerWithoutReplacement(drop_last=drop_last) - writer = RoundRobinWriter() - - rb = ReplayBuffer(storage=storage, sampler=sampler, writer=writer, batch_size=3) - rb.extend([torch.randint(100, (1,)) for _ in range(size)]) - - for i, _ in enumerate(rb): - if i == 20: - # guard against infinite loop if error is introduced - raise RuntimeError("Iteration didn't terminate") - - if drop_last: - assert i == size // 3 - 1 - else: - assert i == (size - 1) // 3 - - -class TestNextStateReconstructor: - """Tests for :class:`~torchrl.envs.transforms.NextStateReconstructor`.""" - - _DEFAULT_TRAJ_KEY = ("collector", "traj_ids") - - @classmethod - def _make_data( - cls, - n_traj=3, - traj_len=4, - obs_dim=2, - traj_key: tuple | str | None = None, - ): - if traj_key is None: - traj_key = cls._DEFAULT_TRAJ_KEY - n = n_traj * traj_len - obs = torch.arange(n * obs_dim, dtype=torch.float32).reshape(n, obs_dim) - done = torch.zeros(n, 1, dtype=torch.bool) - done[traj_len - 1 :: traj_len] = True - traj_ids = torch.repeat_interleave(torch.arange(n_traj), traj_len) - return TensorDict( - { - "observation": obs, - ("next", "done"): done, - ("next", "reward"): torch.zeros(n, 1), - traj_key: traj_ids, - }, - batch_size=[n], - ) - - def test_slice_sampler_default(self): - """With ``SliceSampler`` + default ``traj_key``, slices mirror cleanly.""" - data = self._make_data(n_traj=3, traj_len=4) - rb = ReplayBuffer( - storage=LazyTensorStorage(12), - sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), - transform=NextStateReconstructor(), - batch_size=8, - ) - rb.extend(data) - sample = rb.sample() - assert sample.batch_size == torch.Size([8]) - next_obs = sample.get(("next", "observation")) - root_obs = sample.get("observation") - traj = sample.get(self._DEFAULT_TRAJ_KEY) - # Within each slice (4 entries), positions 0..2 mirror to 1..3 of the same traj. - for slice_start in (0, 4): - assert (traj[slice_start : slice_start + 4] == traj[slice_start]).all() - for i in range(3): - torch.testing.assert_close( - next_obs[slice_start + i], root_obs[slice_start + i + 1] - ) - # Last position of each slice belongs to a different trajectory - # in the (i, i+1) pair (or has no i+1 at all) → NaN. - assert torch.isnan(next_obs[slice_start + 3]).all() - - def test_single_trajectory_full_batch(self): - """Whole trajectory as one batch: every transition reconstructed, last NaN.""" - n = 6 - td = TensorDict( - { - "observation": torch.arange(n, dtype=torch.float32).view(n, 1), - self._DEFAULT_TRAJ_KEY: torch.zeros(n, dtype=torch.long), - # No terminal in the middle; explicit final done for completeness. - ("next", "done"): torch.tensor([[False]] * (n - 1) + [[True]]), - }, - batch_size=[n], - ) - out = NextStateReconstructor()(td) - next_obs = out.get(("next", "observation")) - torch.testing.assert_close(next_obs[:-1], td.get("observation")[1:]) - assert torch.isnan(next_obs[-1]).all() - - def test_done_catches_slice_repetition(self): - """SliceSampler can place two slices of the same trajectory in one batch. - - Trajectory ids match across the splice; ``done`` at the slice end of the - first copy disambiguates. Without the done check, the first slice's - last position would silently borrow the *second slice's first frame* - (same trajectory, but not its temporal successor) and the user would - never know. - """ - n = 8 # two identical trajectories of length 4, glued together - obs = torch.tensor([[0.0], [1.0], [2.0], [3.0]] * 2, dtype=torch.float32) - td = TensorDict( - { - "observation": obs, - self._DEFAULT_TRAJ_KEY: torch.tensor([0] * 8), # all same id - ("next", "done"): torch.tensor([[False], [False], [False], [True]] * 2), - }, - batch_size=[n], - ) - out = NextStateReconstructor()(td) - next_obs = out.get(("next", "observation")) - # Position 3: traj id matches position 4, but done[3]=True → NaN - assert torch.isnan(next_obs[3]).all() - # Positions 0..2 mirror to 1..3 - torch.testing.assert_close(next_obs[:3], obs[1:4]) - # Positions 4..6 mirror to 5..7 - torch.testing.assert_close(next_obs[4:7], obs[5:8]) - # Position 7: no i+1 → NaN - assert torch.isnan(next_obs[7]).all() - - def test_random_sampler_is_mostly_nan(self): - """Random sampling yields mismatched traj ids between neighbors → NaN. - - Documents the honest failure mode: when the user picks a sampler that - doesn't preserve trajectory adjacency, the transform refuses to invent - a next observation. - """ - data = self._make_data(n_traj=8, traj_len=4) # 32 entries - rb = ReplayBuffer( - storage=LazyTensorStorage(32), - sampler=RandomSampler(), - transform=NextStateReconstructor(), - batch_size=16, - ) - rb.extend(data) - torch.manual_seed(0) - sample = rb.sample() - next_obs = sample.get(("next", "observation")) - # With 8 trajectories random-sampled into a 16-batch, the chance that - # two adjacent picks share a trajectory id (≈ 1/8) is low. Assert that - # the *vast majority* of positions are NaN — both that the check is - # firing and that we aren't accidentally fabricating next obs. - nan_frac = torch.isnan(next_obs).all(dim=-1).float().mean().item() - assert nan_frac > 0.7, f"expected mostly-NaN, got nan_frac={nan_frac:.2f}" - - def test_nested_keys(self): - n = 8 - td = TensorDict( - { - "agents": TensorDict( - { - "pos": torch.arange(n * 3, dtype=torch.float32).reshape(n, 3), - "vel": torch.arange(n * 2, dtype=torch.float32).reshape(n, 2), - }, - [n], - ), - ("next", "done"): torch.tensor([[False], [False], [False], [True]] * 2), - ("next", "reward"): torch.zeros(n, 1), - self._DEFAULT_TRAJ_KEY: torch.tensor([0] * 4 + [1] * 4), - }, - batch_size=[n], - ) - rb = ReplayBuffer( - storage=LazyTensorStorage(n), - sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), - transform=NextStateReconstructor( - keys=[("agents", "pos"), ("agents", "vel")], - ), - batch_size=4, - ) - rb.extend(td) - sample = rb.sample() - for k in (("agents", "pos"), ("agents", "vel")): - next_k = ("next", *k) - torch.testing.assert_close(sample.get(next_k)[:3], sample.get(k)[1:4]) - assert torch.isnan(sample.get(next_k)[3]).all() - - def test_explicit_fill_value(self): - data = self._make_data(n_traj=2, traj_len=4) - rb = ReplayBuffer( - storage=LazyTensorStorage(8), - sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), - transform=NextStateReconstructor(fill_value=-1.0), - batch_size=8, - ) - rb.extend(data) - sample = rb.sample() - next_obs = sample.get(("next", "observation")) - # The last position of each slice belongs to a different trajectory - # in (i, i+1), so it gets the fill value. - for slice_start in (0, 4): - assert (next_obs[slice_start + 3] == -1.0).all() - - def test_overwrites_existing_next_obs(self): - """If ``("next", k)`` is already in storage, the transform overwrites it.""" - n = 8 - td = TensorDict( - { - "observation": torch.arange(n, dtype=torch.float32).view(n, 1), - ("next", "observation"): torch.full( - (n, 1), -999.0, dtype=torch.float32 - ), - ("next", "done"): torch.tensor([[False], [False], [False], [True]] * 2), - ("next", "reward"): torch.zeros(n, 1), - self._DEFAULT_TRAJ_KEY: torch.tensor([0] * 4 + [1] * 4), - }, - batch_size=[n], - ) - rb = ReplayBuffer( - storage=LazyTensorStorage(n), - sampler=SliceSampler(slice_len=4, traj_key=self._DEFAULT_TRAJ_KEY), - transform=NextStateReconstructor(), - batch_size=8, - ) - rb.extend(td) - sample = rb.sample() - assert not (sample.get(("next", "observation")) == -999.0).any() - - def test_step_count_cross_check(self): - """``step_count_key`` adds a stricter "consecutive in time" check.""" - n = 4 - td = TensorDict( - { - "observation": torch.arange(n, dtype=torch.float32).view(n, 1), - self._DEFAULT_TRAJ_KEY: torch.zeros(n, dtype=torch.long), - ("next", "done"): torch.zeros(n, 1, dtype=torch.bool), - # Same traj id and no done, but step counts disagree at i=1 - # (jumps from 0 to 5, then 5 -> 6 -> 7). - ("collector", "step_count"): torch.tensor([0, 5, 6, 7]), - }, - batch_size=[n], - ) - t = NextStateReconstructor(step_count_key=("collector", "step_count")) - out = t(td) - next_obs = out.get(("next", "observation")) - # Position 0 → step_count[1] - step_count[0] = 5 ≠ 1, so NaN. - assert torch.isnan(next_obs[0]).all() - # Positions 1 and 2 are consecutive (5→6, 6→7) → reconstructed. - torch.testing.assert_close(next_obs[1], td.get("observation")[2]) - torch.testing.assert_close(next_obs[2], td.get("observation")[3]) - # Position 3 has no i+1 → NaN. - assert torch.isnan(next_obs[3]).all() - - def test_strict_missing_traj_key_raises(self): - td = TensorDict( - {"observation": torch.arange(4, dtype=torch.float32).view(4, 1)}, - batch_size=[4], - ) - with pytest.raises(KeyError, match="trajectory key"): - NextStateReconstructor()(td) - - def test_strict_missing_done_key_raises(self): - td = TensorDict( - { - "observation": torch.arange(4, dtype=torch.float32).view(4, 1), - self._DEFAULT_TRAJ_KEY: torch.zeros(4, dtype=torch.long), - }, - batch_size=[4], - ) - with pytest.raises(KeyError, match="done key"): - NextStateReconstructor()(td) - - def test_strict_false_single_traj_fallback(self): - td = TensorDict( - {"observation": torch.arange(4, dtype=torch.float32).view(4, 1)}, - batch_size=[4], - ) - out = NextStateReconstructor(strict=False)(td) - next_obs = out.get(("next", "observation")) - torch.testing.assert_close(next_obs[:-1], td.get("observation")[1:]) - assert torch.isnan(next_obs[-1]).all() - - def test_traj_key_none_disables_check(self): - td = TensorDict( - { - "observation": torch.arange(4, dtype=torch.float32).view(4, 1), - # Different traj ids, but check is disabled → all-shift, no NaN - # except the last position. - self._DEFAULT_TRAJ_KEY: torch.tensor([0, 1, 2, 3]), - }, - batch_size=[4], - ) - out = NextStateReconstructor(traj_key=None, done_key=None)(td) - next_obs = out.get(("next", "observation")) - torch.testing.assert_close(next_obs[:-1], td.get("observation")[1:]) - assert torch.isnan(next_obs[-1]).all() - - def test_int_obs_requires_explicit_fill_value(self): - td = TensorDict( - { - "observation": torch.arange(4, dtype=torch.int64).view(4, 1), - self._DEFAULT_TRAJ_KEY: torch.zeros(4, dtype=torch.long), - ("next", "done"): torch.zeros(4, 1, dtype=torch.bool), - }, - batch_size=[4], - ) - with pytest.raises(TypeError, match="non-floating dtype"): - NextStateReconstructor()(td) - # Explicit integer fill works - out = NextStateReconstructor(fill_value=-1)(td) - next_obs = out.get(("next", "observation")) - assert next_obs[-1].item() == -1 - - def test_bad_batch_dims_errors(self): - td = TensorDict( - { - "observation": torch.arange(8, dtype=torch.float32).view(2, 4, 1), - self._DEFAULT_TRAJ_KEY: torch.zeros(2, 4, dtype=torch.long), - }, - batch_size=[2, 4], - ) - with pytest.raises(ValueError, match="flat"): - NextStateReconstructor()(td) - - -class TestMaxValueWriter: - @pytest.mark.parametrize("size", [20, 25, 30]) - @pytest.mark.parametrize("batch_size", [1, 10, 15]) - @pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) - @pytest.mark.parametrize("device", get_default_devices()) - def test_max_value_writer(self, size, batch_size, reward_ranges, device): - torch.manual_seed(0) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(size, device=device), - sampler=SamplerWithoutReplacement(), - batch_size=batch_size, - writer=TensorDictMaxValueWriter(rank_key="key"), - ) - - max_reward1, max_reward2, max_reward3 = reward_ranges - - td = TensorDict( - { - "key": torch.clamp_max(torch.rand(size), max=max_reward1), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") <= max_reward1).all() - assert (0 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) - - td = TensorDict( - { - "key": torch.clamp(torch.rand(size), min=max_reward1, max=max_reward2), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") <= max_reward2).all() - assert (max_reward1 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) - - td = TensorDict( - { - "key": torch.clamp(torch.rand(size), min=max_reward2, max=max_reward3), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - - for sample in td: - rb.add(sample) - - sample = rb.sample() - assert (sample.get("key") <= max_reward3).all() - assert (max_reward2 <= sample.get("key")).all() - assert len(sample.get("index").unique()) == len(sample.get("index")) - - # Finally, test the case when no obs should be added - td = TensorDict( - { - "key": torch.zeros(size), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - assert (sample.get("key") != 0).all() - - @pytest.mark.parametrize("size", [20, 25, 30]) - @pytest.mark.parametrize("batch_size", [1, 10, 15]) - @pytest.mark.parametrize("reward_ranges", [(0.25, 0.5, 1.0)]) - @pytest.mark.parametrize("device", get_default_devices()) - def test_max_value_writer_serialize( - self, size, batch_size, reward_ranges, device, tmpdir - ): - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(size, device=device), - sampler=SamplerWithoutReplacement(), - batch_size=batch_size, - writer=TensorDictMaxValueWriter(rank_key="key"), - ) - - max_reward1, max_reward2, max_reward3 = reward_ranges - - td = TensorDict( - { - "key": torch.clamp_max(torch.rand(size), max=max_reward1), - "obs": torch.rand(size), - }, - batch_size=size, - device=device, - ) - rb.extend(td) - rb.writer.dumps(tmpdir) - # check we can dump twice - rb.writer.dumps(tmpdir) - other = TensorDictMaxValueWriter(rank_key="key") - other.loads(tmpdir) - assert len(rb.writer._current_top_values) == len(other._current_top_values) - torch.testing.assert_close( - torch.tensor(rb.writer._current_top_values), - torch.tensor(other._current_top_values), - ) - - @pytest.mark.parametrize("size", [[], [1], [2, 3]]) - @pytest.mark.parametrize("device", get_default_devices()) - @pytest.mark.parametrize("reduction", ["max", "min", "mean", "median", "sum"]) - def test_max_value_writer_reduce(self, size, device, reduction): - torch.manual_seed(0) - batch_size = 4 - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(1, device=device), - sampler=SamplerWithoutReplacement(), - batch_size=batch_size, - writer=TensorDictMaxValueWriter(rank_key="key", reduction=reduction), - ) - - key = torch.rand(batch_size, *size, device=device) - obs = torch.rand(batch_size, *size, device=device) - td = TensorDict( - {"key": key, "obs": obs}, - batch_size=batch_size, - device=device, - ) - rb.extend(td) - sample = rb.sample() - if reduction == "max": - rank_key = torch.stack([k.max() for k in key.unbind(0)]) - elif reduction == "min": - rank_key = torch.stack([k.min() for k in key.unbind(0)]) - elif reduction == "mean": - rank_key = torch.stack([k.mean() for k in key.unbind(0)]) - elif reduction == "median": - rank_key = torch.stack([k.median() for k in key.unbind(0)]) - elif reduction == "sum": - rank_key = torch.stack([k.sum() for k in key.unbind(0)]) - - top_rank = torch.argmax(rank_key) - assert (sample.get("obs") == obs[top_rank]).all() - - -class TestMultiProc: - @staticmethod - def worker(rb, q0, q1): - td = TensorDict({"a": torch.ones(10), "next": {"reward": torch.ones(10)}}, [10]) - rb.extend(td) - q0.put("extended") - extended = q1.get(timeout=5) - assert extended == "extended" - assert len(rb) == 21, len(rb) - assert (rb["a"][:9] == 2).all() - q0.put("finish") - - @staticmethod - def async_prb_worker(rb, worker_id, q): - td = TensorDict( - { - "obs": torch.full((4, 1), worker_id, dtype=torch.float32), - "prio": {"td_error": torch.linspace(0.1, 1.0, 4) + worker_id}, - }, - [4], - ) - rb.extend(td) - q.put("finish") - - @staticmethod - def async_generic_prb_worker(rb, worker_id, q): - data = TensorDict( - {"obs": torch.full((4, 1), worker_id, dtype=torch.float32)}, - [4], - ) - rb.extend(data) - q.put("finish") - - def exec_multiproc_rb( - self, - storage_type=LazyMemmapStorage, - init=True, - writer_type=TensorDictRoundRobinWriter, - sampler_type=RandomSampler, - device=None, - ): - rb = TensorDictReplayBuffer( - storage=storage_type(21), writer=writer_type(), sampler=sampler_type() - ) - if init: - td = TensorDict( - {"a": torch.zeros(10), "next": {"reward": torch.ones(10)}}, - [10], - device=device, - ) - rb.extend(td) - q0 = mp.Queue(1) - q1 = mp.Queue(1) - proc = mp.Process(target=self.worker, args=(rb, q0, q1)) - proc.start() - try: - extended = q0.get(timeout=100) - assert extended == "extended" - assert len(rb) == 20 - assert (rb["a"][10:20] == 1).all() - td = TensorDict({"a": torch.zeros(10) + 2}, [10]) - rb.extend(td) - q1.put("extended") - finish = q0.get(timeout=5) - assert finish == "finish" - finally: - proc.join() - - def test_multiproc_rb(self): - return self.exec_multiproc_rb() - - def test_error_list(self): - # list storage cannot be shared - with pytest.raises(RuntimeError, match="Cannot share a storage of type"): - self.exec_multiproc_rb(storage_type=ListStorage) - - def test_error_maxwriter(self): - # TensorDictMaxValueWriter cannot be shared - with pytest.raises(RuntimeError, match="cannot be shared between processes"): - self.exec_multiproc_rb(writer_type=TensorDictMaxValueWriter) - - def test_error_prb(self): - # PrioritizedSampler cannot be shared - if samplers.SumSegmentTreeFp32 is None: - pytest.skip("PrioritizedSampler extension is unavailable.") - with pytest.raises( - RuntimeError, - match="cannot be shared between processes.*sync=False", - ): - self.exec_multiproc_rb( - sampler_type=lambda: PrioritizedSampler(21, alpha=1.1, beta=0.5) - ) - - def test_prioritized_sampler_shared_error_mentions_sync_false(self, monkeypatch): - sampler = PrioritizedSampler.__new__(PrioritizedSampler) - monkeypatch.setattr(samplers, "get_spawning_popen", lambda: object()) - with pytest.raises(RuntimeError, match="sync=False"): - sampler.__getstate__() - - def test_shared_prefetch_error_mentions_fix(self): - with pytest.raises( - ValueError, - match="Cannot share prefetched replay buffers.*prefetch=0.*shared=False", - ): - TensorDictReplayBuffer( - storage=LazyTensorStorage(10), - batch_size=2, - prefetch=1, - shared=True, - ) - - def test_async_prioritized_rb_multiproc_writes(self): - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - priority_key=("prio", "td_error"), - storage=LazyMemmapStorage(32, shared_init=True), - batch_size=4, - shared=True, - sync=False, - ) - q = mp.Queue() - processes = [] - for worker_id in range(2): - proc = mp.Process( - target=self.async_prb_worker, - args=(rb, worker_id, q), - ) - processes.append(proc) - proc.start() - - for proc in processes: - proc.join() - assert proc.exitcode == 0 - assert q.get(timeout=5) == "finish" - - assert rb.write_count == 8 - sample = rb.sample() - assert rb._prioritized_sampler_write_count == 8 - assert sample["obs"].shape == (4, 1) - assert "priority_weight" in sample.keys() - assert "index" in sample.keys() - - sample["prio", "td_error"] = torch.ones(sample.shape) * 10 - rb.update_tensordict_priority(sample) - assert rb.prioritized_sampler._max_priority[0] is not None - - def test_async_generic_prioritized_rb_multiproc_writes(self): - rb = PrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - storage=LazyMemmapStorage(32), - batch_size=4, - sync=False, - ) - rb.extend(TensorDict({"obs": torch.zeros((1, 1))}, [1])) - rb.empty() - rb.share(True) - q = mp.Queue() - processes = [] - for worker_id in range(2): - proc = mp.Process( - target=self.async_generic_prb_worker, - args=(rb, worker_id, q), - ) - processes.append(proc) - proc.start() - - for proc in processes: - proc.join() - assert proc.exitcode == 0 - assert q.get(timeout=5) == "finish" - - assert rb.write_count == 8 - sample, info = rb.sample(return_info=True) - assert rb._prioritized_sampler_write_count == 8 - assert sample["obs"].shape == (4, 1) - assert "priority_weight" in info - assert "index" in info - - rb.update_priority(info["index"], torch.ones(4) * 10) - assert rb.prioritized_sampler._max_priority[0] is not None - - def test_error_noninit(self): - # list storage cannot be shared - with pytest.raises(RuntimeError, match="it has not been initialized yet"): - self.exec_multiproc_rb(init=False) - - -class TestSamplers: - @pytest.mark.parametrize( - "backend", ["torch"] + (["torchsnapshot"] if _has_snapshot else []) - ) - def test_sampler_without_rep_state_dict(self, backend): - os.environ["CKPT_BACKEND"] = backend - torch.manual_seed(0) - - n_samples = 3 - buffer_size = 100 - storage_in = LazyTensorStorage(buffer_size, device="cpu") - storage_out = LazyTensorStorage(buffer_size, device="cpu") - - replay_buffer = TensorDictReplayBuffer( - storage=storage_in, - sampler=SamplerWithoutReplacement(), - ) - # fill replay buffer with random data - transition = TensorDict( - { - "observation": torch.ones(1, 4), - "action": torch.ones(1, 2), - "reward": torch.ones(1, 1), - "dones": torch.ones(1, 1), - "next": {"observation": torch.ones(1, 4)}, - }, - batch_size=1, - ) - for _ in range(n_samples): - replay_buffer.extend(transition.clone()) - for _ in range(n_samples): - s = replay_buffer.sample(batch_size=1) - assert (s.exclude("index") == 1).all() - - replay_buffer.extend(torch.zeros_like(transition)) - - state_dict = replay_buffer.state_dict() - - new_replay_buffer = TensorDictReplayBuffer( - storage=storage_out, - batch_size=state_dict["_batch_size"], - sampler=SamplerWithoutReplacement(), - ) - - new_replay_buffer.load_state_dict(state_dict) - s = new_replay_buffer.sample(batch_size=1) - assert (s.exclude("index") == 0).all() - - def test_sampler_without_rep_dumps_loads(self, tmpdir): - d0 = tmpdir + "/save0" - d1 = tmpdir + "/save1" - d2 = tmpdir + "/dump" - replay_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(max_size=100, scratch_dir=d0, device="cpu"), - sampler=SamplerWithoutReplacement(drop_last=True), - batch_size=8, - ) - replay_buffer2 = TensorDictReplayBuffer( - storage=LazyMemmapStorage(max_size=100, scratch_dir=d1, device="cpu"), - sampler=SamplerWithoutReplacement(drop_last=True), - batch_size=8, - ) - td = TensorDict( - {"a": torch.arange(0, 27), ("b", "c"): torch.arange(1, 28)}, batch_size=[27] - ) - replay_buffer.extend(td) - for _ in replay_buffer: - break - replay_buffer.dumps(d2) - replay_buffer2.loads(d2) - assert ( - replay_buffer.sampler._sample_list == replay_buffer2.sampler._sample_list - ).all() - s = replay_buffer2.sample(3) - assert (s["a"] == s["b", "c"] - 1).all() - - @pytest.mark.parametrize("drop_last", [False, True]) - def test_sampler_without_replacement_cap_prefetch(self, drop_last): - torch.manual_seed(0) - data = TensorDict({"a": torch.arange(11)}, batch_size=[11]) - rb = ReplayBuffer( - storage=LazyTensorStorage(11), - sampler=SamplerWithoutReplacement(drop_last=drop_last), - batch_size=2, - prefetch=3, - ) - rb.extend(data) - - for _ in range(100): - s = set() - for i, d in enumerate(rb): - assert i <= (4 + int(not drop_last)), i - s = s.union(set(d["a"].tolist())) - assert i == (4 + int(not drop_last)), i - if drop_last: - assert s != set(range(11)) - else: - assert s == set(range(11)) - - @pytest.mark.parametrize( - "batch_size,num_slices,slice_len,prioritized", - [ - [100, 20, None, True], - [100, 20, None, False], - [120, 30, None, False], - [100, None, 5, False], - [120, None, 4, False], - [101, None, 101, False], - ], - ) - @pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")]) - @pytest.mark.parametrize("done_key", ["done", ("some", "done")]) - @pytest.mark.parametrize("match_episode", [True, False]) - @pytest.mark.parametrize("device", get_default_devices()) - def test_slice_sampler( - self, - batch_size, - num_slices, - slice_len, - prioritized, - episode_key, - done_key, - match_episode, - device, - ): - torch.manual_seed(0) - storage = LazyMemmapStorage(100) - episode = torch.zeros(100, dtype=torch.int, device=device) - episode[:30] = 1 - episode[30:55] = 2 - episode[55:70] = 3 - episode[70:] = 4 - steps = torch.cat( - [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 - ) - - done = torch.zeros(100, 1, dtype=torch.bool) - done[torch.tensor([29, 54, 69, 99])] = 1 - - data = TensorDict( - { - # we only use episode_key if we want the sampler to access it - episode_key if match_episode else "whatever_episode": episode, - "another_episode": episode, - "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), - "act": torch.randn((20,)).expand(100, 20), - "steps": steps, - "count": torch.arange(100), - "other": torch.randn((20, 50)).expand(100, 20, 50), - done_key: done, - _replace_last(done_key, "terminated"): done, - }, - [100], - device=device, - ) - storage.set(range(100), data) - if slice_len is not None and slice_len > 15: - # we may have to sample trajs shorter than slice_len - strict_length = False - else: - strict_length = True - - if prioritized: - num_steps = data.shape[0] - sampler = PrioritizedSliceSampler( - max_capacity=num_steps, - alpha=0.7, - beta=0.9, - num_slices=num_slices, - traj_key=episode_key, - end_key=done_key, - slice_len=slice_len, - strict_length=strict_length, - truncated_key=_replace_last(done_key, "truncated"), - ) - index = torch.arange(0, num_steps, 1) - sampler.extend(index) - sampler.update_priority(index, 1) - else: - sampler = SliceSampler( - num_slices=num_slices, - traj_key=episode_key, - end_key=done_key, - slice_len=slice_len, - strict_length=strict_length, - truncated_key=_replace_last(done_key, "truncated"), - ) - if slice_len is not None: - num_slices = batch_size // slice_len - trajs_unique_id = set() - too_short = False - count_unique = set() - for _ in range(50): - index, info = sampler.sample(storage, batch_size=batch_size) - samples = storage._storage[index] - if strict_length: - # check that trajs are ok - samples = samples.view(num_slices, -1) - - unique_another_episode = ( - samples["another_episode"].unique(dim=1).squeeze() - ) - assert unique_another_episode.shape == torch.Size([num_slices]), ( - num_slices, - samples, - ) - assert ( - samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1] - ).all() - if isinstance(index, tuple): - index_numel = index[0].numel() - else: - index_numel = index.numel() - - too_short = too_short or index_numel < batch_size - trajs_unique_id = trajs_unique_id.union( - samples["another_episode"].view(-1).tolist() - ) - count_unique = count_unique.union(samples.get("count").view(-1).tolist()) - - truncated = info[_replace_last(done_key, "truncated")] - terminated = info[_replace_last(done_key, "terminated")] - assert (truncated | terminated).view(num_slices, -1)[:, -1].all() - assert ( - terminated - == samples[_replace_last(done_key, "terminated")].view_as(terminated) - ).all() - done = info[done_key] - assert done.view(num_slices, -1)[:, -1].all() - - if len(count_unique) == 100: - # all items have been sampled - break - else: - raise AssertionError( - f"Not all items can be sampled: {set(range(100)) - count_unique} are missing" - ) - - if strict_length: - assert not too_short - else: - assert too_short - - assert len(trajs_unique_id) == 4 - - @pytest.mark.parametrize("sampler", [SliceSampler, SliceSamplerWithoutReplacement]) - def test_slice_sampler_at_capacity(self, sampler): - torch.manual_seed(0) - - trajectory0 = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]) - trajectory1 = torch.arange(2).repeat_interleave(6) - trajectory = torch.stack([trajectory0, trajectory1], 0) - - td = TensorDict( - {"trajectory": trajectory, "steps": torch.arange(12).expand(2, 12)}, [2, 12] - ) - - rb = ReplayBuffer( - sampler=sampler(traj_key="trajectory", num_slices=2), - storage=LazyTensorStorage(20, ndim=2), - batch_size=6, - ) - - rb.extend(td) - - for s in rb: - if (s["steps"] == 9).any(): - break - else: - raise AssertionError - - def test_slice_sampler_errors(self): - device = "cpu" - batch_size, num_slices = 100, 20 - - episode = torch.zeros(100, dtype=torch.int, device=device) - episode[:30] = 1 - episode[30:55] = 2 - episode[55:70] = 3 - episode[70:] = 4 - steps = torch.cat( - [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 - ) - - done = torch.zeros(100, 1, dtype=torch.bool) - done[torch.tensor([29, 54, 69])] = 1 - - data = TensorDict( - { - # we only use episode_key if we want the sampler to access it - "episode": episode, - "another_episode": episode, - "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), - "act": torch.randn((20,)).expand(100, 20), - "steps": steps, - "other": torch.randn((20, 50)).expand(100, 20, 50), - ("next", "done"): done, - }, - [100], - device=device, - ) - - data_wrong_done = data.clone(False) - data_wrong_done.rename_key_("episode", "_") - data_wrong_done["next", "done"] = done.unsqueeze(1).expand(100, 5, 1) - storage = LazyMemmapStorage(100) - storage.set(range(100), data_wrong_done) - sampler = SliceSampler(num_slices=num_slices) - with pytest.raises( - RuntimeError, - match="Expected the end-of-trajectory signal to be 1-dimensional", - ): - index, _ = sampler.sample(storage, batch_size=batch_size) - - storage = ListStorage(100) - storage.set(range(100), data) - sampler = SliceSampler(num_slices=num_slices) - with pytest.raises( - RuntimeError, - match="Could not get a tensordict out of the storage, which is required for SliceSampler to compute the trajectories.", - ): - index, _ = sampler.sample(storage, batch_size=batch_size) - - @pytest.mark.parametrize("batch_size,num_slices", [[20, 4], [4, 2]]) - @pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")]) - @pytest.mark.parametrize("done_key", ["done", ("some", "done")]) - @pytest.mark.parametrize("match_episode", [True, False]) - @pytest.mark.parametrize("device", get_default_devices()) - def test_slice_sampler_without_replacement( - self, - batch_size, - num_slices, - episode_key, - done_key, - match_episode, - device, - ): - torch.manual_seed(0) - storage = LazyMemmapStorage(100) - episode = torch.zeros(100, dtype=torch.int, device=device) - steps = [] - done = torch.zeros(100, 1, dtype=torch.bool) - for i in range(0, 100, 5): - episode[i : i + 5] = i // 5 - steps.append(torch.arange(5)) - done[i + 4] = 1 - steps = torch.cat(steps) - - data = TensorDict( - { - # we only use episode_key if we want the sampler to access it - episode_key if match_episode else "whatever_episode": episode, - "another_episode": episode, - "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), - "act": torch.randn((20,)).expand(100, 20), - "steps": steps, - "other": torch.randn((20, 50)).expand(100, 20, 50), - done_key: done, - }, - [100], - device=device, - ) - storage.set(range(100), data) - sampler = SliceSamplerWithoutReplacement( - num_slices=num_slices, traj_key=episode_key, end_key=done_key - ) - trajs_unique_id = set() - for i in range(5): - index, info = sampler.sample(storage, batch_size=batch_size) - samples = storage._storage[index] - - # check that trajs are ok - samples = samples.view(num_slices, -1) - assert samples["another_episode"].unique( - dim=1 - ).squeeze().shape == torch.Size([num_slices]) - assert (samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1]).all() - cur_episodes = samples["another_episode"].view(-1).tolist() - for ep in cur_episodes: - assert ep not in trajs_unique_id, i - trajs_unique_id = trajs_unique_id.union( - cur_episodes, - ) - done_recon = info[("next", "truncated")] | info[("next", "terminated")] - assert done_recon.view(num_slices, -1)[:, -1].all() - done = info[("next", "done")] - assert done.view(num_slices, -1)[:, -1].all() - - def test_slice_sampler_left_right(self): - torch.manual_seed(0) - data = TensorDict( - {"obs": torch.arange(1, 11).repeat(10), "eps": torch.arange(100) // 10 + 1}, - [100], - ) - - for N in (2, 4): - rb = TensorDictReplayBuffer( - sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)), - batch_size=50, - storage=LazyMemmapStorage(100), - ) - rb.extend(data) - - for _ in range(10): - sample = rb.sample() - sample = split_trajectories(sample) - assert (sample["next", "truncated"].squeeze(-1).sum(-1) == 1).all() - assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"] - assert ((sample["eps"] == 0).sum(-1) <= N).all() - for i in range(sample.shape[0]): - curr_eps = sample[i]["eps"] - curr_eps = curr_eps[curr_eps != 0] - assert curr_eps.unique().numel() == 1 - - def test_slice_sampler_left_right_ndim(self): - torch.manual_seed(0) - data = TensorDict( - {"obs": torch.arange(1, 11).repeat(12), "eps": torch.arange(120) // 10 + 1}, - [120], - ) - data = data.reshape(4, 30) - - for N in (2, 4): - rb = TensorDictReplayBuffer( - sampler=SliceSampler(num_slices=10, traj_key="eps", span=(N, N)), - batch_size=50, - storage=LazyMemmapStorage(100, ndim=2), - ) - rb.extend(data) - - for _ in range(10): - sample = rb.sample() - sample = split_trajectories(sample) - assert (sample["next", "truncated"].squeeze(-1).sum(-1) <= 1).all() - assert ((sample["obs"] == 0).sum(-1) <= N).all(), sample["obs"] - assert ((sample["eps"] == 0).sum(-1) <= N).all() - for i in range(sample.shape[0]): - curr_eps = sample[i]["eps"] - curr_eps = curr_eps[curr_eps != 0] - assert curr_eps.unique().numel() == 1 - - def test_slice_sampler_strictlength(self): - torch.manual_seed(0) - - data = TensorDict( - { - "traj": torch.cat( - [ - torch.ones(2, dtype=torch.int), - torch.zeros(10, dtype=torch.int), - ], - dim=0, - ), - "x": torch.arange(12), - }, - [12], - ) - - buffer = ReplayBuffer( - storage=LazyTensorStorage(12), - sampler=SliceSampler(num_slices=2, strict_length=True, traj_key="traj"), - batch_size=8, - ) - buffer.extend(data) - - for _ in range(50): - sample = buffer.sample() - assert sample.shape == torch.Size([8]) - assert (sample["traj"] == 0).all() - - buffer = ReplayBuffer( - storage=LazyTensorStorage(12), - sampler=SliceSampler(num_slices=2, strict_length=False, traj_key="traj"), - batch_size=8, - ) - buffer.extend(data) - - for _ in range(50): - sample = buffer.sample() - if sample.shape == torch.Size([6]): - assert (sample["traj"] != 0).any() - else: - assert len(sample["traj"].unique()) == 1 - - # ------------------------------------------------------------------ - # traj_key auto-detection tests - # ------------------------------------------------------------------ - - def test_slice_sampler_auto_traj_key_collector_ids(self): - """Auto-detection should prefer ("collector", "traj_ids") over "episode".""" - torch.manual_seed(0) - # Build data with both keys present; sampler should pick collector key - # and warn that this changes the pre-0.13 default. - traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) - data = TensorDict( - { - ("collector", "traj_ids"): traj_ids, - "episode": torch.zeros(8, dtype=torch.int), # wrong, should be ignored - "obs": torch.arange(8).float(), - }, - batch_size=[8], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(8), - sampler=SliceSampler(num_slices=2), - batch_size=6, - ) - rb.extend(data) - # Force resolution — with both keys present we must see a FutureWarning. - with pytest.warns(FutureWarning, match="auto-detected"): - sample = rb.sample() - assert rb.sampler.traj_key == ("collector", "traj_ids") - assert rb.sampler._fetch_traj is True - assert rb.sampler._traj_key_auto is False - # Each slice should come from a single trajectory - sample_reshaped = sample.reshape(2, 3) - for i in range(2): - traj_vals = sample_reshaped[i][("collector", "traj_ids")] - assert traj_vals.unique().numel() == 1 - - def test_slice_sampler_auto_traj_key_no_warning_single_key(self): - """No FutureWarning when only one of the two candidate keys is present.""" - torch.manual_seed(0) - traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) - data = TensorDict( - { - ("collector", "traj_ids"): traj_ids, - "obs": torch.arange(8).float(), - }, - batch_size=[8], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(8), - sampler=SliceSampler(num_slices=2), - batch_size=6, - ) - rb.extend(data) - with warnings.catch_warnings(): - warnings.simplefilter("error", FutureWarning) - rb.sample() - assert rb.sampler.traj_key == ("collector", "traj_ids") - - def test_slice_sampler_auto_traj_key_episode(self): - """Auto-detection falls back to 'episode' when collector key is absent.""" - torch.manual_seed(0) - traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) - data = TensorDict( - { - "episode": traj_ids, - "obs": torch.arange(8).float(), - }, - batch_size=[8], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(8), - sampler=SliceSampler(num_slices=2), - batch_size=6, - ) - rb.extend(data) - rb.sample() - assert rb.sampler.traj_key == "episode" - assert rb.sampler._fetch_traj is True - - def test_slice_sampler_auto_traj_key_fallback_to_done(self): - """Auto-detection falls back to end_key reconstruction when no traj key.""" - torch.manual_seed(0) - done = torch.zeros(9, 1, dtype=torch.bool) - done[[2, 5, 8]] = True - data = TensorDict( - { - ("next", "done"): done, - ("next", "truncated"): done, - ("next", "terminated"): done, - "obs": torch.arange(9).float(), - }, - batch_size=[9], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(9), - sampler=SliceSampler(num_slices=3), - batch_size=9, - ) - rb.extend(data) - rb.sample() - assert rb.sampler._fetch_traj is False - - def test_slice_sampler_explicit_traj_key_no_auto(self): - """Explicit traj_key should bypass auto-detection entirely.""" - torch.manual_seed(0) - traj_ids = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2], dtype=torch.int) - data = TensorDict( - { - "my_traj": traj_ids, - ("collector", "traj_ids"): torch.zeros(8, dtype=torch.int), - "obs": torch.arange(8).float(), - }, - batch_size=[8], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(8), - sampler=SliceSampler(num_slices=2, traj_key="my_traj"), - batch_size=6, - ) - rb.extend(data) - rb.sample() - assert rb.sampler.traj_key == "my_traj" - assert getattr(rb.sampler, "_traj_key_auto", False) is False - - # ------------------------------------------------------------------ - # mask / lengths tests (strict_length=False) - # ------------------------------------------------------------------ - - def _make_rb_with_short_trajs(self, traj_lengths, slice_len, num_slices): - """Helper: build a TensorDictReplayBuffer with trajectories of given lengths.""" - parts = [] - for t_id, length in enumerate(traj_lengths): - is_init = torch.zeros(length, 1, dtype=torch.bool) - is_init[0] = True # episode reset at the first step of each trajectory - parts.append( - TensorDict( - { - "traj": torch.full((length,), t_id, dtype=torch.int), - "obs": torch.arange(length).float(), - "is_init": is_init, - }, - batch_size=[length], - ) - ) - data = torch.cat(parts) - total = sum(traj_lengths) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(total), - sampler=SliceSampler( - slice_len=slice_len, - traj_key="traj", - strict_length=False, - pad_output=True, - ), - batch_size=num_slices * slice_len, - ) - rb.extend(data) - return rb - - def test_slice_sampler_mask_present_when_short_trajs(self): - """mask appears in output when short trajectories force padding.""" - torch.manual_seed(0) - rb = self._make_rb_with_short_trajs( - traj_lengths=[3, 6, 2], slice_len=5, num_slices=3 - ) - sample = rb.sample() - assert ("collector", "mask") in sample.keys(True) - - def test_slice_sampler_mask_shape_dtype(self): - """mask is bool with shape [B*T] (matches batch shape, no trailing 1).""" - torch.manual_seed(0) - B, T = 4, 6 - rb = self._make_rb_with_short_trajs( - traj_lengths=[2, 5, 3, 4], slice_len=T, num_slices=B - ) - sample = rb.sample() - mask = sample[("collector", "mask")] - assert mask.shape == torch.Size([B * T]) - assert mask.dtype == torch.bool - # mask must match the leading batch dim so trainer code can index - # batch[batch.get(("collector", "mask"))] without broadcasting tricks. - assert mask.shape[0] == sample.batch_size[0] - - def test_slice_sampler_mask_correctness(self): - """mask rows are contiguous: True prefix followed by False suffix.""" - torch.manual_seed(0) - B, T = 6, 8 - rb = self._make_rb_with_short_trajs( - traj_lengths=[3, 8, 2, 7, 1, 5], slice_len=T, num_slices=B - ) - for _ in range(20): - sample = rb.sample() - mask = sample[("collector", "mask")].reshape(B, T) - # derive lengths from the mask itself - lengths = mask.sum(-1) # [B] - for i in range(B): - length = lengths[i].item() - assert length >= 1 - assert length <= T - assert mask[ - i, :length - ].all(), f"slice {i}: first {length} steps should be True" - assert not mask[ - i, length: - ].any(), f"slice {i}: steps after {length} should be False" - - def test_slice_sampler_mask_padded_obs_is_valid(self): - """Padded positions repeat the last real index — obs values must be finite.""" - torch.manual_seed(0) - rb = self._make_rb_with_short_trajs( - traj_lengths=[2, 6, 3], slice_len=5, num_slices=3 - ) - sample = rb.sample() - assert torch.isfinite(sample["obs"]).all() - - def test_slice_sampler_strict_length_no_mask(self): - """With pad_output=False, no mask is emitted regardless of strict_length.""" - torch.manual_seed(0) - data = TensorDict( - { - "traj": torch.cat( - [torch.zeros(6, dtype=torch.int), torch.ones(6, dtype=torch.int)] - ), - "obs": torch.arange(12).float(), - }, - batch_size=[12], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(12), - sampler=SliceSampler( - slice_len=4, traj_key="traj", strict_length=True, pad_output=False - ), - batch_size=8, - ) - rb.extend(data) - sample = rb.sample() - assert ("collector", "mask") not in sample.keys(True) - - def test_slice_sampler_pad_output_strict_length_raises(self): - """pad_output=True + strict_length=True is rejected at construction.""" - with pytest.raises(ValueError, match="pad_output=True is incompatible"): - SliceSampler( - slice_len=4, traj_key="traj", strict_length=True, pad_output=True - ) - - def test_slice_sampler_pad_output_marks_slice_starts(self): - """pad_output=True writes is_init=True at every slice start. - - This is what lets a recurrent policy in `set_recurrent_mode("recurrent")` - consume the flat [B*T] sample directly: the RNN splits on `is_init` - and uses each slice's stored hidden state at position 0. - """ - torch.manual_seed(0) - B, T = 4, 8 - rb = self._make_rb_with_short_trajs( - traj_lengths=[3, 8, 2, 7, 1, 5], slice_len=T, num_slices=B - ) - for _ in range(10): - sample = rb.sample() - is_init = sample["is_init"].reshape(B, T) - # Position 0 of every slice must be True regardless of where the - # slice landed within its source trajectory. - assert is_init[:, 0].all(), "every slice must start with is_init=True" - - def test_slice_sampler_marks_slice_starts_no_pad(self): - """Default (no pad_output) flow: is_init=True at every slice start. - - This is the workflow most users will hit: trajectories are written - end-to-end into the buffer, the sampler returns concatenated - variable-length slices, and the RNN splits on `is_init`. No mask, no - padding involved. - """ - torch.manual_seed(0) - traj_lengths = [3, 8, 2, 7, 5] - parts = [] - for t_id, length in enumerate(traj_lengths): - init = torch.zeros(length, 1, dtype=torch.bool) - init[0] = True - parts.append( - TensorDict( - { - "traj": torch.full((length,), t_id, dtype=torch.int), - "is_init": init, - }, - batch_size=[length], - ) - ) - data = torch.cat(parts) - B = 4 - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(data.numel()), - sampler=SliceSampler(num_slices=B, traj_key="traj", strict_length=False), - batch_size=B * 6, - ) - rb.extend(data) - for _ in range(10): - sample = rb.sample() - assert "is_init" in sample.keys(True) - is_init = sample["is_init"].squeeze(-1) - trunc = sample[("next", "truncated")].squeeze(-1) - # Slice 0 always starts at position 0. - assert is_init[0].item(), "first slice must start with is_init=True" - # Every position right after a truncated flag must be is_init=True - # (next slice's start). The last truncated marks the end of the - # batch; nothing follows it. - slice_ends = trunc.nonzero().squeeze(-1).tolist() - for end in slice_ends[:-1]: - assert is_init[ - end + 1 - ].item(), f"slice starting at index {end + 1} missing is_init=True" - - def test_slice_sampler_pad_output_no_is_init_no_marker(self): - """Without is_init in the storage we don't introduce one out of thin air.""" - torch.manual_seed(0) - # Build a buffer *without* is_init. - data = TensorDict( - { - "traj": torch.cat( - [ - torch.full((3,), 0, dtype=torch.int), - torch.full((6,), 1, dtype=torch.int), - torch.full((2,), 2, dtype=torch.int), - ] - ), - "obs": torch.arange(11).float(), - }, - batch_size=[11], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(11), - sampler=SliceSampler( - slice_len=5, traj_key="traj", strict_length=False, pad_output=True - ), - batch_size=15, - ) - rb.extend(data) - sample = rb.sample() - # is_init must not appear if it wasn't in the storage - assert "is_init" not in sample.keys(True) - - def test_slice_sampler_flat_sample_matches_batched_recurrent_module(self): - """A flat padded sample must match an explicit [B, T] recurrent call.""" - torch.manual_seed(0) - B, T = 4, 5 - input_size, hidden_size = 3, 7 - parts = [] - for traj_id, length in enumerate([11, 9, 10, 12]): - is_init = torch.zeros(length, 1, dtype=torch.bool) - is_init[0] = True - parts.append( - TensorDict( - { - "traj": torch.full((length,), traj_id, dtype=torch.int), - "embed": torch.randn(length, input_size), - "recurrent_state": torch.randn(length, 1, hidden_size), - "is_init": is_init, - }, - batch_size=[length], - ) - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(sum(part.shape[0] for part in parts)), - sampler=SliceSampler( - slice_len=T, - traj_key="traj", - strict_length=False, - pad_output=True, - ), - batch_size=B * T, - ) - rb.extend(torch.cat(parts)) - sample = rb.sample() - assert sample["is_init"].reshape(B, T)[:, 0].all() - - gru = GRUModule( - input_size=input_size, - hidden_size=hidden_size, - num_layers=1, - in_keys=["embed", "recurrent_state", "is_init"], - out_keys=["features", ("next", "recurrent_state")], - ) - with set_recurrent_mode("recurrent"): - flat_out = gru(sample.clone()) - batched_out = gru(sample.clone().reshape(B, T)) - - torch.testing.assert_close( - flat_out["features"].reshape(B, T, hidden_size), batched_out["features"] - ) - torch.testing.assert_close( - flat_out[("next", "recurrent_state")].reshape(B, T, 1, hidden_size), - batched_out[("next", "recurrent_state")], - ) - - def test_slice_sampler_mask_all_long_trajs_no_mask(self): - """When all trajs >= slice_len, pad_output=True still emits no mask (nothing to pad).""" - torch.manual_seed(0) - data = TensorDict( - { - "traj": torch.cat( - [torch.zeros(8, dtype=torch.int), torch.ones(8, dtype=torch.int)] - ), - "obs": torch.arange(16).float(), - }, - batch_size=[16], - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(16), - sampler=SliceSampler( - slice_len=4, traj_key="traj", strict_length=False, pad_output=True - ), - batch_size=8, - ) - rb.extend(data) - sample = rb.sample() - # No short trajectories → no padding needed → no mask emitted - assert ("collector", "mask") not in sample.keys(True) - - def test_slice_sampler_truncated_marks_last_real_step(self): - """truncated flag should sit at the last *real* timestep, not the padded end.""" - torch.manual_seed(0) - B, T = 4, 6 - rb = self._make_rb_with_short_trajs( - traj_lengths=[2, 5, 3, 4], slice_len=T, num_slices=B - ) - sample = rb.sample() - mask = sample[("collector", "mask")].reshape(B, T) - lengths = mask.sum(-1) # [B] — derived from mask - trunc = sample[("next", "truncated")].reshape(B, T) - for i in range(B): - length = lengths[i].item() - # truncated should be True exactly at position length-1 - assert trunc[ - i, length - 1 - ].item(), f"slice {i}: truncated missing at last real step" - # no truncated flag in padded region - if length < T: - assert not trunc[ - i, length: - ].any(), f"slice {i}: spurious truncated in padding" - - @pytest.mark.parametrize("ndim", [1, 2]) - @pytest.mark.parametrize("strict_length", [True, False]) - @pytest.mark.parametrize("circ", [False, True]) - @pytest.mark.parametrize("at_capacity", [False, True]) - def test_slice_sampler_prioritized(self, ndim, strict_length, circ, at_capacity): - torch.manual_seed(0) - out = [] - for t in range(5): - length = (t + 1) * 5 - done = torch.zeros(length, 1, dtype=torch.bool) - done[-1] = 1 - priority = 10 if t == 0 else 1 - traj = TensorDict( - { - "traj": torch.full((length,), t), - "step_count": torch.arange(length), - "done": done, - "priority": torch.full((length,), priority), - }, - batch_size=length, - ) - out.append(traj) - data = torch.cat(out) - if ndim == 2: - data = torch.stack([data, data]) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(data.numel() - at_capacity, ndim=ndim), - sampler=PrioritizedSliceSampler( - max_capacity=data.numel() - at_capacity, - alpha=1.0, - beta=1.0, - end_key="done", - slice_len=10, - strict_length=strict_length, - cache_values=True, - ), - batch_size=50, - ) - if not circ: - # Simplest case: the buffer is full but no overlap - index = rb.extend(data, update_priority=False) - else: - # The buffer is 2/3 -> 1/3 overlapping - rb.extend(data[..., : data.shape[-1] // 3], update_priority=False) - index = rb.extend(data, update_priority=False) - rb.update_priority(index, data["priority"]) - samples = [] - found_shorter_batch = False - for _ in range(100): - samples.append(rb.sample()) - if samples[-1].numel() < 50: - found_shorter_batch = True - samples = torch.cat(samples) - if strict_length: - assert not found_shorter_batch - else: - assert found_shorter_batch - # the first trajectory has a very high priority, but should only appear - # if strict_length=False. - if strict_length: - assert (samples["traj"] != 0).all(), samples["traj"].unique() - else: - assert (samples["traj"] == 0).any() - # Check that all samples of the first traj contain all elements (since it's too short to fulfill 10 elts) - sc = samples[samples["traj"] == 0]["step_count"] - assert (sc == 1).sum() == (sc == 2).sum() - assert (sc == 1).sum() == (sc == 4).sum() - assert rb.sampler._cache - rb.extend(data, update_priority=False) - assert not rb.sampler._cache - - @pytest.mark.parametrize("ndim", [1, 2]) - @pytest.mark.parametrize("strict_length", [True, False]) - @pytest.mark.parametrize("circ", [False, True]) - @pytest.mark.parametrize( - "span", [False, [False, False], [False, True], 3, [False, 3]] - ) - def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span): - torch.manual_seed(0) - out = [] - # 5 trajs of length 3, 6, 9, 12 and 15 - for t in range(5): - length = (t + 1) * 3 - done = torch.zeros(length, 1, dtype=torch.bool) - done[-1] = 1 - priority = 1 - traj = TensorDict( - { - "traj": torch.full((length,), t), - "step_count": torch.arange(length), - "done": done, - "priority": torch.full((length,), priority), - }, - batch_size=length, - ) - out.append(traj) - data = torch.cat(out) - if ndim == 2: - data = torch.stack([data, data]) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(data.numel(), ndim=ndim), - sampler=PrioritizedSliceSampler( - max_capacity=data.numel(), - alpha=1.0, - beta=1.0, - end_key="done", - slice_len=5, - strict_length=strict_length, - cache_values=True, - span=span, - ), - batch_size=5, - ) - if not circ: - # Simplest case: the buffer is full but no overlap - index = rb.extend(data) - else: - # The buffer is 2/3 -> 1/3 overlapping - rb.extend(data[..., : data.shape[-1] // 3]) - index = rb.extend(data) - rb.update_priority(index, data["priority"]) - found_traj_0 = False - found_traj_4_truncated_right = False - for i, s in enumerate(rb): - t = s["traj"].unique().tolist() - assert len(t) == 1 - t = t[0] - if t == 0: - found_traj_0 = True - if t == 4 and s.numel() < 5: - if s["step_count"][0] > 10: - found_traj_4_truncated_right = True - if s["step_count"][0] == 0: - pass - if i == 1000: - break - assert not rb.sampler.span[0] - # if rb.sampler.span[0]: - # assert found_traj_4_truncated_left - if rb.sampler.span[1]: - assert found_traj_4_truncated_right - else: - assert not found_traj_4_truncated_right - if strict_length and not rb.sampler.span[1]: - assert not found_traj_0 - else: - assert found_traj_0 - - @pytest.mark.parametrize("max_priority_within_buffer", [True, False]) - def test_prb_update_max_priority(self, max_priority_within_buffer): - rb = ReplayBuffer( - storage=LazyTensorStorage(11), - sampler=PrioritizedSampler( - max_capacity=11, - alpha=1.0, - beta=1.0, - max_priority_within_buffer=max_priority_within_buffer, - ), - ) - for data in torch.arange(20): - idx = rb.add(data) - rb.update_priority(idx, 21 - data) - if data <= 10: - # The max is always going to be the first value - assert rb.sampler._max_priority[0] == 21 - assert rb.sampler._max_priority[1] == 0 - elif not max_priority_within_buffer: - # The max is the historical max, which was at idx 0 - assert rb.sampler._max_priority[0] == 21 - assert rb.sampler._max_priority[1] == 0 - else: - # the max is the current max. Find it and compare - sumtree = torch.as_tensor( - [rb.sampler._sum_tree[i] for i in range(rb.sampler._max_capacity)] - ) - assert rb.sampler._max_priority[0] == sumtree.max() - assert rb.sampler._max_priority[1] == sumtree.argmax() - idx = rb.extend(torch.arange(10)) - rb.update_priority(idx, 12) - if max_priority_within_buffer: - assert rb.sampler._max_priority[0] == 12 - assert rb.sampler._max_priority[1] == 0 - else: - assert rb.sampler._max_priority[0] == 21 - assert rb.sampler._max_priority[1] == 0 - - @pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" - ) - def test_prb_serialization(self, tmpdir): - rb = ReplayBuffer( - storage=LazyMemmapStorage(max_size=10), - sampler=PrioritizedSampler(max_capacity=10, alpha=0.8, beta=0.6), - ) - - td = TensorDict( - { - "observations": torch.zeros(1, 3), - "actions": torch.zeros(1, 1), - "rewards": torch.zeros(1, 1), - "next_observations": torch.zeros(1, 3), - "terminations": torch.zeros(1, 1, dtype=torch.bool), - }, - batch_size=[1], - ) - rb.extend(td) - - rb.save(tmpdir) - - rb2 = ReplayBuffer( - storage=LazyMemmapStorage(max_size=10), - sampler=PrioritizedSampler(max_capacity=10, alpha=0.5, beta=0.5), - ) - - td = TensorDict( - { - "observations": torch.ones(1, 3), - "actions": torch.ones(1, 1), - "rewards": torch.ones(1, 1), - "next_observations": torch.ones(1, 3), - "terminations": torch.ones(1, 1, dtype=torch.bool), - }, - batch_size=[1], - ) - rb2.extend(td) - rb2.load(tmpdir) - assert len(rb) == 1 - assert rb.sampler._alpha == rb2.sampler._alpha - assert rb.sampler._beta == rb2.sampler._beta - assert rb.sampler._max_priority[0] == rb2.sampler._max_priority[0] - assert rb.sampler._max_priority[1] == rb2.sampler._max_priority[1] - - @pytest.mark.skipif( - TORCH_VERSION < version.parse("2.5.0"), reason="requires Torch >= 2.5.0" - ) - def test_prb_new_sampler_with_loaded_storage(self, tmpdir): - """Test that creating a new PrioritizedSampler with loaded storage works correctly. - - This test reproduces the issue from scratch8.py where creating a new - PrioritizedSampler instance with storage that already contains data - would fail with "RuntimeError: non-positive p_sum". - """ - device = torch.device("cpu") - - # Create and populate original buffer - original_rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(10, device=device), - sampler=PrioritizedSampler(max_capacity=10, alpha=0.7, beta=0.5), - batch_size=2, - priority_key="td_error", - ) - - data = TensorDict( - { - "state": torch.ones(4, 2, dtype=torch.float32, device=device), - "td_error": torch.ones(4) * 0.5, - }, - batch_size=torch.Size((4,)), - ) - original_rb.extend(data) - - # Update priorities - td = original_rb.sample() - td["td_error"] = torch.arange(2, device=device) + 1.0 - original_rb.update_tensordict_priority(td) - - # Get original priorities for comparison - original_priorities = torch.tensor( - [original_rb._sampler._sum_tree[i] for i in range(len(original_rb))] - ) - - # Save and load normally - original_rb.dumps(tmpdir) - del original_rb - - loaded_rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(10, device=device), - sampler=PrioritizedSampler(max_capacity=10, alpha=0.7, beta=0.5), - batch_size=2, - priority_key="td_error", - ) - loaded_rb.loads(tmpdir) - - # Create a new buffer with the loaded storage but NEW sampler - # This was failing before the fix with "RuntimeError: non-positive p_sum" - new_rb_with_loaded_storage = TensorDictReplayBuffer( - storage=loaded_rb.storage, # Use the loaded storage - sampler=PrioritizedSampler( # But create a NEW sampler instance - max_capacity=len(loaded_rb), alpha=0.7, beta=0.5 - ), - batch_size=2, - priority_key="td_error", - ) - - # This should work now thanks to our fix - td = new_rb_with_loaded_storage.sample() - assert td.batch_size == torch.Size([2]) - - # Verify the storage has the expected data - assert len(new_rb_with_loaded_storage) == 4 - - # Verify priorities were properly initialized with default values - # When creating a new sampler with existing storage, it should initialize with default priorities - new_priorities = torch.tensor( - [ - new_rb_with_loaded_storage._sampler._sum_tree[i] - for i in range(len(new_rb_with_loaded_storage)) - ] - ) - expected_default_priority = new_rb_with_loaded_storage._sampler.default_priority - expected_priorities = torch.full( - (len(new_rb_with_loaded_storage),), - expected_default_priority, - dtype=torch.float, - ) - - # All priorities should be positive and equal to the default priority - assert (new_priorities > 0).all(), "All priorities should be positive" - torch.testing.assert_close( - new_priorities, - expected_priorities, - msg="New sampler should initialize with default priorities", - ) - - # Also verify that the loaded buffer maintains the original priorities - loaded_priorities = torch.tensor( - [loaded_rb._sampler._sum_tree[i] for i in range(len(loaded_rb))] - ) - torch.testing.assert_close( - loaded_priorities, - original_priorities, - msg="Loaded buffer should maintain original priorities", - ) - - def test_prb_ndim(self): - """This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB. - - All tests are done for 1d and 2d TDs. - - """ - torch.manual_seed(0) - np.random.seed(0) - - # first case: 1d, RB - rb = ReplayBuffer( - sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), - storage=LazyTensorStorage(100), - batch_size=4, - ) - data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) - idx = rb.extend(data) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() - rb.update_priority(idx, 2) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() - s, info = rb.sample(return_info=True) - rb.update_priority(info["index"], 3) - assert ( - torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[info["index"]] - == 3 - ).all() - - # second case: 1d, TRB - rb = TensorDictReplayBuffer( - sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), - storage=LazyTensorStorage(100), - batch_size=4, - ) - data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) - idx = rb.extend(data) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() - rb.update_priority(idx, 2) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() - s = rb.sample() - rb.update_priority(s["index"], 3) - assert ( - torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[s["index"]] == 3 - ).all() - - # third case: 1d TPRB - rb = TensorDictPrioritizedReplayBuffer( - alpha=1.0, - beta=1.0, - storage=LazyTensorStorage(100), - batch_size=4, - priority_key="p", - ) - data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10]) - idx = rb.extend(data) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 0.5).all() - rb.update_priority(idx, 2) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() - s = rb.sample() - - s["p"] = torch.ones(4) * 10_000 - rb.update_tensordict_priority(s) - assert ( - torch.tensor([rb.sampler._sum_tree[i] for i in range(10)])[s["index"]] - == 10_000 - ).all() - - s2 = rb.sample() - # All indices in s2 must be from s since we set a very high priority to these items - assert (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).any(0).all() - - # fourth case: 2d RB - rb = ReplayBuffer( - sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), - storage=LazyTensorStorage(100, ndim=2), - batch_size=4, - ) - data = TensorDict( - {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] - ) - idx = rb.extend(data) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() - rb.update_priority(idx, 2) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() - - s, info = rb.sample(return_info=True) - rb.update_priority(info["index"], 3) - priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( - (5, 2) - ) - assert (priorities[info["index"]] == 3).all() - - # fifth case: 2d TRB - # 2d - rb = TensorDictReplayBuffer( - sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0), - storage=LazyTensorStorage(100, ndim=2), - batch_size=4, - ) - data = TensorDict( - {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] - ) - idx = rb.extend(data) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 1).all() - rb.update_priority(idx, 2) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() - - s = rb.sample() - rb.update_priority(s["index"], 10_000) - priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( - (5, 2) - ) - assert (priorities[s["index"].unbind(-1)] == 10_000).all() - - s2 = rb.sample() - assert ( - (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all() - ) - - # Sixth case: 2d TDPRB - rb = TensorDictPrioritizedReplayBuffer( - alpha=1.0, - beta=1.0, - storage=LazyTensorStorage(100, ndim=2), - batch_size=4, - priority_key="p", - ) - data = TensorDict( - {"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5] - ) - idx = rb.extend(data) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 0.5).all() - rb.update_priority(idx, torch.ones(()) * 2) - assert (torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]) == 2).all() - s = rb.sample() - # setting the priorities to a value that is so big that the buffer will resample them - s["p"] = torch.ones(4) * 10_000 - rb.update_tensordict_priority(s) - priorities = torch.tensor([rb.sampler._sum_tree[i] for i in range(10)]).reshape( - (5, 2) - ) - assert (priorities[s["index"].unbind(-1)] == 10_000).all() - - s2 = rb.sample() - assert ( - (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all() - ) - - def test_replacement_kwarg_random(self): - # RandomSampler(replacement=True) is a regular RandomSampler - s = RandomSampler() - assert type(s) is RandomSampler - s = RandomSampler(replacement=True) - assert type(s) is RandomSampler - - # RandomSampler(replacement=False) dispatches to SamplerWithoutReplacement - s = RandomSampler(replacement=False) - assert type(s) is SamplerWithoutReplacement - # default kwargs propagate - assert s.drop_last is False - assert s.shuffle is True - - # Extra kwargs are forwarded to SamplerWithoutReplacement - s = RandomSampler(replacement=False, drop_last=True, shuffle=False) - assert type(s) is SamplerWithoutReplacement - assert s.drop_last is True - assert s.shuffle is False - - # isinstance is preserved - assert isinstance(s, Sampler) - assert isinstance(s, SamplerWithoutReplacement) - - def test_replacement_kwarg_slice(self): - # SliceSampler(replacement=True) is a regular SliceSampler - s = SliceSampler(slice_len=5) - assert type(s) is SliceSampler - s = SliceSampler(replacement=True, slice_len=5) - assert type(s) is SliceSampler - - # SliceSampler(replacement=False) dispatches to SliceSamplerWithoutReplacement - s = SliceSampler(replacement=False, slice_len=5) - assert type(s) is SliceSamplerWithoutReplacement - assert s.slice_len == 5 - assert s.drop_last is False - assert s.shuffle is True - - # Extra without-replacement kwargs forward correctly - s = SliceSampler( - replacement=False, - slice_len=5, - drop_last=True, - shuffle=False, - traj_key="episode", - strict_length=False, - ) - assert type(s) is SliceSamplerWithoutReplacement - assert s.slice_len == 5 - assert s.drop_last is True - assert s.shuffle is False - assert s.traj_key == "episode" - assert s.strict_length is False - - # isinstance preserves the SliceSampler hierarchy - assert isinstance(s, SliceSampler) - assert isinstance(s, SamplerWithoutReplacement) - - def test_replacement_kwarg_subclass_unaffected(self): - # PrioritizedSliceSampler inherits from SliceSampler but should NOT dispatch - s = PrioritizedSliceSampler( - slice_len=5, max_capacity=10, alpha=0.5, beta=0.5 - ) - assert type(s) is PrioritizedSliceSampler - - # SamplerWithoutReplacement(replacement=...) is a no-op pop - s = SamplerWithoutReplacement(replacement=False, drop_last=True) - assert type(s) is SamplerWithoutReplacement - assert s.drop_last is True - s = SliceSamplerWithoutReplacement(replacement=False, slice_len=5) - assert type(s) is SliceSamplerWithoutReplacement - assert s.slice_len == 5 - - def test_replacement_kwarg_no_variant_errors(self): - # PrioritizedSampler has no without-replacement variant -> TypeError - with pytest.raises(TypeError, match="no without-replacement variant"): - PrioritizedSampler( - max_capacity=10, alpha=0.5, beta=0.5, replacement=False - ) - - def test_replacement_kwarg_in_replay_buffer(self): - # End-to-end: a buffer using RandomSampler(replacement=False) should - # exhaust the storage without duplicate indices (like SamplerWithoutReplacement). - torch.manual_seed(0) - data = TensorDict({"a": torch.arange(11)}, batch_size=[11]) - rb = ReplayBuffer( - storage=LazyTensorStorage(11), - sampler=RandomSampler(replacement=False, drop_last=False), - batch_size=3, - ) - rb.extend(data) - seen = set() - for _ in range(4): - seen.update(rb.sample()["a"].tolist()) - assert seen == set(range(11)) - - def test_replacement_kwarg_slice_in_replay_buffer(self): - # End-to-end: SliceSampler(replacement=False) returns sub-trajectories - torch.manual_seed(0) - episodes = torch.zeros(60, dtype=torch.long) - episodes[:20] = 0 - episodes[20:40] = 1 - episodes[40:] = 2 - data = TensorDict( - {"episode": episodes, "obs": torch.arange(60)}, - batch_size=[60], - ) - rb = ReplayBuffer( - storage=LazyTensorStorage(60), - sampler=SliceSampler( - replacement=False, - slice_len=5, - traj_key="episode", - strict_length=True, - ), - batch_size=10, - ) - rb.extend(data) - sample = rb.sample() - # batch_size=10, slice_len=5 -> 2 slices of 5 contiguous obs each - obs = sample["obs"].view(2, 5) - diffs = obs[:, 1:] - obs[:, :-1] - assert (diffs == 1).all(), obs - - -class TestStalenessAwareSampler: - """Tests for StalenessAwareSampler.""" - - def _make_buffer_with_versions(self, n_entries=100, version_range=(0, 5)): - """Create a replay buffer populated with data containing policy_version.""" - sampler = StalenessAwareSampler(max_staleness=-1) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(n_entries), - sampler=sampler, - batch_size=16, - ) - # Fill with data having varying policy versions - for v in range(version_range[0], version_range[1] + 1): - batch = TensorDict( - { - "observation": torch.randn(20, 4), - "action": torch.randn(20, 2), - "policy_version": torch.full((20,), float(v)), - }, - batch_size=[20], - ) - rb.extend(batch) - return rb, sampler - - def test_basic_sampling(self): - """Test that StalenessAwareSampler can sample from a buffer.""" - rb, sampler = self._make_buffer_with_versions() - sampler.consumer_version = 5 - batch = rb.sample() - assert batch is not None - assert batch.shape[0] == 16 - - def test_freshness_weighting(self): - """Test that fresher entries are sampled more frequently.""" - sampler = StalenessAwareSampler(max_staleness=-1) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(200), - sampler=sampler, - batch_size=32, - ) - # Add 100 entries at version 0 (stale) and 100 at version 9 (fresh) - stale = TensorDict( - { - "observation": torch.zeros(100, 4), - "policy_version": torch.full((100,), 0.0), - }, - batch_size=[100], - ) - fresh = TensorDict( - { - "observation": torch.ones(100, 4), - "policy_version": torch.full((100,), 9.0), - }, - batch_size=[100], - ) - rb.extend(stale) - rb.extend(fresh) - sampler.consumer_version = 10 - - # Sample many times and count how often fresh vs stale entries appear - fresh_count = 0 - total = 0 - for _ in range(100): - batch = rb.sample() - # Fresh entries have observation == 1, stale have observation == 0 - fresh_count += (batch["observation"][:, 0] > 0.5).sum().item() - total += batch.shape[0] - - fresh_ratio = fresh_count / total - # Fresh entries (staleness=1) should be sampled ~10x more than stale (staleness=10) - # So fresh_ratio should be significantly above 0.5 - assert ( - fresh_ratio > 0.7 - ), f"Expected fresh entries to dominate, got {fresh_ratio:.2f}" - - def test_hard_staleness_gate(self): - """Test that entries beyond max_staleness are never sampled.""" - sampler = StalenessAwareSampler(max_staleness=3) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(200), - sampler=sampler, - batch_size=32, - ) - # Add entries at version 0 (stale) and version 8 (fresh) - stale = TensorDict( - { - "observation": torch.zeros(100, 4), - "policy_version": torch.full((100,), 0.0), - }, - batch_size=[100], - ) - fresh = TensorDict( - { - "observation": torch.ones(100, 4), - "policy_version": torch.full((100,), 8.0), - }, - batch_size=[100], - ) - rb.extend(stale) - rb.extend(fresh) - sampler.consumer_version = 10 - - # All sampled entries should be fresh (staleness=2 <= 3) - # Stale entries have staleness=10 > 3, so they're excluded - for _ in range(50): - batch = rb.sample() - assert ( - batch["observation"][:, 0] > 0.5 - ).all(), ( - "Stale entries should never be sampled when max_staleness is exceeded" - ) - - def test_all_stale_raises(self): - """Test that an error is raised when all entries exceed max_staleness.""" - sampler = StalenessAwareSampler(max_staleness=2) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(50), - sampler=sampler, - batch_size=8, - ) - data = TensorDict( - { - "observation": torch.randn(50, 4), - "policy_version": torch.full((50,), 0.0), - }, - batch_size=[50], - ) - rb.extend(data) - sampler.consumer_version = 100 # Everything is very stale - - with pytest.raises(RuntimeError, match="max_staleness"): - rb.sample() - - def test_consumer_version_increment(self): - """Test consumer version tracking.""" - sampler = StalenessAwareSampler() - assert sampler.consumer_version == 0 - sampler.increment_consumer_version() - assert sampler.consumer_version == 1 - sampler.consumer_version = 42 - assert sampler.consumer_version == 42 - - def test_staleness_in_info(self): - """Test that staleness values are returned in sample info.""" - sampler = StalenessAwareSampler(max_staleness=-1) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(50), - sampler=sampler, - batch_size=8, - ) - data = TensorDict( - { - "observation": torch.randn(50, 4), - "policy_version": torch.full((50,), 3.0), - }, - batch_size=[50], - ) - rb.extend(data) - sampler.consumer_version = 5 - - index, info = sampler.sample(rb._storage, 8) - assert "staleness" in info - assert (info["staleness"] == 2.0).all() # consumer=5 - version=3 = 2 - - def test_missing_version_key_raises(self): - """Test that a clear error is raised when version key is missing.""" - sampler = StalenessAwareSampler() - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(50), - sampler=sampler, - batch_size=8, - ) - data = TensorDict( - {"observation": torch.randn(50, 4)}, - batch_size=[50], - ) - rb.extend(data) - - with pytest.raises(KeyError, match="policy_version"): - rb.sample() - - def test_state_dict_roundtrip(self): - """Test that state_dict/load_state_dict preserves sampler state.""" - sampler = StalenessAwareSampler(max_staleness=7) - sampler.consumer_version = 42 - - sd = sampler.state_dict() - assert sd["consumer_version"] == 42 - assert sd["max_staleness"] == 7 - - sampler2 = StalenessAwareSampler() - sampler2.load_state_dict(sd) - assert sampler2.consumer_version == 42 - assert sampler2.max_staleness == 7 - - def test_no_staleness_limit(self): - """Test sampling with max_staleness=-1 (no limit).""" - sampler = StalenessAwareSampler(max_staleness=-1) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(50), - sampler=sampler, - batch_size=8, - ) - data = TensorDict( - { - "observation": torch.randn(50, 4), - "policy_version": torch.full((50,), 0.0), - }, - batch_size=[50], - ) - rb.extend(data) - sampler.consumer_version = 1000 # Very stale, but no limit - - # Should not raise - batch = rb.sample() - assert batch.shape[0] == 8 - - -def test_prioritized_slice_sampler_doc_example(): - sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) - rb = TensorDictReplayBuffer( - storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6 - ) - data = TensorDict( - { - "observation": torch.randn(9, 16), - "action": torch.randn(9, 1), - "episode": torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=torch.long), - "steps": torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=torch.long), - ("next", "observation"): torch.randn(9, 16), - ("next", "reward"): torch.randn(9, 1), - ("next", "done"): torch.tensor( - [0, 0, 1, 0, 0, 1, 0, 0, 1], dtype=torch.bool - ).unsqueeze(1), - }, - batch_size=[9], - ) - rb.extend(data) - sample, info = rb.sample(return_info=True) - # print("episode", sample["episode"].tolist()) - # print("steps", sample["steps"].tolist()) - # print("weight", info["priority_weight"].tolist()) - - priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1]) - rb.update_priority(torch.arange(0, 9, 1), priority=priority) - sample, info = rb.sample(return_info=True) - # print("episode", sample["episode"].tolist()) - # print("steps", sample["steps"].tolist()) - # print("weight", info["priority_weight"].tolist()) - - -@pytest.mark.parametrize("device", get_default_devices()) -def test_prioritized_slice_sampler_episodes(device): - num_slices = 10 - batch_size = 20 - - episode = torch.zeros(100, dtype=torch.int, device=device) - episode[:30] = 1 - episode[30:55] = 2 - episode[55:70] = 3 - episode[70:] = 4 - steps = torch.cat( - [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 - ) - done = torch.zeros(100, 1, dtype=torch.bool) - done[torch.tensor([29, 54, 69])] = 1 - - data = TensorDict( - { - "observation": torch.randn(100, 16), - "action": torch.randn(100, 4), - "episode": episode, - "steps": steps, - ("next", "observation"): torch.randn(100, 16), - ("next", "reward"): torch.randn(100, 1), - ("next", "done"): done, - }, - batch_size=[100], - device=device, - ) - - num_steps = data.shape[0] - sampler = PrioritizedSliceSampler( - max_capacity=num_steps, - alpha=0.7, - beta=0.9, - num_slices=num_slices, - ) - - rb = TensorDictReplayBuffer( - storage=LazyMemmapStorage(100), - sampler=sampler, - batch_size=batch_size, - ) - rb.extend(data) - - episodes = [] - for _ in range(10): - sample = rb.sample() - episodes.append(sample["episode"]) - assert {1, 2, 3, 4} == set( - torch.cat(episodes).cpu().tolist() - ), "all episodes are expected to be sampled at least once" - - index = torch.arange(0, num_steps, 1) - new_priorities = torch.cat( - [torch.ones(30), torch.zeros(25), torch.ones(15), torch.zeros(30)], 0 - ) - sampler.update_priority(index, new_priorities) - - episodes = [] - for _ in range(10): - sample = rb.sample() - episodes.append(sample["episode"]) - assert {1, 3} == set( - torch.cat(episodes).cpu().tolist() - ), "after priority update, only episode 1 and 3 are expected to be sampled" - - -@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)]) -@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)]) -@pytest.mark.parametrize("gamma", [0.1]) -@pytest.mark.parametrize("total_steps", [200]) -@pytest.mark.parametrize("n_annealing_steps", [100]) -@pytest.mark.parametrize("anneal_every_n", [10, 159]) -@pytest.mark.parametrize("alpha_min", [0, 0.2]) -@pytest.mark.parametrize("beta_max", [1, 1.4]) -def test_prioritized_parameter_scheduler( - alpha, - beta, - gamma, - total_steps, - n_annealing_steps, - anneal_every_n, - alpha_min, - beta_max, -): - rb = TensorDictPrioritizedReplayBuffer( - alpha=alpha, beta=beta, storage=ListStorage(max_size=1000) - ) - data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000) - rb.extend(data) - alpha_scheduler = LinearScheduler( - rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps - ) - beta_scheduler = StepScheduler( - rb, - param_name="beta", - gamma=gamma, - n_steps=anneal_every_n, - max_value=beta_max, - mode="additive", - ) - - scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler)) - - alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha) - alpha_min = torch.tensor(alpha_min) - expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1) - expected_alpha_vals = torch.nn.functional.pad( - expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min - ) - - expected_beta_vals = [beta] - annealing_steps = total_steps // anneal_every_n - gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma - expected_beta_vals = ( - (beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max) - ) - for i in range(total_steps): - curr_alpha = rb.sampler.alpha - torch.testing.assert_close( - curr_alpha - if torch.is_tensor(curr_alpha) - else torch.tensor(curr_alpha).float(), - expected_alpha_vals[i], - msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}", - ) - curr_beta = rb.sampler.beta - torch.testing.assert_close( - curr_beta - if torch.is_tensor(curr_beta) - else torch.tensor(curr_beta).float(), - expected_beta_vals[i], - msg=f"expected {expected_beta_vals[i]}, got {curr_beta}", - ) - rb.sample(20) - scheduler.step() - - -class TestEnsemble: - def _make_data(self, data_type): - if data_type is torch.Tensor: - return torch.ones(90) - if data_type is TensorDict: - return TensorDict( - { - "root": torch.arange(90), - "nested": TensorDict( - {"data": torch.arange(180).view(90, 2)}, batch_size=[90, 2] - ), - }, - batch_size=[90], - ) - raise NotImplementedError - - def _make_sampler(self, sampler_type): - if sampler_type is SamplerWithoutReplacement: - return SamplerWithoutReplacement(drop_last=True) - if sampler_type is RandomSampler: - return RandomSampler() - raise NotImplementedError - - def _make_storage(self, storage_type, data_type): - if storage_type is LazyMemmapStorage: - return LazyMemmapStorage(max_size=100) - if storage_type is TensorStorage: - if data_type is TensorDict: - return TensorStorage(TensorDict(batch_size=[100])) - elif data_type is torch.Tensor: - return TensorStorage(torch.zeros(100)) - else: - raise NotImplementedError - if storage_type is ListStorage: - return ListStorage(max_size=100) - raise NotImplementedError - - def _make_collate(self, storage_type): - if storage_type is ListStorage: - return torch.stack - else: - return self._robust_stack - - @staticmethod - def _robust_stack(tensor_list): - if not isinstance(tensor_list, (tuple, list)): - return tensor_list - if all(tensor.shape == tensor_list[0].shape for tensor in tensor_list[1:]): - return torch.stack(list(tensor_list)) - if is_tensor_collection(tensor_list[0]): - return torch.cat(list(tensor_list)) - return torch.nested.nested_tensor(list(tensor_list)) - - @pytest.mark.parametrize( - "storage_type", [LazyMemmapStorage, TensorStorage, ListStorage] - ) - @pytest.mark.parametrize("data_type", [torch.Tensor, TensorDict]) - @pytest.mark.parametrize("p", [[0.0, 0.9, 0.1], None]) - @pytest.mark.parametrize("num_buffer_sampled", [3, 16, None]) - @pytest.mark.parametrize("batch_size", [48, None]) - @pytest.mark.parametrize("sampler_type", [RandomSampler, SamplerWithoutReplacement]) - def test_rb( - self, storage_type, sampler_type, data_type, p, num_buffer_sampled, batch_size - ): - storages = [self._make_storage(storage_type, data_type) for _ in range(3)] - collate_fn = self._make_collate(storage_type) - data = [self._make_data(data_type) for _ in range(3)] - samplers = [self._make_sampler(sampler_type) for _ in range(3)] - sub_batch_size = ( - batch_size // 3 - if issubclass(sampler_type, SamplerWithoutReplacement) - and batch_size is not None - else None - ) - error_catcher = ( - pytest.raises( - ValueError, - match="Samplers with drop_last=True must work with a predictable batch-size", - ) - if batch_size is None - and issubclass(sampler_type, SamplerWithoutReplacement) - else contextlib.nullcontext() - ) - rbs = None - with error_catcher: - rbs = (rb0, rb1, rb2) = [ - ReplayBuffer( - storage=storage, - sampler=sampler, - collate_fn=collate_fn, - batch_size=sub_batch_size, - ) - for (storage, sampler) in zip(storages, samplers) - ] - if rbs is None: - return - for datum, rb in zip(data, rbs): - rb.extend(datum) - rb = ReplayBufferEnsemble( - *rbs, p=p, num_buffer_sampled=num_buffer_sampled, batch_size=batch_size - ) - if batch_size is not None: - for batch_iter in rb: - assert isinstance(batch_iter, (torch.Tensor, TensorDictBase)) - break - batch_sample, info = rb.sample(return_info=True) - else: - batch_iter = None - batch_sample, info = rb.sample(48, return_info=True) - assert isinstance(batch_sample, (torch.Tensor, TensorDictBase)) - if isinstance(batch_sample, TensorDictBase): - assert "root" in batch_sample.keys() - assert "nested" in batch_sample.keys() - assert ("nested", "data") in batch_sample.keys(True) - if p is not None: - if batch_iter is not None: - buffer_ids = batch_iter.get(("index", "buffer_ids")) - assert isinstance(buffer_ids, torch.Tensor), batch_iter - assert 0 not in buffer_ids.unique().tolist() - - buffer_ids = batch_sample.get(("index", "buffer_ids")) - assert isinstance(buffer_ids, torch.Tensor), buffer_ids - assert 0 not in buffer_ids.unique().tolist() - if num_buffer_sampled is not None: - if batch_iter is not None: - assert batch_iter.shape == torch.Size( - [num_buffer_sampled, 48 // num_buffer_sampled] - ) - assert batch_sample.shape == torch.Size( - [num_buffer_sampled, 48 // num_buffer_sampled] - ) - else: - if batch_iter is not None: - assert batch_iter.shape == torch.Size([3, 16]) - assert batch_sample.shape == torch.Size([3, 16]) - - def _prepare_dual_replay_buffer(self, explicit=False): - torch.manual_seed(0) - rb0 = TensorDictReplayBuffer( - storage=LazyMemmapStorage(10), - transform=Compose( - ToTensorImage(in_keys=["pixels", ("next", "pixels")]), - Resize(32, in_keys=["pixels", ("next", "pixels")]), - RenameTransform([("some", "key")], ["renamed"]), - ), - ) - rb1 = TensorDictReplayBuffer( - storage=LazyMemmapStorage(10), - transform=Compose( - ToTensorImage(in_keys=["pixels", ("next", "pixels")]), - Resize(32, in_keys=["pixels", ("next", "pixels")]), - RenameTransform(["another_key"], ["renamed"]), - ), - ) - if explicit: - storages = StorageEnsemble( - rb0._storage, rb1._storage, transforms=[rb0._transform, rb1._transform] - ) - writers = WriterEnsemble(rb0._writer, rb1._writer) - samplers = SamplerEnsemble(rb0._sampler, rb1._sampler, p=[0.5, 0.5]) - collate_fns = [rb0._collate_fn, rb1._collate_fn] - rb = ReplayBufferEnsemble( - storages=storages, - samplers=samplers, - writers=writers, - collate_fns=collate_fns, - transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]), - ) - else: - rb = ReplayBufferEnsemble( - rb0, - rb1, - p=[0.5, 0.5], - transform=Resize(33, in_keys=["pixels"], out_keys=["pixels33"]), - ) - data0 = TensorDict( - { - "pixels": torch.randint(255, (10, 244, 244, 3)), - ("next", "pixels"): torch.randint(255, (10, 244, 244, 3)), - ("some", "key"): torch.randn(10), - }, - batch_size=[10], - ) - data1 = TensorDict( - { - "pixels": torch.randint(255, (10, 64, 64, 3)), - ("next", "pixels"): torch.randint(255, (10, 64, 64, 3)), - "another_key": torch.randn(10), - }, - batch_size=[10], - ) - rb0.extend(data0) - rb1.extend(data1) - return rb, rb0, rb1 - - @pytest.mark.skipif(not _has_tv, reason="torchvision not found") - def test_rb_transform(self): - rb, rb0, rb1 = self._prepare_dual_replay_buffer() - for _ in range(2): - sample = rb.sample(10) - assert sample["next", "pixels"].shape == torch.Size([2, 5, 3, 32, 32]) - assert sample["pixels"].shape == torch.Size([2, 5, 3, 32, 32]) - assert sample["pixels33"].shape == torch.Size([2, 5, 3, 33, 33]) - assert sample["renamed"].shape == torch.Size([2, 5]) - - @pytest.mark.skipif(not _has_tv, reason="torchvision not found") - @pytest.mark.parametrize("explicit", [False, True]) - def test_rb_indexing(self, explicit): - rb, rb0, rb1 = self._prepare_dual_replay_buffer(explicit=explicit) - if explicit: - # indirect checks - assert rb[0]._storage is rb0._storage - assert rb[1]._storage is rb1._storage - else: - assert rb[0] is rb0 - assert rb[1] is rb1 - assert rb[:] is rb - - torch.manual_seed(0) - sample1 = rb.sample(6) - # tensor - torch.manual_seed(0) - sample0 = rb[torch.tensor([0, 1])].sample(6) - assert_allclose_td(sample0, sample1) - # slice - torch.manual_seed(0) - sample0 = rb[:2].sample(6) - assert_allclose_td(sample0, sample1) - # np.ndarray - torch.manual_seed(0) - sample0 = rb[np.array([0, 1])].sample(6) - assert_allclose_td(sample0, sample1) - # list - torch.manual_seed(0) - sample0 = rb[[0, 1]].sample(6) - assert_allclose_td(sample0, sample1) - - # direct indexing - sample1 = rb[:, :3] - # tensor - sample0 = rb[torch.tensor([0, 1]), :3] - assert_allclose_td(sample0, sample1) - # slice - torch.manual_seed(0) - sample0 = rb[:2, :3] - assert_allclose_td(sample0, sample1) - # np.ndarray - torch.manual_seed(0) - sample0 = rb[np.array([0, 1]), :3] - assert_allclose_td(sample0, sample1) - # list - torch.manual_seed(0) - sample0 = rb[[0, 1], :3] - assert_allclose_td(sample0, sample1) - - # check indexing of components - assert isinstance(rb.storage[:], StorageEnsemble) - assert isinstance(rb.storage[:2], StorageEnsemble) - assert isinstance(rb.storage[torch.tensor([0, 1])], StorageEnsemble) - assert isinstance(rb.storage[np.array([0, 1])], StorageEnsemble) - assert isinstance(rb.storage[[0, 1]], StorageEnsemble) - assert isinstance(rb.storage[1], LazyMemmapStorage) - - rb.storage[:, :3] - rb.storage[:2, :3] - rb.storage[torch.tensor([0, 1]), :3] - rb.storage[np.array([0, 1]), :3] - rb.storage[[0, 1], :3] - - assert isinstance(rb.sampler[:], SamplerEnsemble) - assert isinstance(rb.sampler[:2], SamplerEnsemble) - assert isinstance(rb.sampler[torch.tensor([0, 1])], SamplerEnsemble) - assert isinstance(rb.sampler[np.array([0, 1])], SamplerEnsemble) - assert isinstance(rb.sampler[[0, 1]], SamplerEnsemble) - assert isinstance(rb.sampler[1], RandomSampler) - - assert isinstance(rb.writer[:], WriterEnsemble) - assert isinstance(rb.writer[:2], WriterEnsemble) - assert isinstance(rb.writer[torch.tensor([0, 1])], WriterEnsemble) - assert isinstance(rb.writer[np.array([0, 1])], WriterEnsemble) - assert isinstance(rb.writer[[0, 1]], WriterEnsemble) - assert isinstance(rb.writer[0], RoundRobinWriter) - - -def _rbtype(datatype): - if datatype in ("pytree", "tensorclass"): - return [ - (ReplayBuffer, RandomSampler), - (PrioritizedReplayBuffer, RandomSampler), - (ReplayBuffer, SamplerWithoutReplacement), - (PrioritizedReplayBuffer, SamplerWithoutReplacement), - ] - return [ - (ReplayBuffer, RandomSampler), - (ReplayBuffer, SamplerWithoutReplacement), - (PrioritizedReplayBuffer, None), - (TensorDictReplayBuffer, RandomSampler), - (TensorDictReplayBuffer, SamplerWithoutReplacement), - (TensorDictPrioritizedReplayBuffer, None), - ] - - -class TestRBMultidim: - @tensorclass - class MyData: - x: torch.Tensor - y: torch.Tensor - z: torch.Tensor - - def _make_data(self, datatype, datadim): - if datadim == 1: - shape = [12] - elif datadim == 2: - shape = [4, 3] - else: - raise NotImplementedError - if datatype == "pytree": - return { - "x": (torch.ones(*shape, 2), (torch.ones(*shape, 3))), - "y": [ - {"z": torch.ones(shape)}, - torch.ones((*shape, 1), dtype=torch.bool), - ], - } - elif datatype == "tensordict": - return TensorDict( - {"x": torch.ones(*shape, 2), "y": {"z": torch.ones(*shape, 3)}}, shape - ) - elif datatype == "tensorclass": - return self.MyData( - x=torch.ones(*shape, 2), - y=torch.ones(*shape, 3), - z=torch.ones((*shape, 1), dtype=torch.bool), - batch_size=shape, - ) - - datatype_rb_tuples = [ - [datatype, *rbtype] - for datatype in ["pytree", "tensordict", "tensorclass"] - for rbtype in _rbtype(datatype) - ] - - @pytest.mark.parametrize("datatype,rbtype,sampler_cls", datatype_rb_tuples) - @pytest.mark.parametrize("datadim", [1, 2]) - @pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage]) - def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls, sampler_cls): - data = self._make_data(datatype, datadim) - if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer): - rbtype = functools.partial(rbtype, sampler=sampler_cls()) - else: - rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1) - - rb = rbtype(storage=storage_cls(100, ndim=datadim), batch_size=4) - assert str(rb) # check str works - rb.extend(data) - assert str(rb) - assert len(rb) == 12 - data = rb[:] - if datatype in ("tensordict", "tensorclass"): - assert data.numel() == 12 - else: - assert all( - leaf.shape[:datadim].numel() == 12 for leaf in tree_flatten(data)[0] - ) - s = rb.sample() - assert str(rb) - if datatype in ("tensordict", "tensorclass"): - assert (s.exclude("index") == 1).all() - assert s.numel() == 4 - else: - for leaf in tree_iter(s): - assert leaf.shape[0] == 4 - assert (leaf == 1).all() - - @pytest.mark.skipif(not _has_gym, reason="gym required for this test.") - @pytest.mark.parametrize( - "writer_cls", - [TensorDictMaxValueWriter, RoundRobinWriter, TensorDictRoundRobinWriter], - ) - @pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage]) - @pytest.mark.parametrize( - "rbtype", - [ - functools.partial(ReplayBuffer, batch_size=8), - functools.partial(TensorDictReplayBuffer, batch_size=8), - ], - ) - @pytest.mark.parametrize( - "sampler_cls", - [ - functools.partial(SliceSampler, num_slices=2, strict_length=False), - RandomSampler, - functools.partial( - SliceSamplerWithoutReplacement, num_slices=2, strict_length=False - ), - functools.partial(PrioritizedSampler, alpha=1.0, beta=1.0, max_capacity=10), - functools.partial( - PrioritizedSliceSampler, - alpha=1.0, - beta=1.0, - max_capacity=10, - num_slices=2, - strict_length=False, - ), - ], - ) - @pytest.mark.parametrize( - "transform", - [ - None, - [ - lambda: split_trajectories, - functools.partial(MultiStep, gamma=0.9, n_steps=3), - ], - ], - ) - @pytest.mark.parametrize("env_device", get_default_devices()) - def test_rb_multidim_collector( - self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device - ): - torch.manual_seed(0) - env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device) - env.set_seed(0) - collector = Collector( - env, - RandomPolicy(env.action_spec), - frames_per_batch=4, - total_frames=16, - device=env_device, - ) - if writer_cls is TensorDictMaxValueWriter: - with pytest.raises( - ValueError, - match="TensorDictMaxValueWriter is not compatible with storages with more than one dimension", - ): - rb = rbtype( - storage=storage_cls(max_size=10, ndim=2), - sampler=sampler_cls(), - writer=writer_cls(), - delayed_init=False, - ) - return - rb = rbtype( - storage=storage_cls(max_size=10, ndim=2), - sampler=sampler_cls(), - writer=writer_cls(), - ) - if not isinstance(rb.sampler, SliceSampler) and transform is not None: - pytest.skip("no need to test this combination") - if transform: - for t in transform: - rb.append_transform(t()) - try: - for i, data in enumerate(collector): # noqa: B007 - assert data.device == torch.device(env_device) - rb.extend(data) - if isinstance(rb, TensorDictReplayBuffer) and transform is not None: - # this should fail bc we can't set the indices after executing the transform. - with pytest.raises( - RuntimeError, match="Failed to set the metadata" - ): - rb.sample() - return - s = rb.sample() - assert s.device == torch.device("cpu") - rbtot = rb[:] - assert rbtot.shape[0] == 2 - assert len(rb) == rbtot.numel() - if transform is not None: - assert s.ndim == 2 - except Exception: - raise - - @pytest.mark.parametrize("strict_length", [True, False]) - def test_done_slicesampler(self, strict_length): - env = SerialEnv( - 3, - [ - lambda: CountingEnv(max_steps=31).add_truncated_keys(), - lambda: CountingEnv(max_steps=32).add_truncated_keys(), - lambda: CountingEnv(max_steps=33).add_truncated_keys(), - ], - ) - full_action_spec = CountingEnv(max_steps=32).full_action_spec - policy = lambda td: td.update( - full_action_spec.zero((3,)).apply_(lambda x: x + 1) - ) - rb = TensorDictReplayBuffer( - storage=LazyTensorStorage(200, ndim=2), - sampler=SliceSampler( - slice_len=32, - strict_length=strict_length, - truncated_key=("next", "truncated"), - ), - batch_size=128, - ) - - # env.add_truncated_keys() - - for i in range(50): - r = env.rollout( - 50, policy=policy, break_when_any_done=False, set_truncated=True - ) - rb.extend(r) - - sample = rb.sample() - - assert sample["next", "done"].sum() == 128 // 32, ( - i, - sample["next", "done"].sum(), - ) - assert (split_trajectories(sample)["next", "done"].sum(-2) == 1).all() - - -@pytest.mark.skipif(not _has_gym, reason="gym required") -class TestCheckpointers: - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - @pytest.mark.parametrize( - "checkpointer", - [FlatStorageCheckpointer, H5StorageCheckpointer, NestedStorageCheckpointer], - ) - @pytest.mark.parametrize("frames_per_batch", [22, 122]) - def test_simple_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): - env = GymEnv(CARTPOLE_VERSIONED(), device=None) - env.set_seed(0) - torch.manual_seed(0) - collector = Collector( - env, - policy=env.rand_step, - total_frames=200, - frames_per_batch=frames_per_batch, - ) - rb = ReplayBuffer(storage=storage_type(100)) - rb_test = ReplayBuffer(storage=storage_type(100)) - if torch.__version__ < "2.4.0.dev" and checkpointer in ( - H5StorageCheckpointer, - NestedStorageCheckpointer, - ): - with pytest.raises(ValueError, match="Unsupported torch version"): - checkpointer() - return - rb.storage.checkpointer = checkpointer() - rb_test.storage.checkpointer = checkpointer() - for data in collector: - rb.extend(data) - rb.dumps(tmpdir) - rb_test.loads(tmpdir) - assert_allclose_td(rb_test[:], rb[:]) - assert rb.writer._cursor == rb_test._writer._cursor - - @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) - @pytest.mark.parametrize("frames_per_batch", [22, 122]) - @pytest.mark.parametrize( - "checkpointer", - [FlatStorageCheckpointer, NestedStorageCheckpointer, H5StorageCheckpointer], - ) - def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): - env = SerialEnv( - 3, - lambda: GymEnv(CARTPOLE_VERSIONED(), device=None).append_transform( - StepCounter() - ), - ) - env.set_seed(0) - torch.manual_seed(0) - collector = Collector( - env, - policy=env.rand_step, - total_frames=200, - frames_per_batch=frames_per_batch, - ) - rb = ReplayBuffer(storage=storage_type(100, ndim=2)) - rb_test = ReplayBuffer(storage=storage_type(100, ndim=2)) - if torch.__version__ < "2.4.0.dev" and checkpointer in ( - H5StorageCheckpointer, - NestedStorageCheckpointer, - ): - with pytest.raises(ValueError, match="Unsupported torch version"): - checkpointer() - return - rb.storage.checkpointer = checkpointer() - rb_test.storage.checkpointer = checkpointer() - for data in collector: - rb.extend(data) - assert rb.storage.max_size == 102 - if frames_per_batch > 100: - assert rb.storage._is_full - assert len(rb) == 102 - # Checks that when writing to the buffer with a batch greater than the total - # size, we get the last step written properly. - assert (rb[:]["next", "step_count"][:, -1] != 0).any() - rb.dumps(tmpdir) - rb.dumps(tmpdir) - rb_test.loads(tmpdir) - assert_allclose_td(rb_test[:], rb[:]) - assert rb.writer._cursor == rb_test._writer._cursor - - -@pytest.mark.skipif(not _has_ray, reason="ray required for this test.") -class TestRayRB: - @pytest.fixture(autouse=True, scope="module") - def cleanup(self): - import ray - - ray.shutdown() - torchrl_logger.info("Initializing Ray.") - ray.init(num_cpus=1) - yield - torchrl_logger.info("Shutting down Ray.") - ray.shutdown() - - def test_ray_rb(self): - rb = RayReplayBuffer( - storage=partial(LazyTensorStorage, 100), ray_init_config={"num_cpus": 1} - ) - try: - rb.extend( - TensorDict( - {"x": torch.ones(100, 2), "y": torch.ones(100, 2)}, batch_size=100 - ) - ) - assert rb.write_count == 100 - assert len(rb) == 100 - assert rb.sample(2).shape == (2,) - finally: - rb.close() - - def test_ray_rb_iter(self): - rb = RayReplayBuffer( - storage=partial(LazyTensorStorage, 100), - ray_init_config={"num_cpus": 1}, - sampler=SamplerWithoutReplacement, - batch_size=25, - ) - try: - rb.extend( - TensorDict( - { - "x": torch.ones( - 100, - ), - "y": torch.ones( - 100, - ), - }, - batch_size=100, - ) - ) - for _ in range(2): - for d in rb: - torchrl_logger.info(f"d: {d}") - assert d is not None - assert d.shape == (25,) - finally: - rb.close() - - def test_ray_rb_serialization(self): - import ray - - class Worker: - def __init__(self, rb): - self.rb = rb - - def run(self): - self.rb.extend(TensorDict({"x": torch.ones(100)}, batch_size=100)) - - rb = RayReplayBuffer( - storage=partial(LazyTensorStorage, 100), ray_init_config={"num_cpus": 1} - ) - try: - remote_worker = ray.remote(Worker).remote(rb) - ray.get(remote_worker.run.remote()) - finally: - rb.close() - - -class TestSharedStorageInit: - def worker(self, rb, worker_id, queue): - length = len(rb) - data = TensorDict({"x": torch.full((2,), worker_id)}, batch_size=(2,)) - worker_id * 2 - index = rb.extend(data) - assert len(rb) >= length + 2 - assert (rb[index] == data).all() - queue.put("done") - - @pytest.mark.parametrize( - "storage_cls, use_tmpdir", - [ - (LazyTensorStorage, False), - (LazyMemmapStorage, False), - (LazyMemmapStorage, True), - ], - ) - def test_shared_storage_multiprocess(self, storage_cls, use_tmpdir, tmpdir): - if use_tmpdir: - storage_cls = functools.partial(storage_cls, scratch_dir=tmpdir) - storage = storage_cls(max_size=100, shared_init=True) - rb = ReplayBuffer(storage=storage, batch_size=2).share(True) - queue = mp.Queue() - - processes = [] - for i in range(4): - p = mp.Process(target=self.worker, args=(rb, i, queue)) - processes.append(p) - p.start() - - for p in processes: - p.join() - queue.get() - - all_data = storage.get(slice(0, 8)) - values = set(all_data["x"].tolist()) - expected = {0.0, 1.0, 2.0, 3.0} - assert expected.issubset(values) - assert len(storage) >= 8 - - def prioritized_collector_worker(self, rb, worker_id, queue): - data = TensorDict( - { - "obs": torch.full((4, 1), worker_id, dtype=torch.float32), - "td_error": torch.linspace(0.1, 1.0, 4) + worker_id, - }, - batch_size=(4,), - ) - rb.extend(data) - queue.put("done") - - @pytest.mark.gpu - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") - def test_prioritized_memmap_cuda_sampler_after_multiprocess_writes(self, tmpdir): - ext = pytest.importorskip("torchrl._torchrl") - if not hasattr(ext, "CudaSumSegmentTreeFp32"): - pytest.skip("TorchRL was not built with CUDA segment tree support") - - storage = LazyMemmapStorage(max_size=32, scratch_dir=tmpdir, shared_init=True) - writer_rb = TensorDictReplayBuffer(storage=storage, batch_size=4).share(True) - queue = mp.Queue() - - processes = [] - for i in range(2): - p = mp.Process( - target=self.prioritized_collector_worker, - args=(writer_rb, i, queue), - ) - processes.append(p) - p.start() - - for p in processes: - p.join() - assert p.exitcode == 0 - assert queue.get(timeout=5) == "done" - - assert len(storage) == 8 - learner_rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.5, - storage=storage, - sampler_device="cuda:0", - batch_size=4, - priority_key="td_error", - ) - - sample = learner_rb.sample() - assert learner_rb._sampler.device == torch.device("cuda:0") - assert sample["obs"].device.type == "cpu" - assert sample["index"].device.type == "cpu" - assert sample["priority_weight"].device.type == "cpu" - - sample["td_error"] = torch.ones_like(sample["td_error"]) * 10 - learner_rb.update_tensordict_priority(sample) - sample = learner_rb.sample() - assert sample["index"].device.type == "cpu" - - -@pytest.mark.skipif(not _has_zstandard, reason="zstandard required for this test.") -class TestCompressedListStorage: - """Test cases for CompressedListStorage.""" - - def test_compressed_storage_initialization(self): - """Test that CompressedListStorage initializes correctly.""" - storage = CompressedListStorage(max_size=100, compression_level=3) - assert storage.max_size == 100 - assert storage.compression_level == 3 - assert len(storage) == 0 - - @pytest.mark.parametrize( - "test_tensor", - [ - torch.rand(1), # 0D scalar - torch.randn(84, dtype=torch.float32), # 1D tensor - torch.randn(84, 84, dtype=torch.float32), # 2D tensor - torch.randn(1, 84, 84, dtype=torch.float32), # 3D tensor - torch.randn(32, 84, 84, dtype=torch.float32), # 3D tensor - ], - ) - def test_compressed_storage_tensor(self, test_tensor): - """Test compression and decompression of tensor data of various shapes.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Store tensor - storage.set(0, test_tensor) - - # Retrieve tensor - retrieved_tensor = storage.get(0) - - # Verify data integrity - assert ( - test_tensor.shape == retrieved_tensor.shape - ), f"Expected shape {test_tensor.shape}, got {retrieved_tensor.shape}" - assert ( - test_tensor.dtype == retrieved_tensor.dtype - ), f"Expected dtype {test_tensor.dtype}, got {retrieved_tensor.dtype}" - assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) - - def test_compressed_storage_tensordict(self): - """Test compression and decompression of TensorDict data.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Create test TensorDict - test_td = TensorDict( - { - "obs": torch.randn(3, 84, 84, dtype=torch.float32), - "action": torch.tensor([1, 2, 3]), - "reward": torch.randn(3), - "done": torch.tensor([False, True, False]), - }, - batch_size=[3], - ) - - # Store TensorDict - storage.set(0, test_td) - - # Retrieve TensorDict - retrieved_td = storage.get(0) - - # Verify data integrity - assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) - assert torch.allclose(test_td["action"], retrieved_td["action"]) - assert torch.allclose(test_td["reward"], retrieved_td["reward"], atol=1e-6) - assert torch.allclose(test_td["done"], retrieved_td["done"]) - - def test_compressed_storage_multiple_indices(self): - """Test storing and retrieving multiple items.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Store multiple tensors - tensors = [ - torch.randn(2, 2, dtype=torch.float32), - torch.randn(3, 3, dtype=torch.float32), - torch.randn(4, 4, dtype=torch.float32), - ] - - for i, tensor in enumerate(tensors): - storage.set(i, tensor) - - # Retrieve multiple tensors - retrieved = storage.get([0, 1, 2]) - - # Verify data integrity - for original, retrieved_tensor in zip(tensors, retrieved): - assert torch.allclose(original, retrieved_tensor, atol=1e-6) - - def test_compressed_storage_with_replay_buffer(self): - """Test CompressedListStorage with ReplayBuffer.""" - storage = CompressedListStorage(max_size=100, compression_level=3) - rb = ReplayBuffer(storage=storage, batch_size=5) - - # Create test data - data = TensorDict( - { - "obs": torch.randn(10, 3, 84, 84, dtype=torch.float32), - "action": torch.randint(0, 4, (10,)), - "reward": torch.randn(10), - }, - batch_size=[10], - ) - - # Add data to replay buffer - rb.extend(data) - - # Sample from replay buffer - sample = rb.sample(5) - - # Verify sample has correct shape - assert is_tensor_collection(sample), sample - assert sample["obs"].shape[0] == 5 - assert sample["obs"].shape[1:] == (3, 84, 84) - assert sample["action"].shape[0] == 5 - assert sample["reward"].shape[0] == 5 - - def test_compressed_storage_state_dict(self): - """Test saving and loading state dict.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Add some data - test_tensor = torch.randn(3, 3, dtype=torch.float32) - storage.set(0, test_tensor) - - # Save state dict - state_dict = storage.state_dict() - - # Create new storage and load state dict - new_storage = CompressedListStorage(max_size=10, compression_level=3) - new_storage.load_state_dict(state_dict) - - # Verify data integrity - retrieved_tensor = new_storage.get(0) - assert torch.allclose(test_tensor, retrieved_tensor, atol=1e-6) - - def test_compressed_storage_checkpointing(self): - """Test checkpointing functionality.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Add some data - test_td = TensorDict( - { - "obs": torch.randn(3, 84, 84, dtype=torch.float32), - "action": torch.tensor([1, 2, 3]), - }, - batch_size=[3], - ) - storage.set(0, test_td) - - # second batch, different shape - test_td2 = TensorDict( - { - "obs": torch.randn(3, 85, 83, dtype=torch.float32), - "action": torch.tensor([1, 2, 3]), - "meta": torch.randn(3), - "astring": "a string!", - }, - batch_size=[3], - ) - storage.set(1, test_td) - - # Create temporary directory for checkpointing - with tempfile.TemporaryDirectory() as tmpdir: - checkpoint_path = Path(tmpdir) / "checkpoint" - - # Save checkpoint - storage.dumps(checkpoint_path) - - # Create new storage and load checkpoint - new_storage = CompressedListStorage(max_size=10, compression_level=3) - new_storage.loads(checkpoint_path) - - # Verify data integrity - retrieved_td = new_storage.get(0) - assert torch.allclose(test_td["obs"], retrieved_td["obs"], atol=1e-6) - assert torch.allclose(test_td["action"], retrieved_td["action"]) - - def test_compressed_storage_length(self): - """Test that length is calculated correctly.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Initially empty - assert len(storage) == 0 - - # Add some data - storage.set(0, torch.randn(2, 2)) - assert len(storage) == 1 - - storage.set(1, torch.randn(2, 2)) - assert len(storage) == 2 - - storage.set(2, torch.randn(2, 2)) - assert len(storage) == 3 - - def test_compressed_storage_contains(self): - """Test the contains method.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Initially empty - assert not storage.contains(0) - - # Add data - storage.set(0, torch.randn(2, 2)) - assert storage.contains(0) - assert not storage.contains(1) - - def test_compressed_storage_empty(self): - """Test emptying the storage.""" - storage = CompressedListStorage(max_size=10, compression_level=3) - - # Add some data - storage.set(0, torch.randn(2, 2)) - storage.set(1, torch.randn(2, 2)) - assert len(storage) == 2 - - # Empty storage - storage._empty() - assert len(storage) == 0 - - def test_compressed_storage_custom_compression(self): - """Test custom compression functions.""" - - def custom_compress(tensor): - # Simple compression: just convert to uint8 - return tensor.to(torch.uint8) - - def custom_decompress(compressed_tensor, metadata): - # Simple decompression: convert back to original dtype - return compressed_tensor.to(metadata["dtype"]) - - storage = CompressedListStorage( - max_size=10, - compression_fn=custom_compress, - decompression_fn=custom_decompress, - ) - - # Test with tensor - test_tensor = torch.randn(2, 2, dtype=torch.float32) - storage.set(0, test_tensor) - retrieved_tensor = storage.get(0) - - # Note: This will lose precision due to uint8 conversion - # but should still work - assert retrieved_tensor.shape == test_tensor.shape - - def test_compressed_storage_error_handling(self): - """Test error handling for invalid operations.""" - storage = CompressedListStorage(max_size=5, compression_level=3) - - # Test setting data beyond max_size - with pytest.raises(RuntimeError): - storage.set(10, torch.randn(2, 2)) - - # Test getting non-existent data - with pytest.raises(IndexError): - storage.get(0) - - def test_compressed_storage_memory_efficiency(self): - """Test that compression actually reduces memory usage.""" - storage = CompressedListStorage(max_size=100, compression_level=3) - - # Create large tensor data - large_tensor = torch.zeros(100, 3, 84, 84, dtype=torch.int64) - large_tensor.copy_( - torch.arange(large_tensor.numel(), dtype=torch.int32).view_as(large_tensor) - // (3 * 84 * 84) - ) - original_size = large_tensor.numel() * large_tensor.element_size() - - # Store in compressed storage - storage.set(0, large_tensor) - - # Estimate compressed size - compressed_data = storage._storage[0] - compressed_size = compressed_data.numel() # uint8 bytes - - # Verify compression ratio is reasonable (at least 2x for random data) - compression_ratio = original_size / compressed_size - assert ( - compression_ratio > 1.5 - ), f"Compression ratio {compression_ratio} is too low" - - -class TestRBLazyInit: - def test_lazy_init(self): - def transform(td): - return td - - rb = ReplayBuffer( - storage=partial(ListStorage), - writer=partial(RoundRobinWriter), - sampler=partial(RandomSampler), - transform_factory=lambda: transform, - ) - assert not rb.initialized - assert not hasattr(rb, "_storage") - assert rb._init_storage is not None - assert not hasattr(rb, "_sampler") - assert rb._init_sampler is not None - assert not hasattr(rb, "_writer") - assert rb._init_writer is not None - rb.extend(TensorDict(batch_size=[2])) - assert rb.initialized - assert rb._storage is not None - assert rb._init_storage is None - assert rb._sampler is not None - assert rb._init_sampler is None - assert rb._writer is not None - assert rb._init_writer is None - - rb = ReplayBuffer( - storage=partial(ListStorage), - writer=partial(RoundRobinWriter), - sampler=partial(RandomSampler), - ) - assert rb.initialized - assert rb._storage is not None - assert rb._init_storage is None - assert rb._sampler is not None - assert rb._init_sampler is None - assert rb._writer is not None - assert rb._init_writer is None - - rb = ReplayBuffer( - storage=partial(ListStorage), - writer=partial(RoundRobinWriter), - sampler=partial(RandomSampler), - delayed_init=False, - ) - assert rb.initialized - assert rb._storage is not None - assert rb._init_storage is None - assert rb._sampler is not None - assert rb._init_sampler is None - assert rb._writer is not None - assert rb._init_writer is None - - -@pytest.mark.skipif( - _os_is_windows, reason="Windows file locking prevents cleanup tests" -) -class TestLazyMemmapStorageCleanup: - """Tests for LazyMemmapStorage automatic cleanup functionality.""" - - def test_cleanup_explicit_scratch_dir(self, tmpdir): - """Test that cleanup removes files when scratch_dir is specified.""" - scratch_dir = str(tmpdir / "memmap_storage") - os.makedirs(scratch_dir, exist_ok=True) - - storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) - rb = ReplayBuffer(storage=storage) - rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) - - # Verify files were created - assert os.path.isdir(scratch_dir) - assert len(os.listdir(scratch_dir)) > 0 - - # Cleanup should remove the directory - result = storage.cleanup() - assert result is True - assert not os.path.exists(scratch_dir) - - # Second cleanup should be a no-op - result = storage.cleanup() - assert result is False - - def test_cleanup_temp_dir(self): - """Test cleanup when using default temp directory.""" - storage = LazyMemmapStorage(100, auto_cleanup=True) - rb = ReplayBuffer(storage=storage) - rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) - - # Get the temp directory paths before cleanup - temp_paths = set() - for tensor in storage._storage.values(include_nested=True, leaves_only=True): - try: - if hasattr(tensor, "filename") and tensor.filename: - temp_paths.add(os.path.dirname(tensor.filename)) - except (AttributeError, RuntimeError): - continue - - # Cleanup should remove the files if any were created on disk - result = storage.cleanup() - if len(temp_paths) > 0: - assert result is True - # Paths should no longer exist - for path in temp_paths: - assert not os.path.exists(path) - else: - # If no files were created (e.g. anonymous memmap), result should be False - assert result is False - - def test_auto_cleanup_default_behavior(self, tmpdir): - """Test that auto_cleanup defaults correctly based on scratch_dir.""" - # When scratch_dir is None, auto_cleanup should default to True - storage1 = LazyMemmapStorage(100) - assert storage1._auto_cleanup is True - assert storage1._scratch_dir_is_temp is True - - # When scratch_dir is provided, auto_cleanup should default to False - scratch_dir = str(tmpdir / "user_storage") - storage2 = LazyMemmapStorage(100, scratch_dir=scratch_dir) - assert storage2._auto_cleanup is False - assert storage2._scratch_dir_is_temp is False - - # User can override - storage3 = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) - assert storage3._auto_cleanup is True - - storage4 = LazyMemmapStorage(100, auto_cleanup=False) - assert storage4._auto_cleanup is False - - def test_cleanup_idempotent(self, tmpdir): - """Test that cleanup can be called multiple times safely.""" - scratch_dir = str(tmpdir / "memmap_storage") - os.makedirs(scratch_dir, exist_ok=True) - - storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) - rb = ReplayBuffer(storage=storage) - rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) - - # Multiple cleanups should not raise - storage.cleanup() - storage.cleanup() - storage.cleanup() - assert storage._cleaned_up is True - - def test_cleanup_nonexistent_dir(self, tmpdir): - """Test cleanup when directory was already deleted.""" - scratch_dir = str(tmpdir / "memmap_storage") - os.makedirs(scratch_dir, exist_ok=True) - - storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) - rb = ReplayBuffer(storage=storage) - rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) - - # Delete the directory externally - shutil.rmtree(scratch_dir) - assert not os.path.exists(scratch_dir) - - # Cleanup should handle missing directory gracefully - result = storage.cleanup() - assert result is False # No cleanup needed since dir is gone - - def test_cleanup_uninitialized_storage(self): - """Test cleanup on storage that was never used.""" - storage = LazyMemmapStorage(100, auto_cleanup=True) - # Storage is not initialized - cleanup should be safe - result = storage.cleanup() - assert result is False - - def test_cleanup_registry(self): - """Test that storages are registered for cleanup.""" - storage = LazyMemmapStorage(100, auto_cleanup=True) - # Check storage is in the registry (avoids race with GC on WeakSet) - assert storage in _MEMMAP_STORAGE_REGISTRY - - # Storage with auto_cleanup=False should not be registered - storage2 = LazyMemmapStorage(100, auto_cleanup=False) - assert storage2 not in _MEMMAP_STORAGE_REGISTRY - # Original storage should still be in the registry - assert storage in _MEMMAP_STORAGE_REGISTRY - - # Cleanup should still work - storage.cleanup() - - def test_cleanup_subprocess(self, tmpdir): - """Test that cleanup works correctly in subprocess scenarios.""" - scratch_dir = str(tmpdir / "subprocess_storage") - - # Create a script that creates a storage and exits normally - script = f""" -import torch -from tensordict import TensorDict -from torchrl.data import ReplayBuffer, LazyMemmapStorage - -storage = LazyMemmapStorage(100, scratch_dir="{scratch_dir}", auto_cleanup=True) -rb = ReplayBuffer(storage=storage) -rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) -print("Storage created") -# Normal exit - atexit handler should clean up -""" - result = subprocess.run( - [sys.executable, "-c", script], - capture_output=True, - text=True, - timeout=30, - ) - - # Script should have succeeded - assert result.returncode == 0, f"Script failed: {result.stderr}" - - # Directory should have been cleaned up on exit - assert not os.path.exists( - scratch_dir - ), f"Directory {scratch_dir} should have been cleaned up" - - def test_cleanup_signal_interrupt(self, tmpdir): - """Test that cleanup happens on SIGINT (Ctrl+C).""" - scratch_dir = str(tmpdir / "signal_storage") - - # Create a script that sleeps and can be interrupted - script = f""" -import signal -import time -import torch -from tensordict import TensorDict -from torchrl.data import ReplayBuffer, LazyMemmapStorage - -storage = LazyMemmapStorage(100, scratch_dir="{scratch_dir}", auto_cleanup=True) -rb = ReplayBuffer(storage=storage) -rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) -print("READY", flush=True) -time.sleep(60) # Will be interrupted -""" - proc = subprocess.Popen( - [sys.executable, "-c", script], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - - # Wait for the script to be ready - try: - # Read until we see READY - start = time.time() - while time.time() - start < 10: - line = proc.stdout.readline() - if "READY" in line: - break - else: - proc.kill() - pytest.skip("Script did not start in time") - - # Give it a moment to set up signal handlers - time.sleep(0.5) - - # Verify directory exists - assert os.path.isdir(scratch_dir) - - # Send SIGINT (Ctrl+C) - proc.send_signal(signal.SIGINT) - proc.wait(timeout=5) - - # Directory should have been cleaned up - assert not os.path.exists( - scratch_dir - ), f"Directory {scratch_dir} should have been cleaned up on SIGINT" - finally: - if proc.poll() is None: - proc.kill() - proc.wait() - - def test_cleanup_with_del(self, tmpdir): - """Test that __del__ triggers cleanup.""" - scratch_dir = str(tmpdir / "del_storage") - os.makedirs(scratch_dir, exist_ok=True) - - def create_and_delete(): - storage = LazyMemmapStorage(100, scratch_dir=scratch_dir, auto_cleanup=True) - rb = ReplayBuffer(storage=storage) - rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) - # Storage goes out of scope here - - create_and_delete() - - # Force garbage collection - gc.collect() - - # Note: __del__ is not guaranteed to run immediately, but the cleanup - # infrastructure should still work via atexit - - def test_cleanup_preserves_user_data_by_default(self, tmpdir): - """Test that user-specified directories are NOT cleaned by default.""" - scratch_dir = str(tmpdir / "user_data") - os.makedirs(scratch_dir, exist_ok=True) - - storage = LazyMemmapStorage(100, scratch_dir=scratch_dir) - rb = ReplayBuffer(storage=storage) - rb.extend(TensorDict(a=torch.randn(10), batch_size=[10])) - - # auto_cleanup should be False by default - assert storage._auto_cleanup is False - - # Directory should exist - assert os.path.isdir(scratch_dir) - - # Explicit cleanup should still work - storage.cleanup() - assert not os.path.exists(scratch_dir) - - -if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)