Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1f6f327
Fixed MultiSyncCollector set_seed and split_trajs issue
ParamThakkar123 Jan 19, 2026
e2aaf6b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 20, 2026
40642d5
Revert "Fixed MultiSyncCollector set_seed and split_trajs issue"
ParamThakkar123 Jan 20, 2026
efdc89c
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 21, 2026
628f44b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 23, 2026
a476a77
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 24, 2026
0f565c5
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 25, 2026
7fb086b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 26, 2026
ff72793
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 28, 2026
69001ed
Added Support for index_select in TensorSpec
ParamThakkar123 Jan 28, 2026
4ab13be
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 29, 2026
2e8face
rebase
ParamThakkar123 Jan 29, 2026
56e1529
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 31, 2026
ba6a19f
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 4, 2026
8be545b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 5, 2026
88ab847
[Feature] Extended Support delayed spec initialization for exploratio…
ParamThakkar123 Feb 5, 2026
c38881b
Merge branch 'main' of https://github.com/pytorch/rl into add/delayed…
ParamThakkar123 Feb 7, 2026
50c4c7f
edit
vmoens Feb 7, 2026
0d7d355
Merge branch 'main' of https://github.com/pytorch/rl into add/delayed…
ParamThakkar123 Feb 8, 2026
5fa4a9e
Merge branch 'add/delayed-spec-extended' of https://github.com/ParamT…
ParamThakkar123 Feb 8, 2026
19c9f59
Merge branch 'main' of https://github.com/pytorch/rl into add/delayed…
ParamThakkar123 Feb 17, 2026
3a0e415
Merge branch 'main' of https://github.com/pytorch/rl into add/delayed…
ParamThakkar123 Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_no_spec_error(self):

with pytest.raises(
RuntimeError,
match="Failed while executing module|spec must be provided to the exploration wrapper",
match="Failed while executing module|spec has not been set",
):
explorative_policy(td)

Expand Down Expand Up @@ -421,8 +421,12 @@ def test_nested(
return

def test_no_spec_error(self, device):
with pytest.raises(RuntimeError, match="spec cannot be None."):
OrnsteinUhlenbeckProcessModule(spec=None).to(device)
module = OrnsteinUhlenbeckProcessModule(spec=None, safe=False).to(device)
td = TensorDict(
{"action": torch.randn(3, device=device)}, batch_size=[3], device=device
)
out = module(td)
assert "action" in out.keys()


@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -683,7 +687,7 @@ def test_set_exploration_modules_spec_from_env(device, use_batched_env):
d_obs = env.observation_spec["observation"].shape[-1]
d_act = expected_spec.shape[-1]

# Create a policy with exploration module that has spec=None
# Create a policy with exploration modules that have spec=None
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
Expand All @@ -699,19 +703,29 @@ def test_set_exploration_modules_spec_from_env(device, use_batched_env):
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
exploration_module = AdditiveGaussianModule(spec=None, device=device)
exploratory_policy = TensorDictSequential(policy, exploration_module)
additive = AdditiveGaussianModule(spec=None, device=device)
egreedy = EGreedyModule(spec=None, device=device)
ou = OrnsteinUhlenbeckProcessModule(spec=None, device=device)
exploratory_policy = TensorDictSequential(policy, additive, egreedy, ou)

assert exploration_module._spec is None
assert additive.spec is None
assert egreedy.spec is None
assert ou.spec is None

set_exploration_modules_spec_from_env(exploratory_policy, env)

# Verify spec is set after configuration and matches the environment's action_spec
assert exploration_module._spec is not None
if isinstance(exploration_module._spec, Composite):
assert exploration_module._spec[exploration_module.action_key] == expected_spec
else:
assert exploration_module._spec == expected_spec
for exploration_module in (additive, egreedy, ou):
assert exploration_module.spec is not None
if isinstance(exploration_module.spec, Composite):
action_key = (
exploration_module.action_key
if hasattr(exploration_module, "action_key")
else exploration_module.ou.key
)
assert exploration_module.spec[action_key] == expected_spec
else:
assert exploration_module.spec == expected_spec

td = env.reset()
result = exploratory_policy(td)
Expand Down
71 changes: 44 additions & 27 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class EGreedyModule(TensorDictModuleBase):

def __init__(
self,
spec: TensorSpec,
spec: TensorSpec | None,
eps_init: float = 1.0,
eps_end: float = 0.1,
annealing_num_steps: int = 1000,
Expand Down Expand Up @@ -123,17 +123,22 @@ def __init__(
"eps", torch.as_tensor(eps_init, dtype=torch.float32, device=device)
)

if spec is not None:
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
spec = Composite({action_key: spec}, shape=spec.shape[:-1])
if device is not None:
spec = spec.to(device)
self._spec = spec
self.spec = spec

@property
def spec(self):
return self._spec

@spec.setter
def spec(self, value: TensorSpec | None) -> None:
if value is not None:
if not isinstance(value, Composite) and len(self.out_keys) >= 1:
value = Composite({self.action_key: value}, shape=value.shape[:-1])
if self.eps.device is not None:
value = value.to(self.eps.device)

self._spec = value

def step(self, frames: int = 1) -> None:
"""A step of epsilon decay.

Expand Down Expand Up @@ -203,7 +208,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
r = r.to(device)
action = torch.where(cond, r, action)
else:
raise RuntimeError("spec must be provided to the exploration wrapper.")
raise RuntimeError(
"spec has not been set. Pass spec at construction time or set it via "
"the `spec` property before calling forward()."
)
action_tensordict.set(action_key, action)
return tensordict

Expand Down Expand Up @@ -518,7 +526,7 @@ class OrnsteinUhlenbeckProcessModule(TensorDictModuleBase):

def __init__(
self,
spec: TensorSpec,
spec: TensorSpec | None,
eps_init: float = 1.0,
eps_end: float = 0.1,
annealing_num_steps: int = 1000,
Expand Down Expand Up @@ -564,20 +572,7 @@ def __init__(
self.in_keys = [self.ou.key]
self.out_keys = [self.ou.key] + self.ou.out_keys
self.is_init_key = is_init_key
noise_key = self.ou.noise_key
steps_key = self.ou.steps_key

if spec is not None:
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
spec = Composite({action_key: spec}, shape=spec.shape[:-1])
self._spec = spec
else:
raise RuntimeError("spec cannot be None.")
ou_specs = {
noise_key: None,
steps_key: None,
}
self._spec.update(ou_specs)
self.spec = spec
if len(set(self.out_keys)) != len(self.out_keys):
raise RuntimeError(f"Got multiple identical output keys: {self.out_keys}")
self.safe = safe
Expand All @@ -588,6 +583,21 @@ def __init__(
def spec(self):
return self._spec

@spec.setter
def spec(self, value: TensorSpec | None) -> None:
if value is None:
self._spec = None
return
if not isinstance(value, Composite) and len(self.out_keys) >= 1:
value = Composite({self.ou.key: value}, shape=value.shape[:-1])
ou_specs = {
self.ou.noise_key: None,
self.ou.steps_key: None,
}
value = value.clone()
value.update(ou_specs)
self._spec = value

def step(self, frames: int = 1) -> None:
"""Updates the eps noise factor.

Expand Down Expand Up @@ -799,8 +809,8 @@ def __call__(self, td: TensorDictBase) -> TensorDictBase:
def set_exploration_modules_spec_from_env(policy: nn.Module, env: EnvBase) -> None:
"""Sets exploration module specs from an environment action spec.

This is intended for cases where exploration modules (e.g. AdditiveGaussianModule)
are instantiated with ``spec=None`` and must be configured once the environment
This is intended for cases where exploration modules (e.g. AdditiveGaussianModule,
EGreedyModule, OrnsteinUhlenbeckProcessModule) are instantiated with ``spec=None`` and must be configured once the environment
is known (e.g. inside a collector).
"""
action_spec = (
Expand All @@ -809,6 +819,13 @@ def set_exploration_modules_spec_from_env(policy: nn.Module, env: EnvBase) -> No
else env.action_spec
)

exploration_modules = (
AdditiveGaussianModule,
EGreedyModule,
OrnsteinUhlenbeckProcessModule,
)

for submodule in policy.modules():
if isinstance(submodule, AdditiveGaussianModule) and submodule._spec is None:
submodule.spec = action_spec
if isinstance(submodule, exploration_modules):
if submodule.spec is None:
submodule.spec = action_spec