-
Notifications
You must be signed in to change notification settings - Fork 6.6k
perf: optimize CausalConv3d for wan autoencoders #12800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
c8ef
wants to merge
2
commits into
huggingface:main
Choose a base branch
from
c8ef:wan
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
+10
−6
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Contributor
Author
|
testing script: import torch
import torch.nn as nn
import torch.nn.functional as F
import triton.testing
class CausalConv3d_A(nn.Conv3d):
"""
Implementation A: Fully explicit padding using F.pad
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (
self.padding[2],
self.padding[2],
self.padding[1],
self.padding[1],
2 * self.padding[0],
0,
)
self.padding = (0, 0, 0) # Reset internal padding to 0
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class CausalConv3d_B(nn.Conv3d):
"""
Implementation B: Explicit Temporal padding, Implicit Spatial padding
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.temporal_padding = 2 * self.padding[0]
# Keep spatial padding, remove temporal padding from conv layer
self.padding = (0, self.padding[1], self.padding[2])
def forward(self, x, cache_x=None):
b, c, t, h, w = x.size()
padding = self.temporal_padding
if cache_x is not None and self.temporal_padding > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding -= cache_x.shape[2]
# Manually pad time dimension
if padding > 0:
x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
return super().forward(x)
def setup_models(in_channels, out_channels, kernel_size, padding):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_a = CausalConv3d_A(in_channels, out_channels, kernel_size, padding=padding).to(device)
model_b = CausalConv3d_B(in_channels, out_channels, kernel_size, padding=padding).to(device)
model_b.load_state_dict(model_a.state_dict())
return model_a, model_b
def test_correctness():
print("\n=== Running Correctness Test ===")
B, C, T, H, W = 2, 32, 16, 64, 64
out_C = 64
kernel = 3
pad_val = 1 # resulting in causal pad of 2*1=2
model_a, model_b = setup_models(C, out_C, kernel, padding=(pad_val, pad_val, pad_val))
model_a.eval()
model_b.eval()
x = torch.randn(B, C, T, H, W, device="cuda")
# 1. Test without cache
with torch.no_grad():
out_a = model_a(x)
out_b = model_b(x)
try:
torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5)
print("[Pass] Outputs are numerically identical (No Cache).")
except AssertionError as e:
print("[Fail] Outputs differ!")
print(e)
return
cache = torch.randn(B, C, 2, H, W, device="cuda")
with torch.no_grad():
out_a_cache = model_a(x, cache_x=cache)
out_b_cache = model_b(x, cache_x=cache)
try:
torch.testing.assert_close(out_a_cache, out_b_cache, rtol=1e-5, atol=1e-5)
print("[Pass] Outputs are numerically identical (With Cache).")
except AssertionError as e:
print("[Fail] Outputs differ with cache!")
def benchmark_performance():
print("\n=== Running Performance Benchmark ===")
if not torch.cuda.is_available():
print("Skipping benchmark (CUDA not available)")
return
B, C, T, H, W = 4, 64, 32, 128, 128
out_C = 64
kernel = 3
# Padding set to (1,1,1), so T gets padded by 2, H/W by 1
model_a, model_b = setup_models(C, out_C, kernel, padding=(1, 1, 1))
x = torch.randn(B, C, T, H, W, device="cuda")
def run_a():
return model_a(x)
def run_b():
return model_b(x)
ms_a = triton.testing.do_bench(run_a, rep=100)
ms_b = triton.testing.do_bench(run_b, rep=100)
print(f"Implementation A (F.pad): {ms_a:.3f} ms")
print(f"Implementation B (Impl.H/W): {ms_b:.3f} ms")
diff = (ms_a - ms_b) / ms_a * 100
print(f"Implementation B is {diff:.2f}% faster")
if __name__ == "__main__":
test_correctness()
benchmark_performance()result: |
Contributor
Author
|
@sayakpaul @yiyixuxu @DN6 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
By optimizing CausalConv3d, this patch improves the overall performance of wan autoencoders by 5-10%.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.