Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,20 @@ def __init__(
)

# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
self.temporal_padding = 2 * self.padding[0]
# Keep spatial padding, remove temporal padding from conv layer
self.padding = (0, self.padding[1], self.padding[2])
Comment on lines +165 to +167
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit feels a little confusing to me, TBH.

We're assiging a scalar to temporal_padding and then a 3-member tuple to padding. I would have expected temporal_padding to be a 3-member tuple, unless I am missing something obvious.

Perhaps, you could help me understand this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the standard Conv module includes a padding attribute, it does not support causal padding in the temporal dimension. Previously, we manually removed all internal padding and relied explicitly on F.pad. In this implementation, we apply manual padding only to the temporal dimension when necessary, while retaining the module's native padding for spatial (H/W) dimensions. This could be faster because the internal padding may use better optimizations.


def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
b, c, _, h, w = x.size()
padding = self.temporal_padding
if cache_x is not None and self.temporal_padding > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
padding -= cache_x.shape[2]
# Manually pad time dimension
if padding > 0:
x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
return super().forward(x)


Expand Down