1+ import functools
12import logging
23from typing import Any , Dict , List , Optional , Tuple , Union
34
45import torch
6+ import torch .nn as nn
57from diffusers .models .attention_dispatch import dispatch_attention_fn
68from diffusers .models .attention_processor import Attention
79from diffusers .models .transformers .transformer_2d import Transformer2DModelOutput
810from diffusers .models .transformers .transformer_qwenimage import (
911 QwenDoubleStreamAttnProcessor2_0 ,
10- QwenEmbedRope ,
1112 QwenImageTransformer2DModel ,
1213)
1314from diffusers .utils .constants import USE_PEFT_BACKEND
1617logger = logging .getLogger (__name__ )
1718
1819
19- def qeff_apply_rotary_emb_qwen (
20- x : torch .Tensor ,
21- freqs_cis : Union [torch .Tensor , Tuple [torch .Tensor ]],
22- use_real : bool = True ,
23- use_real_unbind_dim : int = - 1 ,
24- ) -> Tuple [torch .Tensor , torch .Tensor ]:
20+ def qeff_apply_rotary_emb_qwen (x , freqs_cos , freqs_sin ):
2521 """
2622 Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
2723 to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
@@ -36,36 +32,156 @@ def qeff_apply_rotary_emb_qwen(
3632 Returns:
3733 Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
3834 """
39- x_ = x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 )
40- x_rotated_new = x_ .unbind (- 1 )
41- x_real = x_rotated_new [0 ]
42- x_imag = x_rotated_new [1 ]
43- freqs_cis = freqs_cis .reshape (freqs_cis .shape [0 ], - 1 )
44- freqs_cis = freqs_cis .view (freqs_cis .shape [0 ], freqs_cis .shape [- 1 ] // 2 , 2 )
35+ x_reshaped = x .float ().reshape (* x .shape [:- 1 ], - 1 , 2 ) # [B, S, H, D//2, 2]
36+ x1 = x_reshaped [..., 0 ] # [B, S, H, D//2]
37+ x2 = x_reshaped [..., 1 ] # [B, S, H, D//2]
38+
39+ # Reshape for broadcasting: [S, D//2] -> [1, S, 1, D//2]
40+ freqs_cos = freqs_cos .unsqueeze (0 ).unsqueeze (2 )
41+ freqs_sin = freqs_sin .unsqueeze (0 ).unsqueeze (2 )
42+
43+ # Apply rotation
44+ x_out1 = x1 * freqs_cos - x2 * freqs_sin # Real part
45+ x_out2 = x1 * freqs_sin + x2 * freqs_cos # Imaginary part
46+
47+ # Stack and reshape back
48+ x_out = torch .stack ([x_out1 , x_out2 ], dim = - 1 ) # [B, S, H, D//2, 2]
49+ x_out = x_out .flatten (- 2 ) # [B, S, H, D]
50+ return x_out .type_as (x )
51+
52+
53+ class QEffQwenEmbedRope (nn .Module ):
54+ def __init__ (self , theta : int , axes_dim : List [int ], scale_rope = False ):
55+ super ().__init__ ()
56+ self .theta = theta
57+ self .axes_dim = axes_dim
58+ self .scale_rope = scale_rope
59+ pos_index = torch .arange (4096 )
60+ neg_index = torch .arange (4096 ).flip (0 ) * - 1 - 1
61+
62+ # Store cos and sin separately instead of complex numbers
63+ pos_freqs_list = [
64+ self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
65+ self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
66+ self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
67+ ]
68+ self .pos_freqs_cos = torch .cat ([f [0 ] for f in pos_freqs_list ], dim = 1 )
69+ self .pos_freqs_sin = torch .cat ([f [1 ] for f in pos_freqs_list ], dim = 1 )
70+
71+ neg_freqs_list = [
72+ self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
73+ self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
74+ self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
75+ ]
76+ self .neg_freqs_cos = torch .cat ([f [0 ] for f in neg_freqs_list ], dim = 1 )
77+ self .neg_freqs_sin = torch .cat ([f [1 ] for f in neg_freqs_list ], dim = 1 )
78+
79+ self .rope_cache = {}
80+
81+ @functools .lru_cache (maxsize = None )
82+ def _compute_video_freqs (self , frame , height , width , idx = 0 ):
83+ seq_lens = frame * height * width
84+ freqs_pos_cos = self .pos_freqs_cos .split ([x // 2 for x in self .axes_dim ], dim = 1 )
85+ freqs_pos_sin = self .pos_freqs_sin .split ([x // 2 for x in self .axes_dim ], dim = 1 )
86+ freqs_neg_cos = self .neg_freqs_cos .split ([x // 2 for x in self .axes_dim ], dim = 1 )
87+ freqs_neg_sin = self .neg_freqs_sin .split ([x // 2 for x in self .axes_dim ], dim = 1 )
88+
89+ # Frame dimension
90+ freqs_frame_cos = freqs_pos_cos [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
91+ freqs_frame_sin = freqs_pos_sin [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
92+
93+ if self .scale_rope :
94+ freqs_height_cos = torch .cat (
95+ [freqs_neg_cos [1 ][- (height - height // 2 ) :], freqs_pos_cos [1 ][: height // 2 ]], dim = 0
96+ )
97+ freqs_height_sin = torch .cat (
98+ [freqs_neg_sin [1 ][- (height - height // 2 ) :], freqs_pos_sin [1 ][: height // 2 ]], dim = 0
99+ )
100+ freqs_height_cos = freqs_height_cos .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
101+ freqs_height_sin = freqs_height_sin .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
102+
103+ freqs_width_cos = torch .cat (
104+ [freqs_neg_cos [2 ][- (width - width // 2 ) :], freqs_pos_cos [2 ][: width // 2 ]], dim = 0
105+ )
106+ freqs_width_sin = torch .cat (
107+ [freqs_neg_sin [2 ][- (width - width // 2 ) :], freqs_pos_sin [2 ][: width // 2 ]], dim = 0
108+ )
109+ freqs_width_cos = freqs_width_cos .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
110+ freqs_width_sin = freqs_width_sin .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
111+ else :
112+ freqs_height_cos = freqs_pos_cos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
113+ freqs_height_sin = freqs_pos_sin [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
114+ freqs_width_cos = freqs_pos_cos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
115+ freqs_width_sin = freqs_pos_sin [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
116+
117+ freqs_cos = torch .cat ([freqs_frame_cos , freqs_height_cos , freqs_width_cos ], dim = - 1 ).reshape (seq_lens , - 1 )
118+ freqs_sin = torch .cat ([freqs_frame_sin , freqs_height_sin , freqs_width_sin ], dim = - 1 ).reshape (seq_lens , - 1 )
45119
46- freqs_cos = freqs_cis [..., 0 ].unsqueeze (1 ) # real part
47- freqs_sin = freqs_cis [..., 1 ].unsqueeze (1 ) # imag part
120+ return freqs_cos .clone ().contiguous (), freqs_sin .clone ().contiguous ()
48121
49- rotated_real = x_real * freqs_cos - x_imag * freqs_sin
50- rotated_imag = x_real * freqs_sin + x_imag * freqs_cos
51- x_out = torch .stack ((rotated_real , rotated_imag ), dim = - 1 )
52- x_out = x_out .reshape (* x .shape )
53- return x_out
122+ def forward (self , video_fhw , txt_seq_lens , device ):
123+ """
124+ Args:
125+ video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video
126+ txt_length: [bs] a list of 1 integers representing the length of the text
127+ Returns:
128+ Tuple of (vid_freqs_cos, vid_freqs_sin, txt_freqs_cos, txt_freqs_sin)
129+ """
130+ if self .pos_freqs_cos .device != device :
131+ self .pos_freqs_cos = self .pos_freqs_cos .to (device )
132+ self .pos_freqs_sin = self .pos_freqs_sin .to (device )
133+ self .neg_freqs_cos = self .neg_freqs_cos .to (device )
134+ self .neg_freqs_sin = self .neg_freqs_sin .to (device )
135+
136+ if isinstance (video_fhw , list ):
137+ video_fhw = video_fhw [0 ]
138+ if not isinstance (video_fhw , list ):
139+ video_fhw = [video_fhw ]
140+
141+ vid_freqs_cos_list = []
142+ vid_freqs_sin_list = []
143+ max_vid_index = 0
144+
145+ for idx , fhw in enumerate (video_fhw ):
146+ frame , height , width = fhw
147+ rope_key = f"{ idx } _{ height } _{ width } "
148+ if not torch .compiler .is_compiling ():
149+ if rope_key not in self .rope_cache :
150+ self .rope_cache [rope_key ] = self ._compute_video_freqs (frame , height , width , idx )
151+ video_freq_cos , video_freq_sin = self .rope_cache [rope_key ]
152+ else :
153+ video_freq_cos , video_freq_sin = self ._compute_video_freqs (frame , height , width , idx )
54154
155+ video_freq_cos = video_freq_cos .to (device )
156+ video_freq_sin = video_freq_sin .to (device )
157+ vid_freqs_cos_list .append (video_freq_cos )
158+ vid_freqs_sin_list .append (video_freq_sin )
159+
160+ if self .scale_rope :
161+ max_vid_index = max (height // 2 , width // 2 , max_vid_index )
162+ else :
163+ max_vid_index = max (height , width , max_vid_index )
164+
165+ max_len = max (txt_seq_lens )
166+ txt_freqs_cos = self .pos_freqs_cos [max_vid_index : max_vid_index + max_len , ...]
167+ txt_freqs_sin = self .pos_freqs_sin [max_vid_index : max_vid_index + max_len , ...]
168+
169+ vid_freqs_cos = torch .cat (vid_freqs_cos_list , dim = 0 )
170+ vid_freqs_sin = torch .cat (vid_freqs_sin_list , dim = 0 )
171+
172+ return vid_freqs_cos , vid_freqs_sin , txt_freqs_cos , txt_freqs_sin
55173
56- class QEffQwenEmbedRope (QwenEmbedRope ):
57174 def rope_params (self , index , dim , theta = 10000 ):
58175 """
59176 Args:
60177 index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
61178 """
62179 assert dim % 2 == 0
63180 freqs = torch .outer (index , 1.0 / torch .pow (theta , torch .arange (0 , dim , 2 ).to (torch .float32 ).div (dim )))
64-
65- real_part = torch .ones_like (freqs ) * torch .cos (freqs )
66- imag_part = torch .ones_like (freqs ) * torch .sin (freqs )
67- freqs = torch .stack ([real_part , imag_part ], dim = - 1 )
68- return freqs # [6032,64,2]
181+ # Return cos and sin separately instead of complex tensor
182+ freqs_cos = torch .cos (freqs )
183+ freqs_sin = torch .sin (freqs )
184+ return freqs_cos , freqs_sin
69185
70186
71187class QEffQwenImageTransformer2DModel (QwenImageTransformer2DModel ):
@@ -78,10 +194,11 @@ def forward(
78194 encoder_hidden_states : torch .Tensor = None ,
79195 encoder_hidden_states_mask : torch .Tensor = None ,
80196 timestep : torch .LongTensor = None ,
197+ frame : torch .Tensor = None ,
198+ height : torch .Tensor = None ,
199+ width : torch .Tensor = None ,
200+ txt_seq_lens : torch .Tensor = None ,
81201 img_shapes : Optional [List [Tuple [int , int , int ]]] = None ,
82- txt_seq_lens : Optional [List [int ]] = None ,
83- img_rotary_emb : Optional [torch .Tensor ] = None ,
84- text_rotary_emb : Optional [torch .Tensor ] = None ,
85202 guidance : torch .Tensor = None , # TODO: this should probably be removed
86203 attention_kwargs : Optional [Dict [str , Any ]] = None ,
87204 return_dict : bool = True ,
@@ -110,6 +227,22 @@ def forward(
110227 If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
111228 `tuple` where the first element is the sample tensor.
112229 """
230+ # breakpoint()
231+ # Convert scalar tensors to Python integers and create img_shapes list
232+ if isinstance (frame , torch .Tensor ):
233+ frame = frame .item () if frame .numel () == 1 else int (frame [0 ])
234+ if isinstance (height , torch .Tensor ):
235+ height = height .item () if height .numel () == 1 else int (height [0 ])
236+ if isinstance (width , torch .Tensor ):
237+ width = width .item () if width .numel () == 1 else int (width [0 ])
238+
239+ if not img_shapes :
240+ img_shapes = [(frame , height , width )]
241+
242+ # Convert txt_seq_lens to list if it's a tensor
243+ if isinstance (txt_seq_lens , torch .Tensor ):
244+ txt_seq_lens = txt_seq_lens .tolist ()
245+
113246 if attention_kwargs is not None :
114247 attention_kwargs = attention_kwargs .copy ()
115248 lora_scale = attention_kwargs .pop ("scale" , 1.0 )
@@ -139,7 +272,6 @@ def forward(
139272 else self .time_text_embed (timestep , guidance , hidden_states )
140273 )
141274 image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
142- # image_rotary_emb = (img_rotary_emb, text_rotary_emb)
143275
144276 for index_block , block in enumerate (self .transformer_blocks ):
145277 if torch .is_grad_enabled () and self .gradient_checkpointing :
@@ -192,23 +324,23 @@ def __call__(
192324 seq_txt = encoder_hidden_states .shape [1 ]
193325
194326 # Compute QKV for image stream (sample projections)
195- img_query = attn .to_q (hidden_states ) // 8
196- img_key = attn .to_k (hidden_states ) // 8
197- img_value = attn .to_v (hidden_states ) // 8
327+ img_query = attn .to_q (hidden_states )
328+ img_key = attn .to_k (hidden_states )
329+ img_value = attn .to_v (hidden_states )
198330
199331 # Compute QKV for text stream (context projections)
200- txt_query = attn .add_q_proj (encoder_hidden_states ) // 8
201- txt_key = attn .add_k_proj (encoder_hidden_states ) // 8
202- txt_value = attn .add_v_proj (encoder_hidden_states ) // 8
332+ txt_query = attn .add_q_proj (encoder_hidden_states )
333+ txt_key = attn .add_k_proj (encoder_hidden_states )
334+ txt_value = attn .add_v_proj (encoder_hidden_states )
203335
204336 # Reshape for multi-head attention
205- img_query = img_query .unflatten (- 1 , (attn .heads , - 1 )) // 2
206- img_key = img_key .unflatten (- 1 , (attn .heads , - 1 )) // 2
207- img_value = img_value .unflatten (- 1 , (attn .heads , - 1 )) // 2
337+ img_query = img_query .unflatten (- 1 , (attn .heads , - 1 ))
338+ img_key = img_key .unflatten (- 1 , (attn .heads , - 1 ))
339+ img_value = img_value .unflatten (- 1 , (attn .heads , - 1 ))
208340
209- txt_query = txt_query .unflatten (- 1 , (attn .heads , - 1 )) // 2
210- txt_key = txt_key .unflatten (- 1 , (attn .heads , - 1 )) // 2
211- txt_value = txt_value .unflatten (- 1 , (attn .heads , - 1 )) // 2
341+ txt_query = txt_query .unflatten (- 1 , (attn .heads , - 1 ))
342+ txt_key = txt_key .unflatten (- 1 , (attn .heads , - 1 ))
343+ txt_value = txt_value .unflatten (- 1 , (attn .heads , - 1 ))
212344
213345 # Apply QK normalization
214346 if attn .norm_q is not None :
@@ -222,11 +354,14 @@ def __call__(
222354
223355 # Apply RoPE
224356 if image_rotary_emb is not None :
225- img_freqs , txt_freqs = image_rotary_emb
226- img_query = qeff_apply_rotary_emb_qwen (img_query , img_freqs , use_real = False )
227- img_key = qeff_apply_rotary_emb_qwen (img_key , img_freqs , use_real = False )
228- txt_query = qeff_apply_rotary_emb_qwen (txt_query , txt_freqs , use_real = False )
229- txt_key = qeff_apply_rotary_emb_qwen (txt_key , txt_freqs , use_real = False )
357+ # breakpoint()
358+ # Unpack the 4 tensors (cos and sin for both img and txt)
359+ img_freqs_cos , img_freqs_sin , txt_freqs_cos , txt_freqs_sin = image_rotary_emb
360+
361+ img_query = qeff_apply_rotary_emb_qwen (img_query , img_freqs_cos , img_freqs_sin )
362+ img_key = qeff_apply_rotary_emb_qwen (img_key , img_freqs_cos , img_freqs_sin )
363+ txt_query = qeff_apply_rotary_emb_qwen (txt_query , txt_freqs_cos , txt_freqs_sin )
364+ txt_key = qeff_apply_rotary_emb_qwen (txt_key , txt_freqs_cos , txt_freqs_sin )
230365
231366 # Concatenate for joint attention
232367 # Order: [text, image]
@@ -254,10 +389,10 @@ def __call__(
254389 img_attn_output = joint_hidden_states [:, seq_txt :, :] # Image part
255390
256391 # Apply output projections
257- img_attn_output = attn .to_out [0 ](img_attn_output ) # scale back
392+ img_attn_output = attn .to_out [0 ](img_attn_output )
258393 if len (attn .to_out ) > 1 :
259394 img_attn_output = attn .to_out [1 ](img_attn_output ) # dropout
260395
261- txt_attn_output = attn .to_add_out (txt_attn_output ) # scale back
396+ txt_attn_output = attn .to_add_out (txt_attn_output )
262397
263398 return img_attn_output , txt_attn_output
0 commit comments