From 411cfee61955ef5eedbdd5afceb07e84e0f5f064 Mon Sep 17 00:00:00 2001 From: c8ef Date: Sat, 6 Dec 2025 15:16:50 +0800 Subject: [PATCH 1/2] perf: optimize CasualConv3d for wan autoencoders --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 761dff2dc61a..289a4c728ab7 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -163,15 +163,19 @@ def __init__( # Set up causal padding self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + # Keep spatial padding, remove temporal padding from conv layer + self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None): + b, c, _, h, w = x.size() padding = list(self._padding) if cache_x is not None and self._padding[4] > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) + # Manually pad time dimension + if padding[4] > 0: + x = torch.cat([x.new_zeros(b, c, padding[4], h, w), x], dim=2) return super().forward(x) From a9cdb5f9009721459e642e75ebae8140a6fc880d Mon Sep 17 00:00:00 2001 From: c8ef Date: Sat, 6 Dec 2025 19:06:06 +0800 Subject: [PATCH 2/2] revise --- .../models/autoencoders/autoencoder_kl_wan.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 289a4c728ab7..dae0750e9a85 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -162,20 +162,20 @@ def __init__( ) # Set up causal padding - self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.temporal_padding = 2 * self.padding[0] # Keep spatial padding, remove temporal padding from conv layer self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None): b, c, _, h, w = x.size() - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: + padding = self.temporal_padding + if cache_x is not None and self.temporal_padding > 0: cache_x = cache_x.to(x.device) x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] + padding -= cache_x.shape[2] # Manually pad time dimension - if padding[4] > 0: - x = torch.cat([x.new_zeros(b, c, padding[4], h, w), x], dim=2) + if padding > 0: + x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2) return super().forward(x)