|
18 | 18 |
|
19 | 19 | # from torchrl.modules.tensordict_module.rnn import GRUCell |
20 | 20 | from torch.nn import GRUCell |
21 | | -from torchrl._utils import timeit |
| 21 | +from torchrl._utils import _maybe_record_function_decorator |
22 | 22 |
|
23 | 23 | from torchrl.modules.models.models import MLP |
24 | 24 |
|
25 | 25 | UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") |
26 | 26 |
|
27 | 27 |
|
| 28 | +class _Contiguous(nn.Module): |
| 29 | + """Helper module that makes a tensor contiguous. |
| 30 | +
|
| 31 | + This is useful inside nn.Sequential for torch.compile inductor compatibility. |
| 32 | + Inductor sometimes needs explicit contiguous() calls after reshape operations |
| 33 | + for efficient memory layout. |
| 34 | + """ |
| 35 | + |
| 36 | + def forward(self, x): |
| 37 | + return x.contiguous() |
| 38 | + |
| 39 | + |
28 | 40 | class DreamerActor(nn.Module): |
29 | 41 | """Dreamer actor network. |
30 | 42 |
|
@@ -128,16 +140,16 @@ def __init__(self, channels=32, num_layers=4, in_channels=None, depth=None, devi |
128 | 140 | k = k * 2 |
129 | 141 | self.encoder = nn.Sequential(*layers) |
130 | 142 |
|
| 143 | + @_maybe_record_function_decorator("ObsEncoder.forward") |
131 | 144 | def forward(self, observation): |
132 | 145 | *batch_sizes, C, H, W = observation.shape |
133 | | - if len(batch_sizes) == 0: |
134 | | - end_dim = 0 |
135 | | - else: |
136 | | - end_dim = len(batch_sizes) - 1 |
137 | | - observation = torch.flatten(observation, start_dim=0, end_dim=end_dim) |
| 146 | + # Flatten all batch dimensions into one for conv |
| 147 | + # Use contiguous() for inductor compatibility |
| 148 | + observation = observation.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0).contiguous() |
138 | 149 | obs_encoded = self.encoder(observation) |
139 | | - latent = obs_encoded.reshape(*batch_sizes, -1) |
140 | | - return latent |
| 150 | + # Reshape back to original batch dims + latent |
| 151 | + latent = obs_encoded.unflatten(0, batch_sizes) if batch_sizes else obs_encoded |
| 152 | + return latent.reshape(*batch_sizes, -1).contiguous() |
141 | 153 |
|
142 | 154 |
|
143 | 155 | class ObsDecoder(nn.Module): |
@@ -213,14 +225,18 @@ def __init__(self, channels=32, num_layers=4, kernel_sizes=None, latent_dim=None |
213 | 225 | self.decoder = nn.Sequential(*layers) |
214 | 226 | self._depth = channels |
215 | 227 |
|
| 228 | + @_maybe_record_function_decorator("ObsDecoder.forward") |
216 | 229 | def forward(self, state, rnn_hidden): |
| 230 | + # Concatenate and project to latent space |
217 | 231 | latent = self.state_to_latent(torch.cat([state, rnn_hidden], dim=-1)) |
218 | 232 | *batch_sizes, D = latent.shape |
219 | | - latent = latent.view(-1, D, 1, 1) |
| 233 | + # Flatten batch dimensions and reshape for conv |
| 234 | + latent = latent.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0).unsqueeze(-1).unsqueeze(-1).contiguous() |
220 | 235 | obs_decoded = self.decoder(latent) |
221 | 236 | _, C, H, W = obs_decoded.shape |
222 | | - obs_decoded = obs_decoded.view(*batch_sizes, C, H, W) |
223 | | - return obs_decoded |
| 237 | + # Unflatten back to original batch dims |
| 238 | + obs_decoded = obs_decoded.unflatten(0, batch_sizes) if batch_sizes else obs_decoded |
| 239 | + return obs_decoded.contiguous() |
224 | 240 |
|
225 | 241 |
|
226 | 242 | class RSSMRollout(TensorDictModuleBase): |
|
0 commit comments