Skip to content

Commit ebaa4d7

Browse files
committed
Moved Rotary Embed inside
Signed-off-by: Dipankar Sarkar <quic_dipankar@quicinc.com>
1 parent c5d7fcd commit ebaa4d7

File tree

7 files changed

+205
-87
lines changed

7 files changed

+205
-87
lines changed

QEfficient/diffusers/models/attention_processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
import torch
1111
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
1212

13+
from QEfficient.diffusers.models.transformers.transformer_qwenimage import QEffQwenDoubleStreamAttnProcessor2_0
14+
1315

1416
class QEffAttention(Attention):
1517
def __qeff_init__(self):
16-
processor = QEffJointAttnProcessor2_0()
18+
# breakpoint()
19+
processor = QEffQwenDoubleStreamAttnProcessor2_0()
1720
self.processor = processor
1821
processor.query_block_size = 64
1922

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
FluxTransformerBlock,
1818
)
1919
from diffusers.models.transformers.transformer_qwenimage import (
20-
QwenEmbedRope,
20+
QwenDoubleStreamAttnProcessor2_0,
2121
QwenImageTransformer2DModel,
2222
)
2323
from torch import nn
@@ -44,7 +44,6 @@
4444
)
4545
from QEfficient.diffusers.models.transformers.transformer_qwenimage import (
4646
QEffQwenDoubleStreamAttnProcessor2_0,
47-
QEffQwenEmbedRope,
4847
QEffQwenImageTransformer2DModel,
4948
)
5049

@@ -71,9 +70,8 @@ class AttentionTransform(ModuleMappingTransform):
7170
FluxTransformer2DModel: QEffFluxTransformer2DModel,
7271
FluxAttention: QEffFluxAttention,
7372
FluxAttnProcessor: QEffFluxAttnProcessor,
74-
QwenEmbedRope: QEffQwenEmbedRope,
7573
QwenImageTransformer2DModel: QEffQwenImageTransformer2DModel,
76-
QEffQwenDoubleStreamAttnProcessor2_0: QEffQwenDoubleStreamAttnProcessor2_0,
74+
QwenDoubleStreamAttnProcessor2_0: QEffQwenDoubleStreamAttnProcessor2_0,
7775
}
7876

7977
@classmethod

QEfficient/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 184 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import functools
12
import logging
23
from typing import Any, Dict, List, Optional, Tuple, Union
34

45
import torch
6+
import torch.nn as nn
57
from diffusers.models.attention_dispatch import dispatch_attention_fn
68
from diffusers.models.attention_processor import Attention
79
from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput
810
from diffusers.models.transformers.transformer_qwenimage import (
911
QwenDoubleStreamAttnProcessor2_0,
10-
QwenEmbedRope,
1112
QwenImageTransformer2DModel,
1213
)
1314
from diffusers.utils.constants import USE_PEFT_BACKEND
@@ -16,12 +17,7 @@
1617
logger = logging.getLogger(__name__)
1718

1819

19-
def qeff_apply_rotary_emb_qwen(
20-
x: torch.Tensor,
21-
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
22-
use_real: bool = True,
23-
use_real_unbind_dim: int = -1,
24-
) -> Tuple[torch.Tensor, torch.Tensor]:
20+
def qeff_apply_rotary_emb_qwen(x, freqs_cos, freqs_sin):
2521
"""
2622
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
2723
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
@@ -36,36 +32,156 @@ def qeff_apply_rotary_emb_qwen(
3632
Returns:
3733
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
3834
"""
39-
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
40-
x_rotated_new = x_.unbind(-1)
41-
x_real = x_rotated_new[0]
42-
x_imag = x_rotated_new[1]
43-
freqs_cis = freqs_cis.reshape(freqs_cis.shape[0], -1)
44-
freqs_cis = freqs_cis.view(freqs_cis.shape[0], freqs_cis.shape[-1] // 2, 2)
35+
x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2]
36+
x1 = x_reshaped[..., 0] # [B, S, H, D//2]
37+
x2 = x_reshaped[..., 1] # [B, S, H, D//2]
38+
39+
# Reshape for broadcasting: [S, D//2] -> [1, S, 1, D//2]
40+
freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2)
41+
freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2)
42+
43+
# Apply rotation
44+
x_out1 = x1 * freqs_cos - x2 * freqs_sin # Real part
45+
x_out2 = x1 * freqs_sin + x2 * freqs_cos # Imaginary part
46+
47+
# Stack and reshape back
48+
x_out = torch.stack([x_out1, x_out2], dim=-1) # [B, S, H, D//2, 2]
49+
x_out = x_out.flatten(-2) # [B, S, H, D]
50+
return x_out.type_as(x)
51+
52+
53+
class QEffQwenEmbedRope(nn.Module):
54+
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
55+
super().__init__()
56+
self.theta = theta
57+
self.axes_dim = axes_dim
58+
self.scale_rope = scale_rope
59+
pos_index = torch.arange(4096)
60+
neg_index = torch.arange(4096).flip(0) * -1 - 1
61+
62+
# Store cos and sin separately instead of complex numbers
63+
pos_freqs_list = [
64+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
65+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
66+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
67+
]
68+
self.pos_freqs_cos = torch.cat([f[0] for f in pos_freqs_list], dim=1)
69+
self.pos_freqs_sin = torch.cat([f[1] for f in pos_freqs_list], dim=1)
70+
71+
neg_freqs_list = [
72+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
73+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
74+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
75+
]
76+
self.neg_freqs_cos = torch.cat([f[0] for f in neg_freqs_list], dim=1)
77+
self.neg_freqs_sin = torch.cat([f[1] for f in neg_freqs_list], dim=1)
78+
79+
self.rope_cache = {}
80+
81+
@functools.lru_cache(maxsize=None)
82+
def _compute_video_freqs(self, frame, height, width, idx=0):
83+
seq_lens = frame * height * width
84+
freqs_pos_cos = self.pos_freqs_cos.split([x // 2 for x in self.axes_dim], dim=1)
85+
freqs_pos_sin = self.pos_freqs_sin.split([x // 2 for x in self.axes_dim], dim=1)
86+
freqs_neg_cos = self.neg_freqs_cos.split([x // 2 for x in self.axes_dim], dim=1)
87+
freqs_neg_sin = self.neg_freqs_sin.split([x // 2 for x in self.axes_dim], dim=1)
88+
89+
# Frame dimension
90+
freqs_frame_cos = freqs_pos_cos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
91+
freqs_frame_sin = freqs_pos_sin[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
92+
93+
if self.scale_rope:
94+
freqs_height_cos = torch.cat(
95+
[freqs_neg_cos[1][-(height - height // 2) :], freqs_pos_cos[1][: height // 2]], dim=0
96+
)
97+
freqs_height_sin = torch.cat(
98+
[freqs_neg_sin[1][-(height - height // 2) :], freqs_pos_sin[1][: height // 2]], dim=0
99+
)
100+
freqs_height_cos = freqs_height_cos.view(1, height, 1, -1).expand(frame, height, width, -1)
101+
freqs_height_sin = freqs_height_sin.view(1, height, 1, -1).expand(frame, height, width, -1)
102+
103+
freqs_width_cos = torch.cat(
104+
[freqs_neg_cos[2][-(width - width // 2) :], freqs_pos_cos[2][: width // 2]], dim=0
105+
)
106+
freqs_width_sin = torch.cat(
107+
[freqs_neg_sin[2][-(width - width // 2) :], freqs_pos_sin[2][: width // 2]], dim=0
108+
)
109+
freqs_width_cos = freqs_width_cos.view(1, 1, width, -1).expand(frame, height, width, -1)
110+
freqs_width_sin = freqs_width_sin.view(1, 1, width, -1).expand(frame, height, width, -1)
111+
else:
112+
freqs_height_cos = freqs_pos_cos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
113+
freqs_height_sin = freqs_pos_sin[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
114+
freqs_width_cos = freqs_pos_cos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
115+
freqs_width_sin = freqs_pos_sin[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
116+
117+
freqs_cos = torch.cat([freqs_frame_cos, freqs_height_cos, freqs_width_cos], dim=-1).reshape(seq_lens, -1)
118+
freqs_sin = torch.cat([freqs_frame_sin, freqs_height_sin, freqs_width_sin], dim=-1).reshape(seq_lens, -1)
45119

46-
freqs_cos = freqs_cis[..., 0].unsqueeze(1) # real part
47-
freqs_sin = freqs_cis[..., 1].unsqueeze(1) # imag part
120+
return freqs_cos.clone().contiguous(), freqs_sin.clone().contiguous()
48121

49-
rotated_real = x_real * freqs_cos - x_imag * freqs_sin
50-
rotated_imag = x_real * freqs_sin + x_imag * freqs_cos
51-
x_out = torch.stack((rotated_real, rotated_imag), dim=-1)
52-
x_out = x_out.reshape(*x.shape)
53-
return x_out
122+
def forward(self, video_fhw, txt_seq_lens, device):
123+
"""
124+
Args:
125+
video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video
126+
txt_length: [bs] a list of 1 integers representing the length of the text
127+
Returns:
128+
Tuple of (vid_freqs_cos, vid_freqs_sin, txt_freqs_cos, txt_freqs_sin)
129+
"""
130+
if self.pos_freqs_cos.device != device:
131+
self.pos_freqs_cos = self.pos_freqs_cos.to(device)
132+
self.pos_freqs_sin = self.pos_freqs_sin.to(device)
133+
self.neg_freqs_cos = self.neg_freqs_cos.to(device)
134+
self.neg_freqs_sin = self.neg_freqs_sin.to(device)
135+
136+
if isinstance(video_fhw, list):
137+
video_fhw = video_fhw[0]
138+
if not isinstance(video_fhw, list):
139+
video_fhw = [video_fhw]
140+
141+
vid_freqs_cos_list = []
142+
vid_freqs_sin_list = []
143+
max_vid_index = 0
144+
145+
for idx, fhw in enumerate(video_fhw):
146+
frame, height, width = fhw
147+
rope_key = f"{idx}_{height}_{width}"
148+
if not torch.compiler.is_compiling():
149+
if rope_key not in self.rope_cache:
150+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
151+
video_freq_cos, video_freq_sin = self.rope_cache[rope_key]
152+
else:
153+
video_freq_cos, video_freq_sin = self._compute_video_freqs(frame, height, width, idx)
54154

155+
video_freq_cos = video_freq_cos.to(device)
156+
video_freq_sin = video_freq_sin.to(device)
157+
vid_freqs_cos_list.append(video_freq_cos)
158+
vid_freqs_sin_list.append(video_freq_sin)
159+
160+
if self.scale_rope:
161+
max_vid_index = max(height // 2, width // 2, max_vid_index)
162+
else:
163+
max_vid_index = max(height, width, max_vid_index)
164+
165+
max_len = max(txt_seq_lens)
166+
txt_freqs_cos = self.pos_freqs_cos[max_vid_index : max_vid_index + max_len, ...]
167+
txt_freqs_sin = self.pos_freqs_sin[max_vid_index : max_vid_index + max_len, ...]
168+
169+
vid_freqs_cos = torch.cat(vid_freqs_cos_list, dim=0)
170+
vid_freqs_sin = torch.cat(vid_freqs_sin_list, dim=0)
171+
172+
return vid_freqs_cos, vid_freqs_sin, txt_freqs_cos, txt_freqs_sin
55173

56-
class QEffQwenEmbedRope(QwenEmbedRope):
57174
def rope_params(self, index, dim, theta=10000):
58175
"""
59176
Args:
60177
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
61178
"""
62179
assert dim % 2 == 0
63180
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
64-
65-
real_part = torch.ones_like(freqs) * torch.cos(freqs)
66-
imag_part = torch.ones_like(freqs) * torch.sin(freqs)
67-
freqs = torch.stack([real_part, imag_part], dim=-1)
68-
return freqs # [6032,64,2]
181+
# Return cos and sin separately instead of complex tensor
182+
freqs_cos = torch.cos(freqs)
183+
freqs_sin = torch.sin(freqs)
184+
return freqs_cos, freqs_sin
69185

70186

71187
class QEffQwenImageTransformer2DModel(QwenImageTransformer2DModel):
@@ -78,10 +194,11 @@ def forward(
78194
encoder_hidden_states: torch.Tensor = None,
79195
encoder_hidden_states_mask: torch.Tensor = None,
80196
timestep: torch.LongTensor = None,
197+
frame: torch.Tensor = None,
198+
height: torch.Tensor = None,
199+
width: torch.Tensor = None,
200+
txt_seq_lens: torch.Tensor = None,
81201
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
82-
txt_seq_lens: Optional[List[int]] = None,
83-
img_rotary_emb: Optional[torch.Tensor] = None,
84-
text_rotary_emb: Optional[torch.Tensor] = None,
85202
guidance: torch.Tensor = None, # TODO: this should probably be removed
86203
attention_kwargs: Optional[Dict[str, Any]] = None,
87204
return_dict: bool = True,
@@ -110,6 +227,22 @@ def forward(
110227
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
111228
`tuple` where the first element is the sample tensor.
112229
"""
230+
# breakpoint()
231+
# Convert scalar tensors to Python integers and create img_shapes list
232+
if isinstance(frame, torch.Tensor):
233+
frame = frame.item() if frame.numel() == 1 else int(frame[0])
234+
if isinstance(height, torch.Tensor):
235+
height = height.item() if height.numel() == 1 else int(height[0])
236+
if isinstance(width, torch.Tensor):
237+
width = width.item() if width.numel() == 1 else int(width[0])
238+
239+
if not img_shapes:
240+
img_shapes = [(frame, height, width)]
241+
242+
# Convert txt_seq_lens to list if it's a tensor
243+
if isinstance(txt_seq_lens, torch.Tensor):
244+
txt_seq_lens = txt_seq_lens.tolist()
245+
113246
if attention_kwargs is not None:
114247
attention_kwargs = attention_kwargs.copy()
115248
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -139,7 +272,6 @@ def forward(
139272
else self.time_text_embed(timestep, guidance, hidden_states)
140273
)
141274
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
142-
# image_rotary_emb = (img_rotary_emb, text_rotary_emb)
143275

144276
for index_block, block in enumerate(self.transformer_blocks):
145277
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -192,23 +324,23 @@ def __call__(
192324
seq_txt = encoder_hidden_states.shape[1]
193325

194326
# Compute QKV for image stream (sample projections)
195-
img_query = attn.to_q(hidden_states) // 8
196-
img_key = attn.to_k(hidden_states) // 8
197-
img_value = attn.to_v(hidden_states) // 8
327+
img_query = attn.to_q(hidden_states)
328+
img_key = attn.to_k(hidden_states)
329+
img_value = attn.to_v(hidden_states)
198330

199331
# Compute QKV for text stream (context projections)
200-
txt_query = attn.add_q_proj(encoder_hidden_states) // 8
201-
txt_key = attn.add_k_proj(encoder_hidden_states) // 8
202-
txt_value = attn.add_v_proj(encoder_hidden_states) // 8
332+
txt_query = attn.add_q_proj(encoder_hidden_states)
333+
txt_key = attn.add_k_proj(encoder_hidden_states)
334+
txt_value = attn.add_v_proj(encoder_hidden_states)
203335

204336
# Reshape for multi-head attention
205-
img_query = img_query.unflatten(-1, (attn.heads, -1)) // 2
206-
img_key = img_key.unflatten(-1, (attn.heads, -1)) // 2
207-
img_value = img_value.unflatten(-1, (attn.heads, -1)) // 2
337+
img_query = img_query.unflatten(-1, (attn.heads, -1))
338+
img_key = img_key.unflatten(-1, (attn.heads, -1))
339+
img_value = img_value.unflatten(-1, (attn.heads, -1))
208340

209-
txt_query = txt_query.unflatten(-1, (attn.heads, -1)) // 2
210-
txt_key = txt_key.unflatten(-1, (attn.heads, -1)) // 2
211-
txt_value = txt_value.unflatten(-1, (attn.heads, -1)) // 2
341+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
342+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
343+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
212344

213345
# Apply QK normalization
214346
if attn.norm_q is not None:
@@ -222,11 +354,14 @@ def __call__(
222354

223355
# Apply RoPE
224356
if image_rotary_emb is not None:
225-
img_freqs, txt_freqs = image_rotary_emb
226-
img_query = qeff_apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
227-
img_key = qeff_apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
228-
txt_query = qeff_apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
229-
txt_key = qeff_apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
357+
# breakpoint()
358+
# Unpack the 4 tensors (cos and sin for both img and txt)
359+
img_freqs_cos, img_freqs_sin, txt_freqs_cos, txt_freqs_sin = image_rotary_emb
360+
361+
img_query = qeff_apply_rotary_emb_qwen(img_query, img_freqs_cos, img_freqs_sin)
362+
img_key = qeff_apply_rotary_emb_qwen(img_key, img_freqs_cos, img_freqs_sin)
363+
txt_query = qeff_apply_rotary_emb_qwen(txt_query, txt_freqs_cos, txt_freqs_sin)
364+
txt_key = qeff_apply_rotary_emb_qwen(txt_key, txt_freqs_cos, txt_freqs_sin)
230365

231366
# Concatenate for joint attention
232367
# Order: [text, image]
@@ -254,10 +389,10 @@ def __call__(
254389
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
255390

256391
# Apply output projections
257-
img_attn_output = attn.to_out[0](img_attn_output) # scale back
392+
img_attn_output = attn.to_out[0](img_attn_output)
258393
if len(attn.to_out) > 1:
259394
img_attn_output = attn.to_out[1](img_attn_output) # dropout
260395

261-
txt_attn_output = attn.to_add_out(txt_attn_output) # scale back
396+
txt_attn_output = attn.to_add_out(txt_attn_output)
262397

263398
return img_attn_output, txt_attn_output

0 commit comments

Comments
 (0)