Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 292 additions & 4 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,6 @@ def test_lstm_parallel_env(
def _test_lstm_parallel_env(
self, python_based, parallel, heterogeneous, within, maybe_fork_ParallelEnv
):

from torchrl.envs import InitTracker, TransformedEnv

torch.manual_seed(0)
Expand Down Expand Up @@ -1473,8 +1472,299 @@ def call(data, params):
assert vmap(call, (None, 0))(data, params).shape == torch.Size((2, 50, 11))


def test_safe_specs():
class TestFunctorchIntegration:
"""Test suite for functorch integration with TensorDictModule."""

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_tdmodule_functional_params(self):
"""Test that TDModule functional params can be extracted and used."""
torch.manual_seed(0)

net = nn.Linear(3, 4)
td_module = SafeModule(
module=net,
in_keys=["input"],
out_keys=["output"],
)

params = TensorDict.from_module(td_module)

assert len(params.keys()) > 0
for key in params.keys(True):
assert params.get(key) is not None

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_vmap_on_tdmodule_with_params(self):
"""Test vmap on TensorDictModule with functional params."""
torch.manual_seed(0)

net = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 2))
td_module = TensorDictModule(
module=net,
in_keys=["input"],
out_keys=["output"],
)

params = TensorDict.from_module(td_module)
params_expanded = params.expand(4, *params.shape)

td = TensorDict({"input": torch.randn(2, 3)}, [2])

def call(td, params):
with params.to_module(td_module):
return td_module(td.clone())

result = vmap(call, (None, 0))(td, params_expanded)

assert result.shape == torch.Size([4, 2])
assert "output" in result.keys()
assert result["output"].shape == torch.Size([4, 2, 2])

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_nested_tdmodule_param_length(self):
"""Test nested TDModules (ProbabilisticTensorDictModule) param length."""
torch.manual_seed(0)

net = nn.Sequential(nn.Linear(3, 4), NormalParamExtractor())
base_module = TensorDictModule(
module=net,
in_keys=["input"],
out_keys=["loc", "scale"],
)

prob_module = SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
)

td_sequence = SafeProbabilisticTensorDictSequential(base_module, prob_module)

params = TensorDict.from_module(td_sequence)

assert len(params.keys()) > 0

for key in params.keys(True):
assert params.get(key) is not None

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_nested_tdmodule_param_casting(self):
"""Test nested TDModules param casting."""
torch.manual_seed(0)

net = nn.Sequential(nn.Linear(3, 4), NormalParamExtractor())
base_module = TensorDictModule(
module=net,
in_keys=["input"],
out_keys=["loc", "scale"],
)

prob_module = SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
)

td_sequence = SafeProbabilisticTensorDictSequential(base_module, prob_module)

params = TensorDict.from_module(td_sequence)

params_float64 = params.to(torch.float64)

for key in params_float64.keys(True):
assert params_float64.get(key).dtype == torch.float64

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_tdsequence_vmap_params(self):
"""Test TDSequence param handling with vmap."""
torch.manual_seed(0)

td_module1 = SafeModule(
module=nn.Linear(3, 4),
in_keys=["input"],
out_keys=["hidden"],
)

td_module2 = SafeModule(
module=nn.Linear(4, 2),
in_keys=["hidden"],
out_keys=["output"],
)

td_sequence = SafeSequential(td_module1, td_module2)

params = TensorDict.from_module(td_sequence)

td = TensorDict({"input": torch.randn(2, 3)}, [2])
with params.to_module(td_sequence):
result = td_sequence(td.clone())

assert "output" in result.keys()
assert result["output"].shape == torch.Size([2, 2])

params_expanded = params.expand(4, *params.shape)

def call(td, params):
with params.to_module(td_sequence):
return td_sequence(td.clone())

result_vmap = vmap(call, (None, 0))(td, params_expanded)

assert result_vmap.shape == torch.Size([4, 2])
assert "output" in result_vmap.keys()
assert result_vmap["output"].shape == torch.Size([4, 2, 2])

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_vmap_multiple_inputs(self):
"""Test vmap with multiple inputs to module."""
torch.manual_seed(0)

class MultiInputModule(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(4, 4)

def forward(self, x, y):
return self.fc1(x) + self.fc2(y)

net = MultiInputModule()
td_module = TensorDictModule(
module=net,
in_keys=["x", "y"],
out_keys=["output"],
)

params = TensorDict.from_module(td_module)
params_expanded = params.expand(3, *params.shape)

td = TensorDict(
{
"x": torch.randn(2, 3),
"y": torch.randn(2, 4),
},
[2],
)

def call(td, params):
with params.to_module(td_module):
return td_module(td.clone())

result = vmap(call, (None, 0))(td, params_expanded)

assert result.shape == torch.Size([3, 2])
assert "output" in result.keys()
assert result["output"].shape == torch.Size([3, 2, 4])

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_vmap_module_class(self):
"""Test VmapModule class."""
torch.manual_seed(0)

td_module = TensorDictModule(
module=nn.Linear(3, 4),
in_keys=["input"],
out_keys=["output"],
)

vmapped_module = VmapModule(td_module, vmap_dim=0)

td = TensorDict({"input": torch.randn(3, 3)}, [3])
result = vmapped_module(td)

assert result["output"].shape == torch.Size([3, 4])

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_tdsequence_vmap(self):
"""Test vmap on TDSequence."""
torch.manual_seed(0)

td_module1 = SafeModule(
module=nn.Linear(3, 4),
in_keys=["input"],
out_keys=["hidden"],
)

td_module2 = SafeModule(
module=nn.Linear(4, 2),
in_keys=["hidden"],
out_keys=["output"],
)

td_sequence = SafeSequential(td_module1, td_module2)

params = TensorDict.from_module(td_sequence)
params_expanded = params.expand(4, *params.shape)

td = TensorDict({"input": torch.randn(2, 3)}, [2])

def call(td, params):
with params.to_module(td_sequence):
return td_sequence(td.clone())

result = vmap(call, (None, 0))(td, params_expanded)

assert result.shape == torch.Size([4, 2])
assert "output" in result.keys()
assert result["output"].shape == torch.Size([4, 2, 2])

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_tdmodule_functional_to_module(self):
"""Test TDModule functional params with to_module context manager."""
torch.manual_seed(0)

net = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 2))
td_module = TensorDictModule(
module=net,
in_keys=["input"],
out_keys=["output"],
)

params = TensorDict.from_module(td_module)

td = TensorDict({"input": torch.randn(3, 3)}, [3])
with params.to_module(td_module):
result = td_module(td.clone())

assert "output" in result.keys()
assert result["output"].shape == torch.Size([3, 2])

@pytest.mark.skipif(not _has_functorch, reason="functorch is required")
def test_nested_sequential_vmap(self):
"""Test vmap on nested sequential modules."""
torch.manual_seed(0)

net1 = nn.Linear(3, 4)
net2 = nn.Linear(4, 2)
net3 = nn.Linear(2, 1)

td_seq1 = TensorDictSequential(
TensorDictModule(net1, in_keys=["a"], out_keys=["b"]),
TensorDictModule(net2, in_keys=["b"], out_keys=["c"]),
)

td_seq2 = TensorDictSequential(
TensorDictModule(net3, in_keys=["c"], out_keys=["d"]),
)

td_sequence = TensorDictSequential(td_seq1, td_seq2)

params = TensorDict.from_module(td_sequence)
params_expanded = params.expand(2, *params.shape)

td = TensorDict({"a": torch.randn(3, 3)}, [3])

def call(td, params):
with params.to_module(td_sequence):
return td_sequence(td.clone())

result = vmap(call, (None, 0))(td, params_expanded)

assert result.shape == torch.Size([2, 3])
assert "d" in result.keys()


def test_safe_specs():
out_key = ("a", "b")
spec = Composite(Composite({out_key: Unbounded()}))
original_spec = spec.clone()
Expand Down Expand Up @@ -1645,7 +1935,6 @@ def test_batched_actor_exceptions(self):

@pytest.mark.parametrize("time_steps", [3, 5])
def test_batched_actor_simple(self, time_steps):

batch = 2

# The second env has frequent resets, the first none
Expand Down Expand Up @@ -1686,7 +1975,6 @@ def test_batched_actor_simple(self, time_steps):


def test_get_primers_from_module():

# No primers in the model
module = MLP(in_features=10, out_features=10, num_cells=[])
transform = get_primers_from_module(module)
Expand Down
Loading