Skip to content

Conversation

@c8ef
Copy link
Contributor

@c8ef c8ef commented Dec 6, 2025

What does this PR do?

By optimizing CausalConv3d, this patch improves the overall performance of wan autoencoders by 5-10%.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@c8ef
Copy link
Contributor Author

c8ef commented Dec 6, 2025

testing script:

import torch
import torch.nn as nn
import torch.nn.functional as F

import triton.testing


class CausalConv3d_A(nn.Conv3d):
    """
    Implementation A: Fully explicit padding using F.pad
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._padding = (
            self.padding[2],
            self.padding[2],
            self.padding[1],
            self.padding[1],
            2 * self.padding[0],
            0,
        )
        self.padding = (0, 0, 0)  # Reset internal padding to 0

    def forward(self, x, cache_x=None):
        padding = list(self._padding)
        if cache_x is not None and self._padding[4] > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding[4] -= cache_x.shape[2]
        x = F.pad(x, padding)
        return super().forward(x)


class CausalConv3d_B(nn.Conv3d):
    """
    Implementation B: Explicit Temporal padding, Implicit Spatial padding
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.temporal_padding = 2 * self.padding[0]
        # Keep spatial padding, remove temporal padding from conv layer
        self.padding = (0, self.padding[1], self.padding[2])

    def forward(self, x, cache_x=None):
        b, c, t, h, w = x.size()
        padding = self.temporal_padding
        if cache_x is not None and self.temporal_padding > 0:
            cache_x = cache_x.to(x.device)
            x = torch.cat([cache_x, x], dim=2)
            padding -= cache_x.shape[2]

        # Manually pad time dimension
        if padding > 0:
            x = torch.cat([x.new_zeros(b, c, padding, h, w), x], dim=2)
        return super().forward(x)


def setup_models(in_channels, out_channels, kernel_size, padding):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model_a = CausalConv3d_A(in_channels, out_channels, kernel_size, padding=padding).to(device)
    model_b = CausalConv3d_B(in_channels, out_channels, kernel_size, padding=padding).to(device)

    model_b.load_state_dict(model_a.state_dict())

    return model_a, model_b


def test_correctness():
    print("\n=== Running Correctness Test ===")
    B, C, T, H, W = 2, 32, 16, 64, 64
    out_C = 64
    kernel = 3
    pad_val = 1  # resulting in causal pad of 2*1=2

    model_a, model_b = setup_models(C, out_C, kernel, padding=(pad_val, pad_val, pad_val))
    model_a.eval()
    model_b.eval()

    x = torch.randn(B, C, T, H, W, device="cuda")

    # 1. Test without cache
    with torch.no_grad():
        out_a = model_a(x)
        out_b = model_b(x)

    try:
        torch.testing.assert_close(out_a, out_b, rtol=1e-5, atol=1e-5)
        print("[Pass] Outputs are numerically identical (No Cache).")
    except AssertionError as e:
        print("[Fail] Outputs differ!")
        print(e)
        return

    cache = torch.randn(B, C, 2, H, W, device="cuda")
    with torch.no_grad():
        out_a_cache = model_a(x, cache_x=cache)
        out_b_cache = model_b(x, cache_x=cache)

    try:
        torch.testing.assert_close(out_a_cache, out_b_cache, rtol=1e-5, atol=1e-5)
        print("[Pass] Outputs are numerically identical (With Cache).")
    except AssertionError as e:
        print("[Fail] Outputs differ with cache!")


def benchmark_performance():
    print("\n=== Running Performance Benchmark ===")
    if not torch.cuda.is_available():
        print("Skipping benchmark (CUDA not available)")
        return

    B, C, T, H, W = 4, 64, 32, 128, 128
    out_C = 64
    kernel = 3
    # Padding set to (1,1,1), so T gets padded by 2, H/W by 1
    model_a, model_b = setup_models(C, out_C, kernel, padding=(1, 1, 1))

    x = torch.randn(B, C, T, H, W, device="cuda")

    def run_a():
        return model_a(x)

    def run_b():
        return model_b(x)

    ms_a = triton.testing.do_bench(run_a, rep=100)
    ms_b = triton.testing.do_bench(run_b, rep=100)

    print(f"Implementation A (F.pad):   {ms_a:.3f} ms")
    print(f"Implementation B (Impl.H/W): {ms_b:.3f} ms")

    diff = (ms_a - ms_b) / ms_a * 100
    print(f"Implementation B is {diff:.2f}% faster")


if __name__ == "__main__":
    test_correctness()
    benchmark_performance()

result:

=== Running Correctness Test ===
[Pass] Outputs are numerically identical (No Cache).
[Pass] Outputs are numerically identical (With Cache).

=== Running Performance Benchmark ===
Implementation A (F.pad):   44.787 ms
Implementation B (Impl.H/W): 42.507 ms
Implementation B is 5.09% faster

@c8ef
Copy link
Contributor Author

c8ef commented Dec 6, 2025

@sayakpaul @yiyixuxu @DN6
Please take a look, thanks!

@c8ef c8ef changed the title perf: optimize CasualConv3d for wan autoencoders perf: optimize CausalConv3d for wan autoencoders Dec 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant