diff --git a/test/test_tensordictmodules.py b/test/test_tensordictmodules.py index c7c6b10235b..85fa49083ca 100644 --- a/test/test_tensordictmodules.py +++ b/test/test_tensordictmodules.py @@ -893,7 +893,6 @@ def test_lstm_parallel_env( def _test_lstm_parallel_env( self, python_based, parallel, heterogeneous, within, maybe_fork_ParallelEnv ): - from torchrl.envs import InitTracker, TransformedEnv torch.manual_seed(0) @@ -1473,8 +1472,299 @@ def call(data, params): assert vmap(call, (None, 0))(data, params).shape == torch.Size((2, 50, 11)) -def test_safe_specs(): +class TestFunctorchIntegration: + """Test suite for functorch integration with TensorDictModule.""" + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_tdmodule_functional_params(self): + """Test that TDModule functional params can be extracted and used.""" + torch.manual_seed(0) + + net = nn.Linear(3, 4) + td_module = SafeModule( + module=net, + in_keys=["input"], + out_keys=["output"], + ) + + params = TensorDict.from_module(td_module) + + assert len(params.keys()) > 0 + for key in params.keys(True): + assert params.get(key) is not None + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_vmap_on_tdmodule_with_params(self): + """Test vmap on TensorDictModule with functional params.""" + torch.manual_seed(0) + + net = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 2)) + td_module = TensorDictModule( + module=net, + in_keys=["input"], + out_keys=["output"], + ) + + params = TensorDict.from_module(td_module) + params_expanded = params.expand(4, *params.shape) + + td = TensorDict({"input": torch.randn(2, 3)}, [2]) + def call(td, params): + with params.to_module(td_module): + return td_module(td.clone()) + + result = vmap(call, (None, 0))(td, params_expanded) + + assert result.shape == torch.Size([4, 2]) + assert "output" in result.keys() + assert result["output"].shape == torch.Size([4, 2, 2]) + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_nested_tdmodule_param_length(self): + """Test nested TDModules (ProbabilisticTensorDictModule) param length.""" + torch.manual_seed(0) + + net = nn.Sequential(nn.Linear(3, 4), NormalParamExtractor()) + base_module = TensorDictModule( + module=net, + in_keys=["input"], + out_keys=["loc", "scale"], + ) + + prob_module = SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ) + + td_sequence = SafeProbabilisticTensorDictSequential(base_module, prob_module) + + params = TensorDict.from_module(td_sequence) + + assert len(params.keys()) > 0 + + for key in params.keys(True): + assert params.get(key) is not None + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_nested_tdmodule_param_casting(self): + """Test nested TDModules param casting.""" + torch.manual_seed(0) + + net = nn.Sequential(nn.Linear(3, 4), NormalParamExtractor()) + base_module = TensorDictModule( + module=net, + in_keys=["input"], + out_keys=["loc", "scale"], + ) + + prob_module = SafeProbabilisticModule( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + ) + + td_sequence = SafeProbabilisticTensorDictSequential(base_module, prob_module) + + params = TensorDict.from_module(td_sequence) + + params_float64 = params.to(torch.float64) + + for key in params_float64.keys(True): + assert params_float64.get(key).dtype == torch.float64 + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_tdsequence_vmap_params(self): + """Test TDSequence param handling with vmap.""" + torch.manual_seed(0) + + td_module1 = SafeModule( + module=nn.Linear(3, 4), + in_keys=["input"], + out_keys=["hidden"], + ) + + td_module2 = SafeModule( + module=nn.Linear(4, 2), + in_keys=["hidden"], + out_keys=["output"], + ) + + td_sequence = SafeSequential(td_module1, td_module2) + + params = TensorDict.from_module(td_sequence) + + td = TensorDict({"input": torch.randn(2, 3)}, [2]) + with params.to_module(td_sequence): + result = td_sequence(td.clone()) + + assert "output" in result.keys() + assert result["output"].shape == torch.Size([2, 2]) + + params_expanded = params.expand(4, *params.shape) + + def call(td, params): + with params.to_module(td_sequence): + return td_sequence(td.clone()) + + result_vmap = vmap(call, (None, 0))(td, params_expanded) + + assert result_vmap.shape == torch.Size([4, 2]) + assert "output" in result_vmap.keys() + assert result_vmap["output"].shape == torch.Size([4, 2, 2]) + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_vmap_multiple_inputs(self): + """Test vmap with multiple inputs to module.""" + torch.manual_seed(0) + + class MultiInputModule(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(3, 4) + self.fc2 = nn.Linear(4, 4) + + def forward(self, x, y): + return self.fc1(x) + self.fc2(y) + + net = MultiInputModule() + td_module = TensorDictModule( + module=net, + in_keys=["x", "y"], + out_keys=["output"], + ) + + params = TensorDict.from_module(td_module) + params_expanded = params.expand(3, *params.shape) + + td = TensorDict( + { + "x": torch.randn(2, 3), + "y": torch.randn(2, 4), + }, + [2], + ) + + def call(td, params): + with params.to_module(td_module): + return td_module(td.clone()) + + result = vmap(call, (None, 0))(td, params_expanded) + + assert result.shape == torch.Size([3, 2]) + assert "output" in result.keys() + assert result["output"].shape == torch.Size([3, 2, 4]) + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_vmap_module_class(self): + """Test VmapModule class.""" + torch.manual_seed(0) + + td_module = TensorDictModule( + module=nn.Linear(3, 4), + in_keys=["input"], + out_keys=["output"], + ) + + vmapped_module = VmapModule(td_module, vmap_dim=0) + + td = TensorDict({"input": torch.randn(3, 3)}, [3]) + result = vmapped_module(td) + + assert result["output"].shape == torch.Size([3, 4]) + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_tdsequence_vmap(self): + """Test vmap on TDSequence.""" + torch.manual_seed(0) + + td_module1 = SafeModule( + module=nn.Linear(3, 4), + in_keys=["input"], + out_keys=["hidden"], + ) + + td_module2 = SafeModule( + module=nn.Linear(4, 2), + in_keys=["hidden"], + out_keys=["output"], + ) + + td_sequence = SafeSequential(td_module1, td_module2) + + params = TensorDict.from_module(td_sequence) + params_expanded = params.expand(4, *params.shape) + + td = TensorDict({"input": torch.randn(2, 3)}, [2]) + + def call(td, params): + with params.to_module(td_sequence): + return td_sequence(td.clone()) + + result = vmap(call, (None, 0))(td, params_expanded) + + assert result.shape == torch.Size([4, 2]) + assert "output" in result.keys() + assert result["output"].shape == torch.Size([4, 2, 2]) + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_tdmodule_functional_to_module(self): + """Test TDModule functional params with to_module context manager.""" + torch.manual_seed(0) + + net = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 2)) + td_module = TensorDictModule( + module=net, + in_keys=["input"], + out_keys=["output"], + ) + + params = TensorDict.from_module(td_module) + + td = TensorDict({"input": torch.randn(3, 3)}, [3]) + with params.to_module(td_module): + result = td_module(td.clone()) + + assert "output" in result.keys() + assert result["output"].shape == torch.Size([3, 2]) + + @pytest.mark.skipif(not _has_functorch, reason="functorch is required") + def test_nested_sequential_vmap(self): + """Test vmap on nested sequential modules.""" + torch.manual_seed(0) + + net1 = nn.Linear(3, 4) + net2 = nn.Linear(4, 2) + net3 = nn.Linear(2, 1) + + td_seq1 = TensorDictSequential( + TensorDictModule(net1, in_keys=["a"], out_keys=["b"]), + TensorDictModule(net2, in_keys=["b"], out_keys=["c"]), + ) + + td_seq2 = TensorDictSequential( + TensorDictModule(net3, in_keys=["c"], out_keys=["d"]), + ) + + td_sequence = TensorDictSequential(td_seq1, td_seq2) + + params = TensorDict.from_module(td_sequence) + params_expanded = params.expand(2, *params.shape) + + td = TensorDict({"a": torch.randn(3, 3)}, [3]) + + def call(td, params): + with params.to_module(td_sequence): + return td_sequence(td.clone()) + + result = vmap(call, (None, 0))(td, params_expanded) + + assert result.shape == torch.Size([2, 3]) + assert "d" in result.keys() + + +def test_safe_specs(): out_key = ("a", "b") spec = Composite(Composite({out_key: Unbounded()})) original_spec = spec.clone() @@ -1645,7 +1935,6 @@ def test_batched_actor_exceptions(self): @pytest.mark.parametrize("time_steps", [3, 5]) def test_batched_actor_simple(self, time_steps): - batch = 2 # The second env has frequent resets, the first none @@ -1686,7 +1975,6 @@ def test_batched_actor_simple(self, time_steps): def test_get_primers_from_module(): - # No primers in the model module = MLP(in_features=10, out_features=10, num_cells=[]) transform = get_primers_from_module(module)