From ee54eacf0e6aa9331d6d0dae4e2f550dafe7716f Mon Sep 17 00:00:00 2001 From: WaterKnight1998 Date: Wed, 3 Dec 2025 20:33:30 +0100 Subject: [PATCH 1/2] Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT). --- .../models/transformers/transformer_prx.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index ccbc83ffca03..f39cfb106974 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -532,7 +532,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)` is the number of patches. """ - return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + b, c, h, w = img.shape + p = patch_size + + # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions + img = img.reshape(b, c, h // p, p, w // p, p) + + # Permute to (B, H//p, W//p, C, p, p) using einsum + # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width + img = torch.einsum("nchpwq->nhwcpq", img) + + # Flatten to (B, L, C * p * p) + img = img.reshape(b, -1, c * p * p) + return img def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: @@ -554,12 +566,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te Reconstructed image tensor of shape `(B, C, H, W)`. """ if isinstance(shape, tuple): - shape = shape[-2:] + h, w = shape[-2:] elif isinstance(shape, torch.Tensor): - shape = (int(shape[0]), int(shape[1])) + h, w = (int(shape[0]), int(shape[1])) else: raise NotImplementedError(f"shape type {type(shape)} not supported") - return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + b, l, d = seq.shape + p = patch_size + c = d // (p * p) + + # Reshape back to grid structure: (B, H//p, W//p, C, p, p) + seq = seq.reshape(b, h // p, w // p, c, p, p) + + # Permute back to image layout: (B, C, H//p, p, W//p, p) + # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width + seq = torch.einsum("nhwcpq->nchpwq", seq) + + # Final reshape to (B, C, H, W) + seq = seq.reshape(b, c, h, w) + return seq class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): From 7d4911c0f3341738eaf6c8f3492c3885590190ee Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 4 Dec 2025 01:47:19 +0000 Subject: [PATCH 2/2] Apply style fixes --- .../models/transformers/transformer_prx.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index f39cfb106974..63bfb4ea7d87 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -16,7 +16,6 @@ import torch from torch import nn -from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging @@ -534,14 +533,14 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: """ b, c, h, w = img.shape p = patch_size - + # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions img = img.reshape(b, c, h // p, p, w // p, p) - + # Permute to (B, H//p, W//p, C, p, p) using einsum # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width img = torch.einsum("nchpwq->nhwcpq", img) - + # Flatten to (B, L, C * p * p) img = img.reshape(b, -1, c * p * p) return img @@ -571,18 +570,18 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te h, w = (int(shape[0]), int(shape[1])) else: raise NotImplementedError(f"shape type {type(shape)} not supported") - + b, l, d = seq.shape p = patch_size c = d // (p * p) - + # Reshape back to grid structure: (B, H//p, W//p, C, p, p) seq = seq.reshape(b, h // p, w // p, c, p, p) - + # Permute back to image layout: (B, C, H//p, p, W//p, p) # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width seq = torch.einsum("nhwcpq->nchpwq", seq) - + # Final reshape to (B, C, H, W) seq = seq.reshape(b, c, h, w) return seq