Skip to content

Commit ca2ca88

Browse files
committed
[Feature] Add _Contiguous module and reshape improvements to encoders/decoders
- Add _Contiguous helper module for torch.compile inductor compatibility - Refactor ObsEncoder.forward and ObsDecoder.forward to use flatten/unflatten with contiguous() - Add _maybe_record_function_decorator for profiling ghstack-source-id: 86ef6a9 Pull-Request: #3306
1 parent 5a09289 commit ca2ca88

1 file changed

Lines changed: 27 additions & 11 deletions

File tree

torchrl/modules/models/model_based.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,25 @@
1818

1919
# from torchrl.modules.tensordict_module.rnn import GRUCell
2020
from torch.nn import GRUCell
21-
from torchrl._utils import timeit
21+
from torchrl._utils import _maybe_record_function_decorator
2222

2323
from torchrl.modules.models.models import MLP
2424

2525
UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11")
2626

2727

28+
class _Contiguous(nn.Module):
29+
"""Helper module that makes a tensor contiguous.
30+
31+
This is useful inside nn.Sequential for torch.compile inductor compatibility.
32+
Inductor sometimes needs explicit contiguous() calls after reshape operations
33+
for efficient memory layout.
34+
"""
35+
36+
def forward(self, x):
37+
return x.contiguous()
38+
39+
2840
class DreamerActor(nn.Module):
2941
"""Dreamer actor network.
3042
@@ -128,16 +140,16 @@ def __init__(self, channels=32, num_layers=4, in_channels=None, depth=None, devi
128140
k = k * 2
129141
self.encoder = nn.Sequential(*layers)
130142

143+
@_maybe_record_function_decorator("ObsEncoder.forward")
131144
def forward(self, observation):
132145
*batch_sizes, C, H, W = observation.shape
133-
if len(batch_sizes) == 0:
134-
end_dim = 0
135-
else:
136-
end_dim = len(batch_sizes) - 1
137-
observation = torch.flatten(observation, start_dim=0, end_dim=end_dim)
146+
# Flatten all batch dimensions into one for conv
147+
# Use contiguous() for inductor compatibility
148+
observation = observation.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0).contiguous()
138149
obs_encoded = self.encoder(observation)
139-
latent = obs_encoded.reshape(*batch_sizes, -1)
140-
return latent
150+
# Reshape back to original batch dims + latent
151+
latent = obs_encoded.unflatten(0, batch_sizes) if batch_sizes else obs_encoded
152+
return latent.reshape(*batch_sizes, -1).contiguous()
141153

142154

143155
class ObsDecoder(nn.Module):
@@ -213,14 +225,18 @@ def __init__(self, channels=32, num_layers=4, kernel_sizes=None, latent_dim=None
213225
self.decoder = nn.Sequential(*layers)
214226
self._depth = channels
215227

228+
@_maybe_record_function_decorator("ObsDecoder.forward")
216229
def forward(self, state, rnn_hidden):
230+
# Concatenate and project to latent space
217231
latent = self.state_to_latent(torch.cat([state, rnn_hidden], dim=-1))
218232
*batch_sizes, D = latent.shape
219-
latent = latent.view(-1, D, 1, 1)
233+
# Flatten batch dimensions and reshape for conv
234+
latent = latent.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0).unsqueeze(-1).unsqueeze(-1).contiguous()
220235
obs_decoded = self.decoder(latent)
221236
_, C, H, W = obs_decoded.shape
222-
obs_decoded = obs_decoded.view(*batch_sizes, C, H, W)
223-
return obs_decoded
237+
# Unflatten back to original batch dims
238+
obs_decoded = obs_decoded.unflatten(0, batch_sizes) if batch_sizes else obs_decoded
239+
return obs_decoded.contiguous()
224240

225241

226242
class RSSMRollout(TensorDictModuleBase):

0 commit comments

Comments
 (0)