-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Open
Labels
Description
System Info
torch 2.8.0
transformers==4.56.2 or ransformers==4.57.3, both tested
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Please see repro code below. In this test case, it runs out of VRAM within only a few steps. In the real use case it takes about 50 iterations.
I hope it's still the same root cause, but I'm not entirely sure
import threading
from transformers import Qwen3ForCausalLM
from diffusers import ZImagePipeline
pipe = ZImagePipeline.from_pretrained(
"Tongyi-MAI/Z-Image-Turbo",
torch_dtype=torch.bfloat16,
)
text_encoder = pipe.text_encoder
text_encoder.to('cuda')
def run():
tokens = torch.zeros((1, 512), device='cuda', dtype=torch.int64)
tokens_attention_mask = torch.ones((1, 512), device='cuda')
tokens_attention_mask[:, 200:] = 0.0
i = 0
while True:
i += 1
print(i)
text_encoder_output = text_encoder(
tokens,
attention_mask=tokens_attention_mask,
output_hidden_states=True,
return_dict=True,
use_cache=False,
)
thread1 = threading.Thread(target=run)
thread2 = threading.Thread(target=run) # <--- comment this to see it working with no issues
thread1.start()
thread2.start() # <--- comment this to see it working with no issues
Expected behavior
see above