diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 761dff2dc61a..dae0750e9a85 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -162,16 +162,20 @@ def __init__( ) # Set up causal padding - self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + self.temporal_padding = 2 * self.padding[0] + # Keep spatial padding, remove temporal padding from conv layer + self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: + b, c, _, h, w = x.size() + padding = self.temporal_padding + if cache_x is not None and self.temporal_padding > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) + padding -= cache_x.shape[2] + # Manually pad time dimension + if padding > 0: + x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2) return super().forward(x)