Skip to content

Conversation

@WaterKnight1998
Copy link

@WaterKnight1998 WaterKnight1998 commented Dec 3, 2025

What does this PR do?

This PR reimplements the img2seq and seq2img utilities in the PRX model to enable successful ONNX export and TensorRT inference.
Previously, these utilities depended on torch.nn.functional.fold and torch.nn.functional.unfold. During ONNX export, this resulted in the use of the Col2Im ONNX operator, which is not supported by the TensorRT ONNX operator set, causing engine building to fail.

The new implementation removes the dependency on Col2Im while preserving the expected functionality and output shapes, ensuring compatibility across ONNX runtimes and TensorRT. The approach follows the unpatchify logic used elsewhere in the library, for example in dit_transformer_2d.py.

Fixes #12786

Code to help you verify

Export model to onnx

import torch
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel


# Load model
model = PRXTransformer2DModel.from_pretrained("Photoroom/prx-1024-t2i-beta", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")

batch_size = 2 # for guidance
device = torch.device("cuda")

# Prepare input

dummy_inputs = {
    "hidden_states": torch.randn(batch_size, model.config.in_channels, 128, 128, dtype=torch.bfloat16, device=device), # Input latent image tensor of shape `(B, C, H, W)`.
    "timestep": torch.tensor([1.0] * batch_size, dtype=torch.float32, device=device), #  Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
    "encoder_hidden_states": torch.randn(batch_size, 256, model.config.context_in_dim, dtype=torch.bfloat16, device=device), # Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
    "attention_mask": torch.randint(0, 2, (batch_size, 256), dtype=torch.int64, device=device), #  Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
}

output_names = ["latent"]

batch_dim = torch.export.Dim("batch_size", min=1, max=16) 

dynamic_shapes = {
    "hidden_states": {0: batch_dim},
    "timestep": {0: batch_dim},
    "encoder_hidden_states": {0: batch_dim},
    "attention_mask": {0: batch_dim} 
}

# Export it
torch.onnx.export(
    model,
    tuple(dummy_inputs.values()), 
    "prx.onnx",
    input_names=list(dummy_inputs.keys()),
    output_names=output_names,
    dynamic_shapes=dynamic_shapes,
    opset_version=20,
    dynamo=True,
    do_constant_folding=False,
    verbose=True
)

Convert model to tensorrt

docker run --rm --gpus all -it \
    -v "$(pwd):/workspace" \
    -w /workspace \
    nvcr.io/nvidia/tensorrt:25.11-py3 \
    trtexec \
    --onnx=prx.onnx \
    --saveEngine=prx_engine.plan \
    --fp16 \
    --memPoolSize=workspace:16384 \
    --minShapes=hidden_states:2x16x128x128,timestep:2,encoder_hidden_states:2x256x2304,attention_mask:2x256 \
    --optShapes=hidden_states:2x16x128x128,timestep:2,encoder_hidden_states:2x256x2304,attention_mask:2x256 \
    --maxShapes=hidden_states:2x16x128x128,timestep:2,encoder_hidden_states:2x256x2304,attention_mask:2x256 \
    --verbose

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @DavidBert @qgallouedec

@sayakpaul
Copy link
Member

@DavidBert could you review this?

@WaterKnight1998 I am not sure why you tagged Quentin on this PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes seem safe for me to merge!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 4, 2025

Style bot fixed some files and pushed the changes.

Copy link
Contributor

@DavidBert DavidBert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PRX] Improving export to Tensorrt

4 participants