From 78b2ad76ad72dd169bf9381d6c5c154239931cf1 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Wed, 9 Apr 2025 10:13:52 +0000 Subject: [PATCH 1/9] [cogview4] Enhance attention mechanism with variable-length support and batch packing Add support for variable-length attention between text and vision tokens while maintaining the original attention pattern. Implement batch packing capability to improve computational efficiency during inference and training. --- .../transformers/transformer_cogview4.py | 214 +++++++++++++++--- 1 file changed, 188 insertions(+), 26 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 41c4cbbf97c7..67ba7ad76aee 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -73,8 +73,9 @@ def __init__(self, embedding_dim: int, dim: int) -> None: def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states = self.norm(hidden_states) - norm_encoder_hidden_states = self.norm_context(encoder_hidden_states) + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) emb = self.linear(temb) ( @@ -124,14 +125,138 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ + Args: + attn (`Attention`): + The attention module. + hidden_states (`torch.Tensor`): + The input hidden states. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states for cross-attention. + attention_mask (`Dict[str, torch.Tensor]`, *optional*): + Dictionary containing mask configurations: + - `batch_flag` (`torch.Tensor`, *optional*): + Values from 0 to n-1 indicating which samples belong to the same batch. + Samples with the same batch_flag are packed together. + Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form batch1, and samples 3-4 form batch2. + If None, no packing is used. + - `text_embedding_attn_mask` (`torch.Tensor`, *optional*): + Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. + If None, full attention is used for all text tokens. + - `latent_embedding_attn_mask` (`torch.Tensor`, *optional*): + Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. + If None, full attention is used for all latent tokens. + Note: the shape of latent_embedding_attn_mask is (batch_size, num_latent_tokens). + image_rotary_emb (`torch.Tensor`, *optional*): + The rotary embedding for the image/latent part of the input. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams. + """ + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape batch_size, image_seq_length, embed_dim = hidden_states.shape - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + dtype = encoder_hidden_states.dtype + device = encoder_hidden_states.device + latent_hidden_states = hidden_states + mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1) - # 1. QKV projections + text_attention_mask, latent_attention_mask, batch_flag = None, None, None + + # 1. Construct attention mask and maybe packing input + if attention_mask is not None: + text_attention_mask = attention_mask.get("text_embedding_attn_mask", None) + latent_attention_mask = attention_mask.get("latent_embedding_attn_mask", None) + batch_flag = attention_mask.get("batch_flag", None) + + if text_attention_mask is None: + text_attention_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) + if latent_attention_mask is None: + latent_attention_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) + + assert text_attention_mask.dim() == 2, "the shape of text_attention_mask should be (batch_size, text_seq_length)" + assert text_attention_mask.dtype == torch.int32, "the dtype of text_attention_mask should be torch.int32" + assert latent_attention_mask.dim() == 2, "the shape of latent_attention_mask should be (batch_size, num_latent_tokens)" + assert latent_attention_mask.dtype == torch.int32, "the dtype of latent_attention_mask should be torch.int32" + + mixed_attention_mask = torch.ones( + (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device + ) + mixed_attention_mask[:, :text_seq_length] = text_attention_mask + mixed_attention_mask[:, text_seq_length:] = latent_attention_mask + + # Create attention mask matrix + mixed_attention_mask_input = mixed_attention_mask.unsqueeze(2).to(dtype=dtype) + attention_mask_matrix = mixed_attention_mask_input @ mixed_attention_mask_input.transpose(1, 2) + + # Apply batch packing if provided + if batch_flag is not None: + assert batch_flag.dim() == 1 + packing_batch_size = torch.max(batch_flag).item() + 1 + + text_seq_length = torch.sum(text_attention_mask, dim=1) + latent_seq_length = torch.sum(latent_attention_mask, dim=1) + mixed_seq_length = text_seq_length + latent_seq_length + + text_seq_length_packed = [ + torch.sum(text_attention_mask[batch_flag == batch_idx]).item() + for batch_idx in range(packing_batch_size) + ] + latent_seq_length_packed = [ + torch.sum(latent_attention_mask[batch_flag == batch_idx]).item() + for batch_idx in range(packing_batch_size) + ] + mixed_seq_length_packed = [ + torch.sum(mixed_attention_mask[batch_flag == batch_idx]).item() + for batch_idx in range(packing_batch_size) + ] + assert torch.equal( + torch.tensor(mixed_seq_length_packed), + torch.tensor(text_seq_length_packed) + torch.tensor(latent_seq_length_packed), + ) + assert len(mixed_seq_length_packed) == packing_batch_size + + mixed_attention_mask_flatten = mixed_attention_mask.flatten(0, 1) + mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1) + mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attention_mask_flatten == 1] + assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] + + mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed) + mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence( + mixed_hidden_states_packed, + batch_first=True, + padding_value=0.0, + padding_side="right", + ) + + mixed_attention_mask_packed = [ + torch.zeros(mixed_seq_length_packed[i], dtype=dtype, device=device) + for i in range(packing_batch_size) + ] + mixed_attention_mask_packed_padded = torch.nn.utils.rnn.pad_sequence( + mixed_attention_mask_packed, batch_first=True, padding_value=0, padding_side="right" + ) + attention_mask_matrix = mixed_attention_mask_packed_padded.unsqueeze(2) @ mixed_attention_mask_packed_padded.unsqueeze(1) + for idx, mask in enumerate(attention_mask_matrix): + seq_lengths = mixed_seq_length[batch_flag == idx] + offset = 0 + for length in seq_lengths: + mask[offset : offset + length, offset : offset + length] = 1 + offset += length + + attention_mask_matrix = attention_mask_matrix.to(dtype=torch.bool) + attention_mask_matrix = attention_mask_matrix.unsqueeze(1) # Add attention head dim + attention_mask = attention_mask_matrix + + if batch_flag is None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + else: + hidden_states = mixed_hidden_states_packed_padded + + # 2. QKV projections query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) @@ -140,32 +265,46 @@ def __call__( key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - # 2. QK normalization + # 3. QK normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - # 3. Rotational positional embeddings applied to latent stream + # 4. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb - query[:, :, text_seq_length:, :] = apply_rotary_emb( - query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 - ) - key[:, :, text_seq_length:, :] = apply_rotary_emb( - key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 - ) - - # 4. Attention - if attention_mask is not None: - text_attention_mask = attention_mask.float().to(query.device) - actual_text_seq_length = text_attention_mask.size(1) - new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device) - new_attention_mask[:, :actual_text_seq_length] = text_attention_mask - new_attention_mask = new_attention_mask.unsqueeze(2) - attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2) - attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype) + if batch_flag is None: + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + else: + assert query.shape[0] == packing_batch_size + assert key.shape[0] == packing_batch_size + + for idx in range(packing_batch_size): + offset = 0 + text_seq_length_bi = text_seq_length[batch_flag == idx] + latent_seq_length_bi = latent_seq_length[batch_flag == idx] + + for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi): + mlen = tlen + llen + image_rotary_emb_slice = (image_rotary_emb[0][:llen], image_rotary_emb[1][:llen]) + query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( + query[idx, :, offset + tlen : offset + mlen, :], + image_rotary_emb_slice, + use_real_unbind_dim=-2, + ) + key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( + key[idx, :, offset + tlen : offset + mlen, :], + image_rotary_emb_slice, + use_real_unbind_dim=-2, + ) + offset += mlen hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False @@ -177,9 +316,32 @@ def __call__( hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) + if batch_flag is None: + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + else: + hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence( + hidden_states, + lengths=torch.tensor(mixed_seq_length_packed), + batch_first=True, + ) + hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0) + hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist()) + assert len(hidden_states_unpack) == batch_size + hidden_states_unpack = [ + torch.split(h, [tlen, llen]) + for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length) + ] + encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack] + hidden_states_unpad = [h[1] for h in hidden_states_unpack] + + for idx in range(batch_size): + encoder_hidden_states[idx][text_attention_mask[idx] == 1] = encoder_hidden_states_unpad[idx] + latent_hidden_states[idx][latent_attention_mask[idx] == 1] = hidden_states_unpad[idx] + + hidden_states = latent_hidden_states + return hidden_states, encoder_hidden_states From 1fae35e84ace19f5f4ce92463d63f1e00db3459b Mon Sep 17 00:00:00 2001 From: OleehyO Date: Sun, 13 Apr 2025 13:59:05 +0000 Subject: [PATCH 2/9] [cogview4] Fix rope --- .../transformers/transformer_cogview4.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 67ba7ad76aee..ca07ab8f32e8 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -126,7 +126,7 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[Dict[str, torch.Tensor]] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]] = None, ) -> torch.Tensor: """ Args: @@ -150,8 +150,8 @@ def __call__( Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention is used for all latent tokens. Note: the shape of latent_embedding_attn_mask is (batch_size, num_latent_tokens). - image_rotary_emb (`torch.Tensor`, *optional*): - The rotary embedding for the image/latent part of the input. + image_rotary_emb (`torch.Tensor` or `list[torch.Tensor]`, *optional*): + The rotary embedding for the image part of the input. Returns: `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams. @@ -225,6 +225,7 @@ def __call__( assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed) + mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence( mixed_hidden_states_packed, batch_first=True, @@ -232,14 +233,12 @@ def __call__( padding_side="right", ) - mixed_attention_mask_packed = [ - torch.zeros(mixed_seq_length_packed[i], dtype=dtype, device=device) - for i in range(packing_batch_size) - ] - mixed_attention_mask_packed_padded = torch.nn.utils.rnn.pad_sequence( - mixed_attention_mask_packed, batch_first=True, padding_value=0, padding_side="right" + l = mixed_hidden_states_packed_padded.shape[1] + attention_mask_matrix = torch.zeros( + (packing_batch_size, l, l), + dtype=dtype, + device=device, ) - attention_mask_matrix = mixed_attention_mask_packed_padded.unsqueeze(2) @ mixed_attention_mask_packed_padded.unsqueeze(1) for idx, mask in enumerate(attention_mask_matrix): seq_lengths = mixed_seq_length[batch_flag == idx] offset = 0 @@ -285,7 +284,9 @@ def __call__( else: assert query.shape[0] == packing_batch_size assert key.shape[0] == packing_batch_size + assert len(image_rotary_emb) == batch_size + rope_idx = 0 for idx in range(packing_batch_size): offset = 0 text_seq_length_bi = text_seq_length[batch_flag == idx] @@ -293,25 +294,27 @@ def __call__( for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi): mlen = tlen + llen - image_rotary_emb_slice = (image_rotary_emb[0][:llen], image_rotary_emb[1][:llen]) query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( query[idx, :, offset + tlen : offset + mlen, :], - image_rotary_emb_slice, + image_rotary_emb[rope_idx], use_real_unbind_dim=-2, ) key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( key[idx, :, offset + tlen : offset + mlen, :], - image_rotary_emb_slice, + image_rotary_emb[rope_idx], use_real_unbind_dim=-2, ) offset += mlen + rope_idx += 1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) + # 5. Output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) @@ -564,6 +567,7 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]] = None, **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: if attention_kwargs is not None: @@ -584,7 +588,10 @@ def forward( batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE - image_rotary_emb = self.rope(hidden_states) + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + else: + image_rotary_emb = image_rotary_emb # 2. Patch & Timestep embeddings p = self.config.patch_size From 255cb5af181b503c6c6dad730c004b15eebd9eb2 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Tue, 15 Apr 2025 09:40:35 +0000 Subject: [PATCH 3/9] [cogview4] Fix tensor type after qk norm --- src/diffusers/models/transformers/transformer_cogview4.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index ca07ab8f32e8..256caef09a2c 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -266,9 +266,9 @@ def __call__( # 3. QK normalization if attn.norm_q is not None: - query = attn.norm_q(query) + query = attn.norm_q(query).to(dtype=dtype) if attn.norm_k is not None: - key = attn.norm_k(key) + key = attn.norm_k(key).to(dtype=dtype) # 4. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: From 1a48dcdd57db862878aa8f532af93e07c2c4e0bf Mon Sep 17 00:00:00 2001 From: OleehyO Date: Thu, 17 Apr 2025 04:48:05 +0000 Subject: [PATCH 4/9] [cogview4] Add docs for attn processor --- .../transformers/transformer_cogview4.py | 59 +++++++++++++++---- 1 file changed, 48 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 256caef09a2c..a357d369966c 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -157,50 +157,60 @@ def __call__( `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams. """ + # Get dimensions and device info batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape batch_size, image_seq_length, embed_dim = hidden_states.shape dtype = encoder_hidden_states.dtype device = encoder_hidden_states.device latent_hidden_states = hidden_states + # Combine text and image streams for joint processing mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1) + # Initialize mask variables text_attention_mask, latent_attention_mask, batch_flag = None, None, None # 1. Construct attention mask and maybe packing input if attention_mask is not None: + # Extract mask components from the dictionary text_attention_mask = attention_mask.get("text_embedding_attn_mask", None) latent_attention_mask = attention_mask.get("latent_embedding_attn_mask", None) batch_flag = attention_mask.get("batch_flag", None) + # Create default masks if not provided if text_attention_mask is None: text_attention_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) if latent_attention_mask is None: latent_attention_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) + # Validate mask shapes and types assert text_attention_mask.dim() == 2, "the shape of text_attention_mask should be (batch_size, text_seq_length)" assert text_attention_mask.dtype == torch.int32, "the dtype of text_attention_mask should be torch.int32" assert latent_attention_mask.dim() == 2, "the shape of latent_attention_mask should be (batch_size, num_latent_tokens)" assert latent_attention_mask.dtype == torch.int32, "the dtype of latent_attention_mask should be torch.int32" + # Create combined mask for text and image tokens mixed_attention_mask = torch.ones( (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device ) mixed_attention_mask[:, :text_seq_length] = text_attention_mask mixed_attention_mask[:, text_seq_length:] = latent_attention_mask - # Create attention mask matrix + # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend) mixed_attention_mask_input = mixed_attention_mask.unsqueeze(2).to(dtype=dtype) attention_mask_matrix = mixed_attention_mask_input @ mixed_attention_mask_input.transpose(1, 2) - # Apply batch packing if provided + # Handle batch packing if enabled if batch_flag is not None: assert batch_flag.dim() == 1 + # Determine packed batch size based on batch_flag packing_batch_size = torch.max(batch_flag).item() + 1 + # Calculate actual sequence lengths for each sample based on masks text_seq_length = torch.sum(text_attention_mask, dim=1) latent_seq_length = torch.sum(latent_attention_mask, dim=1) mixed_seq_length = text_seq_length + latent_seq_length + # Calculate packed sequence lengths for each packed batch text_seq_length_packed = [ torch.sum(text_attention_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) @@ -213,19 +223,19 @@ def __call__( torch.sum(mixed_attention_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) ] - assert torch.equal( - torch.tensor(mixed_seq_length_packed), - torch.tensor(text_seq_length_packed) + torch.tensor(latent_seq_length_packed), - ) + assert len(mixed_seq_length_packed) == packing_batch_size + # Pack sequences by removing padding tokens mixed_attention_mask_flatten = mixed_attention_mask.flatten(0, 1) mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1) mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attention_mask_flatten == 1] assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] + # Split the unpadded sequence into packed batches mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed) + # Re-pad to create packed batches with right-side padding mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence( mixed_hidden_states_packed, batch_first=True, @@ -233,16 +243,21 @@ def __call__( padding_side="right", ) + # Create attention mask for packed batches l = mixed_hidden_states_packed_padded.shape[1] attention_mask_matrix = torch.zeros( (packing_batch_size, l, l), dtype=dtype, device=device, ) + + # Fill attention mask with block diagonal matrices + # This ensures that tokens can only attend to other tokens within the same original sample for idx, mask in enumerate(attention_mask_matrix): seq_lengths = mixed_seq_length[batch_flag == idx] offset = 0 for length in seq_lengths: + # Create a block of 1s for each sample in the packed batch mask[offset : offset + length, offset : offset + length] = 1 offset += length @@ -250,31 +265,36 @@ def __call__( attention_mask_matrix = attention_mask_matrix.unsqueeze(1) # Add attention head dim attention_mask = attention_mask_matrix + # Prepare hidden states for attention computation if batch_flag is None: + # If no packing, just combine text and image tokens hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) else: + # If packing, use the packed sequence hidden_states = mixed_hidden_states_packed_padded - # 2. QKV projections + # 2. QKV projections - convert hidden states to query, key, value query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) + # Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim] query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - # 3. QK normalization + # 3. QK normalization - apply layer norm to queries and keys if configured if attn.norm_q is not None: query = attn.norm_q(query).to(dtype=dtype) if attn.norm_k is not None: key = attn.norm_k(key).to(dtype=dtype) - # 4. Rotational positional embeddings applied to latent stream + # 4. Apply rotary positional embeddings to image tokens only if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb if batch_flag is None: + # Apply RoPE only to image tokens (after text tokens) query[:, :, text_seq_length:, :] = apply_rotary_emb( query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) @@ -282,6 +302,7 @@ def __call__( key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) else: + # For packed batches, need to carefully apply RoPE to appropriate tokens assert query.shape[0] == packing_batch_size assert key.shape[0] == packing_batch_size assert len(image_rotary_emb) == batch_size @@ -289,11 +310,14 @@ def __call__( rope_idx = 0 for idx in range(packing_batch_size): offset = 0 + # Get text and image sequence lengths for samples in this packed batch text_seq_length_bi = text_seq_length[batch_flag == idx] latent_seq_length_bi = latent_seq_length[batch_flag == idx] + # Apply RoPE to each image segment in the packed sequence for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi): mlen = tlen + llen + # Apply RoPE only to image tokens (after text tokens) query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb( query[idx, :, offset + tlen : offset + mlen, :], image_rotary_emb[rope_idx], @@ -311,38 +335,51 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + # Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim] hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) - - # 5. Output projection + # 5. Output projection - project attention output to model dimension hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) + # Split the output back into text and image streams if batch_flag is None: + # Simple split for non-packed case encoder_hidden_states, hidden_states = hidden_states.split( [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 ) else: + # For packed case: need to unpack, split text/image, then restore to original shapes + # First, unpad the sequence based on the packed sequence lengths hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence( hidden_states, lengths=torch.tensor(mixed_seq_length_packed), batch_first=True, ) + # Concatenate all unpadded sequences hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0) + # Split by original sample sequence lengths hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist()) assert len(hidden_states_unpack) == batch_size + + # Further split each sample's sequence into text and image parts hidden_states_unpack = [ torch.split(h, [tlen, llen]) for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length) ] + # Separate text and image sequences encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack] hidden_states_unpad = [h[1] for h in hidden_states_unpack] + # Update the original tensors with the processed values, respecting the attention masks for idx in range(batch_size): + # Place unpacked text tokens back in the encoder_hidden_states tensor encoder_hidden_states[idx][text_attention_mask[idx] == 1] = encoder_hidden_states_unpad[idx] + # Place unpacked image tokens back in the latent_hidden_states tensor latent_hidden_states[idx][latent_attention_mask[idx] == 1] = hidden_states_unpad[idx] + # Update the output hidden states hidden_states = latent_hidden_states return hidden_states, encoder_hidden_states From f2a6e5ceeaeaceb6e283e2c2fc6a5c90be539382 Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 18 Apr 2025 05:49:58 +0000 Subject: [PATCH 5/9] [chore] Change type hint --- src/diffusers/models/transformers/transformer_cogview4.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index a357d369966c..698145a2dbad 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, List import torch import torch.nn as nn @@ -126,7 +126,7 @@ def __call__( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[Dict[str, torch.Tensor]] = None, - image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor] | List[Tuple[torch.Tensor, torch.Tensor]]] = None, ) -> torch.Tensor: """ Args: @@ -604,7 +604,7 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor] | List[Tuple[torch.Tensor, torch.Tensor]]] = None, **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: if attention_kwargs is not None: From ccf675235df739c51c91c95485d9b72162479ccc Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 18 Apr 2025 09:07:59 +0000 Subject: [PATCH 6/9] Rename as CogView4TrainingAttnProcessor --- .../transformers/transformer_cogview4.py | 171 +++++++++++++----- 1 file changed, 130 insertions(+), 41 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 698145a2dbad..7f01f8e1d8c2 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -112,8 +112,93 @@ def forward( class CogView4AttnProcessor: """ - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. + + The processor supports passing an attention mask for text tokens. The attention mask should have shape + (batch_size, text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + dtype = encoder_hidden_states.dtype + device = encoder_hidden_states.device + + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query).to(dtype=dtype) + if attn.norm_k is not None: + key = attn.norm_k(key).to(dtype=dtype) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + + # 4. Attention + if attention_mask is not None: + text_attn_mask = attention_mask["text_attn_mask"] + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + text_attn_mask = text_attn_mask.float().to(query.device) + mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) + mix_attn_mask[:, :text_seq_length] = text_attn_mask + mix_attn_mask = mix_attn_mask.unsqueeze(2) + attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) + attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class CogView4TrainingAttnProcessor: + """ + Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + + This processor differs from CogView4AttnProcessor in several important ways: + 1. It supports attention masking with variable sequence lengths for multi-resolution training + 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is provided """ def __init__(self): @@ -143,13 +228,13 @@ def __call__( Samples with the same batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form batch1, and samples 3-4 form batch2. If None, no packing is used. - - `text_embedding_attn_mask` (`torch.Tensor`, *optional*): + - `text_attn_mask` (`torch.Tensor`, *optional*): Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention is used for all text tokens. - - `latent_embedding_attn_mask` (`torch.Tensor`, *optional*): + - `latent_attn_mask` (`torch.Tensor`, *optional*): Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention is used for all latent tokens. - Note: the shape of latent_embedding_attn_mask is (batch_size, num_latent_tokens). + Note: the shape of latent_attn_mask is (batch_size, num_latent_tokens). image_rotary_emb (`torch.Tensor` or `list[torch.Tensor]`, *optional*): The rotary embedding for the image part of the input. @@ -167,37 +252,37 @@ def __call__( mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1) # Initialize mask variables - text_attention_mask, latent_attention_mask, batch_flag = None, None, None + text_attn_mask, latent_attn_mask, batch_flag = None, None, None # 1. Construct attention mask and maybe packing input if attention_mask is not None: # Extract mask components from the dictionary - text_attention_mask = attention_mask.get("text_embedding_attn_mask", None) - latent_attention_mask = attention_mask.get("latent_embedding_attn_mask", None) + text_attn_mask = attention_mask.get("text_attn_mask", None) + latent_attn_mask = attention_mask.get("latent_attn_mask", None) batch_flag = attention_mask.get("batch_flag", None) # Create default masks if not provided - if text_attention_mask is None: - text_attention_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) - if latent_attention_mask is None: - latent_attention_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) + if text_attn_mask is None: + text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) + if latent_attn_mask is None: + latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) # Validate mask shapes and types - assert text_attention_mask.dim() == 2, "the shape of text_attention_mask should be (batch_size, text_seq_length)" - assert text_attention_mask.dtype == torch.int32, "the dtype of text_attention_mask should be torch.int32" - assert latent_attention_mask.dim() == 2, "the shape of latent_attention_mask should be (batch_size, num_latent_tokens)" - assert latent_attention_mask.dtype == torch.int32, "the dtype of latent_attention_mask should be torch.int32" + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32" + assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)" + assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32" # Create combined mask for text and image tokens - mixed_attention_mask = torch.ones( + mixed_attn_mask = torch.ones( (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device ) - mixed_attention_mask[:, :text_seq_length] = text_attention_mask - mixed_attention_mask[:, text_seq_length:] = latent_attention_mask + mixed_attn_mask[:, :text_seq_length] = text_attn_mask + mixed_attn_mask[:, text_seq_length:] = latent_attn_mask # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend) - mixed_attention_mask_input = mixed_attention_mask.unsqueeze(2).to(dtype=dtype) - attention_mask_matrix = mixed_attention_mask_input @ mixed_attention_mask_input.transpose(1, 2) + mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype) + attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2) # Handle batch packing if enabled if batch_flag is not None: @@ -206,30 +291,30 @@ def __call__( packing_batch_size = torch.max(batch_flag).item() + 1 # Calculate actual sequence lengths for each sample based on masks - text_seq_length = torch.sum(text_attention_mask, dim=1) - latent_seq_length = torch.sum(latent_attention_mask, dim=1) + text_seq_length = torch.sum(text_attn_mask, dim=1) + latent_seq_length = torch.sum(latent_attn_mask, dim=1) mixed_seq_length = text_seq_length + latent_seq_length # Calculate packed sequence lengths for each packed batch text_seq_length_packed = [ - torch.sum(text_attention_mask[batch_flag == batch_idx]).item() + torch.sum(text_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) ] latent_seq_length_packed = [ - torch.sum(latent_attention_mask[batch_flag == batch_idx]).item() + torch.sum(latent_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) ] mixed_seq_length_packed = [ - torch.sum(mixed_attention_mask[batch_flag == batch_idx]).item() + torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) ] assert len(mixed_seq_length_packed) == packing_batch_size # Pack sequences by removing padding tokens - mixed_attention_mask_flatten = mixed_attention_mask.flatten(0, 1) + mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1) mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1) - mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attention_mask_flatten == 1] + mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1] assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] # Split the unpadded sequence into packed batches @@ -245,7 +330,7 @@ def __call__( # Create attention mask for packed batches l = mixed_hidden_states_packed_padded.shape[1] - attention_mask_matrix = torch.zeros( + attn_mask_matrix = torch.zeros( (packing_batch_size, l, l), dtype=dtype, device=device, @@ -253,7 +338,7 @@ def __call__( # Fill attention mask with block diagonal matrices # This ensures that tokens can only attend to other tokens within the same original sample - for idx, mask in enumerate(attention_mask_matrix): + for idx, mask in enumerate(attn_mask_matrix): seq_lengths = mixed_seq_length[batch_flag == idx] offset = 0 for length in seq_lengths: @@ -261,9 +346,9 @@ def __call__( mask[offset : offset + length, offset : offset + length] = 1 offset += length - attention_mask_matrix = attention_mask_matrix.to(dtype=torch.bool) - attention_mask_matrix = attention_mask_matrix.unsqueeze(1) # Add attention head dim - attention_mask = attention_mask_matrix + attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool) + attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim + attention_mask = attn_mask_matrix # Prepare hidden states for attention computation if batch_flag is None: @@ -375,9 +460,9 @@ def __call__( # Update the original tensors with the processed values, respecting the attention masks for idx in range(batch_size): # Place unpacked text tokens back in the encoder_hidden_states tensor - encoder_hidden_states[idx][text_attention_mask[idx] == 1] = encoder_hidden_states_unpad[idx] + encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx] # Place unpacked image tokens back in the latent_hidden_states tensor - latent_hidden_states[idx][latent_attention_mask[idx] == 1] = hidden_states_unpad[idx] + latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx] # Update the output hidden states hidden_states = latent_hidden_states @@ -387,7 +472,12 @@ def __call__( class CogView4TransformerBlock(nn.Module): def __init__( - self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512 + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + attn_processor: Union[CogView4AttnProcessor, CogView4TrainingAttnProcessor] = CogView4AttnProcessor(), ) -> None: super().__init__() @@ -402,7 +492,7 @@ def __init__( qk_norm="layer_norm", elementwise_affine=False, eps=1e-5, - processor=CogView4AttnProcessor(), + processor=attn_processor, ) # 2. Feedforward @@ -416,7 +506,7 @@ def forward( encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: # 1. Timestep conditioning @@ -557,6 +647,7 @@ def __init__( pos_embed_max_size: int = 128, sample_size: int = 128, rope_axes_dim: Tuple[int, int] = (256, 256), + attn_processor: Union[CogView4AttnProcessor, CogView4TrainingAttnProcessor] = CogView4AttnProcessor(), ): super().__init__() @@ -582,7 +673,7 @@ def __init__( # 3. Transformer blocks self.transformer_blocks = nn.ModuleList( [ - CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim, attn_processor) for _ in range(num_layers) ] ) @@ -603,7 +694,7 @@ def forward( crop_coords: torch.Tensor, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor] | List[Tuple[torch.Tensor, torch.Tensor]]] = None, **kwargs, ) -> Union[torch.Tensor, Transformer2DModelOutput]: @@ -627,8 +718,6 @@ def forward( # 1. RoPE if image_rotary_emb is None: image_rotary_emb = self.rope(hidden_states) - else: - image_rotary_emb = image_rotary_emb # 2. Patch & Timestep embeddings p = self.config.patch_size From fe0c30b5f4afa3baa6888346dc653dde04468c2d Mon Sep 17 00:00:00 2001 From: OleehyO Date: Fri, 18 Apr 2025 15:39:21 +0000 Subject: [PATCH 7/9] [refactor] Back to original signature, using `attention_kwargs` instead --- .../transformers/transformer_cogview4.py | 251 +++++++++--------- 1 file changed, 122 insertions(+), 129 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 7f01f8e1d8c2..b68f46deafb2 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -128,9 +128,9 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[Dict[str, torch.Tensor]] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = encoder_hidden_states.dtype device = encoder_hidden_states.device @@ -166,7 +166,7 @@ def __call__( # 4. Attention if attention_mask is not None: - text_attn_mask = attention_mask["text_attn_mask"] + text_attn_mask = attention_mask assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" text_attn_mask = text_attn_mask.float().to(query.device) mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) @@ -210,9 +210,12 @@ def __call__( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - attention_mask: Optional[Dict[str, torch.Tensor]] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor] | List[Tuple[torch.Tensor, torch.Tensor]]] = None, - ) -> torch.Tensor: + latent_attn_mask: Optional[torch.Tensor] = None, + text_attn_mask: Optional[torch.Tensor] = None, + batch_flag: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: attn (`Attention`): @@ -221,23 +224,20 @@ def __call__( The input hidden states. encoder_hidden_states (`torch.Tensor`): The encoder hidden states for cross-attention. - attention_mask (`Dict[str, torch.Tensor]`, *optional*): - Dictionary containing mask configurations: - - `batch_flag` (`torch.Tensor`, *optional*): - Values from 0 to n-1 indicating which samples belong to the same batch. - Samples with the same batch_flag are packed together. - Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form batch1, and samples 3-4 form batch2. - If None, no packing is used. - - `text_attn_mask` (`torch.Tensor`, *optional*): - Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. - If None, full attention is used for all text tokens. - - `latent_attn_mask` (`torch.Tensor`, *optional*): - Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. - If None, full attention is used for all latent tokens. - Note: the shape of latent_attn_mask is (batch_size, num_latent_tokens). - image_rotary_emb (`torch.Tensor` or `list[torch.Tensor]`, *optional*): + latent_attn_mask (`torch.Tensor`, *optional*): + Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. + If None, full attention is used for all latent tokens. + Note: the shape of latent_attn_mask is (batch_size, num_latent_tokens). + text_attn_mask (`torch.Tensor`, *optional*): + Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. + If None, full attention is used for all text tokens. + batch_flag (`torch.Tensor`, *optional*): + Values from 0 to n-1 indicating which samples belong to the same batch. + Samples with the same batch_flag are packed together. + Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form batch1, and samples 3-4 form batch2. + If None, no packing is used. + image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*): The rotary embedding for the image part of the input. - Returns: `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams. """ @@ -251,104 +251,87 @@ def __call__( # Combine text and image streams for joint processing mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1) - # Initialize mask variables - text_attn_mask, latent_attn_mask, batch_flag = None, None, None - # 1. Construct attention mask and maybe packing input - if attention_mask is not None: - # Extract mask components from the dictionary - text_attn_mask = attention_mask.get("text_attn_mask", None) - latent_attn_mask = attention_mask.get("latent_attn_mask", None) - batch_flag = attention_mask.get("batch_flag", None) - - # Create default masks if not provided - if text_attn_mask is None: - text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) - if latent_attn_mask is None: - latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) - - # Validate mask shapes and types - assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" - assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32" - assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)" - assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32" + # Create default masks if not provided + if text_attn_mask is None: + text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device) + if latent_attn_mask is None: + latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device) + + # Validate mask shapes and types + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32" + assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)" + assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32" + + # Create combined mask for text and image tokens + mixed_attn_mask = torch.ones( + (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device + ) + mixed_attn_mask[:, :text_seq_length] = text_attn_mask + mixed_attn_mask[:, text_seq_length:] = latent_attn_mask + + # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend) + mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype) + attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2) + + # Handle batch packing if enabled + if batch_flag is not None: + assert batch_flag.dim() == 1 + # Determine packed batch size based on batch_flag + packing_batch_size = torch.max(batch_flag).item() + 1 + + # Calculate actual sequence lengths for each sample based on masks + text_seq_length = torch.sum(text_attn_mask, dim=1) + latent_seq_length = torch.sum(latent_attn_mask, dim=1) + mixed_seq_length = text_seq_length + latent_seq_length + + # Calculate packed sequence lengths for each packed batch + mixed_seq_length_packed = [ + torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() + for batch_idx in range(packing_batch_size) + ] + + assert len(mixed_seq_length_packed) == packing_batch_size - # Create combined mask for text and image tokens - mixed_attn_mask = torch.ones( - (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device - ) - mixed_attn_mask[:, :text_seq_length] = text_attn_mask - mixed_attn_mask[:, text_seq_length:] = latent_attn_mask - - # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend) - mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype) - attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2) - - # Handle batch packing if enabled - if batch_flag is not None: - assert batch_flag.dim() == 1 - # Determine packed batch size based on batch_flag - packing_batch_size = torch.max(batch_flag).item() + 1 - - # Calculate actual sequence lengths for each sample based on masks - text_seq_length = torch.sum(text_attn_mask, dim=1) - latent_seq_length = torch.sum(latent_attn_mask, dim=1) - mixed_seq_length = text_seq_length + latent_seq_length - - # Calculate packed sequence lengths for each packed batch - text_seq_length_packed = [ - torch.sum(text_attn_mask[batch_flag == batch_idx]).item() - for batch_idx in range(packing_batch_size) - ] - latent_seq_length_packed = [ - torch.sum(latent_attn_mask[batch_flag == batch_idx]).item() - for batch_idx in range(packing_batch_size) - ] - mixed_seq_length_packed = [ - torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() - for batch_idx in range(packing_batch_size) - ] - - assert len(mixed_seq_length_packed) == packing_batch_size - - # Pack sequences by removing padding tokens - mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1) - mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1) - mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1] - assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] - - # Split the unpadded sequence into packed batches - mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed) - - # Re-pad to create packed batches with right-side padding - mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence( - mixed_hidden_states_packed, - batch_first=True, - padding_value=0.0, - padding_side="right", - ) + # Pack sequences by removing padding tokens + mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1) + mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1) + mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1] + assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0] - # Create attention mask for packed batches - l = mixed_hidden_states_packed_padded.shape[1] - attn_mask_matrix = torch.zeros( - (packing_batch_size, l, l), - dtype=dtype, - device=device, - ) - - # Fill attention mask with block diagonal matrices - # This ensures that tokens can only attend to other tokens within the same original sample - for idx, mask in enumerate(attn_mask_matrix): - seq_lengths = mixed_seq_length[batch_flag == idx] - offset = 0 - for length in seq_lengths: - # Create a block of 1s for each sample in the packed batch - mask[offset : offset + length, offset : offset + length] = 1 - offset += length + # Split the unpadded sequence into packed batches + mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed) - attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool) - attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim - attention_mask = attn_mask_matrix + # Re-pad to create packed batches with right-side padding + mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence( + mixed_hidden_states_packed, + batch_first=True, + padding_value=0.0, + padding_side="right", + ) + + # Create attention mask for packed batches + l = mixed_hidden_states_packed_padded.shape[1] + attn_mask_matrix = torch.zeros( + (packing_batch_size, l, l), + dtype=dtype, + device=device, + ) + + # Fill attention mask with block diagonal matrices + # This ensures that tokens can only attend to other tokens within the same original sample + for idx, mask in enumerate(attn_mask_matrix): + seq_lengths = mixed_seq_length[batch_flag == idx] + offset = 0 + for length in seq_lengths: + # Create a block of 1s for each sample in the packed batch + mask[offset : offset + length, offset : offset + length] = 1 + offset += length + + attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool) + attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim + attention_mask = attn_mask_matrix # Prepare hidden states for attention computation if batch_flag is None: @@ -477,7 +460,6 @@ def __init__( num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512, - attn_processor: Union[CogView4AttnProcessor, CogView4TrainingAttnProcessor] = CogView4AttnProcessor(), ) -> None: super().__init__() @@ -492,7 +474,7 @@ def __init__( qk_norm="layer_norm", elementwise_affine=False, eps=1e-5, - processor=attn_processor, + processor=CogView4AttnProcessor(), ) # 2. Feedforward @@ -505,9 +487,9 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]] = None, attention_mask: Optional[Dict[str, torch.Tensor]] = None, - **kwargs, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: # 1. Timestep conditioning ( @@ -524,12 +506,14 @@ def forward( ) = self.norm1(hidden_states, encoder_hidden_states, temb) # 2. Attention + if attention_kwargs is None: + attention_kwargs = {} attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, - **kwargs, + **attention_kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) @@ -647,7 +631,6 @@ def __init__( pos_embed_max_size: int = 128, sample_size: int = 128, rope_axes_dim: Tuple[int, int] = (256, 256), - attn_processor: Union[CogView4AttnProcessor, CogView4TrainingAttnProcessor] = CogView4AttnProcessor(), ): super().__init__() @@ -673,7 +656,7 @@ def __init__( # 3. Transformer blocks self.transformer_blocks = nn.ModuleList( [ - CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim, attn_processor) + CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) for _ in range(num_layers) ] ) @@ -694,9 +677,8 @@ def forward( crop_coords: torch.Tensor, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, - attention_mask: Optional[Dict[str, torch.Tensor]] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor] | List[Tuple[torch.Tensor, torch.Tensor]]] = None, - **kwargs, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -733,11 +715,22 @@ def forward( for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, ) else: hidden_states, encoder_hidden_states = block( - hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, ) # 4. Output norm & projection From b70f208062639544dedf4dafefad427806c7a9b8 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 21 Apr 2025 08:35:44 -1000 Subject: [PATCH 8/9] Update src/diffusers/models/transformers/transformer_cogview4.py --- src/diffusers/models/transformers/transformer_cogview4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index b68f46deafb2..02d395100024 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -132,7 +132,6 @@ def __call__( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = encoder_hidden_states.dtype - device = encoder_hidden_states.device batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape batch_size, image_seq_length, embed_dim = hidden_states.shape From 1f657ac698fcc98ffc52fc6e8351da146c7f443e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 21 Apr 2025 18:36:53 +0000 Subject: [PATCH 9/9] Apply style fixes --- src/diffusers/models/__init__.py | 2 +- .../transformers/transformer_cogview4.py | 53 ++++++++++--------- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 99a2f871c837..218394af2843 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] @@ -41,7 +42,6 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] - _import_structure["auto_model"] = ["AutoModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 02d395100024..aef368f91ac0 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union, List +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -115,8 +115,8 @@ class CogView4AttnProcessor: Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on query and key vectors, but does not include spatial normalization. - The processor supports passing an attention mask for text tokens. The attention mask should have shape - (batch_size, text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. + The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, + text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. """ def __init__(self): @@ -192,12 +192,13 @@ def __call__( class CogView4TrainingAttnProcessor: """ - Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - + Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary + embedding on query and key vectors, but does not include spatial normalization. + This processor differs from CogView4AttnProcessor in several important ways: 1. It supports attention masking with variable sequence lengths for multi-resolution training - 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is provided + 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is + provided """ def __init__(self): @@ -212,7 +213,9 @@ def __call__( latent_attn_mask: Optional[torch.Tensor] = None, text_attn_mask: Optional[torch.Tensor] = None, batch_flag: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -224,17 +227,16 @@ def __call__( encoder_hidden_states (`torch.Tensor`): The encoder hidden states for cross-attention. latent_attn_mask (`torch.Tensor`, *optional*): - Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. - If None, full attention is used for all latent tokens. - Note: the shape of latent_attn_mask is (batch_size, num_latent_tokens). + Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full + attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size, + num_latent_tokens). text_attn_mask (`torch.Tensor`, *optional*): - Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. - If None, full attention is used for all text tokens. + Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention + is used for all text tokens. batch_flag (`torch.Tensor`, *optional*): - Values from 0 to n-1 indicating which samples belong to the same batch. - Samples with the same batch_flag are packed together. - Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form batch1, and samples 3-4 form batch2. - If None, no packing is used. + Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same + batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form + batch1, and samples 3-4 form batch2. If None, no packing is used. image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*): The rotary embedding for the image part of the input. Returns: @@ -287,10 +289,9 @@ def __call__( # Calculate packed sequence lengths for each packed batch mixed_seq_length_packed = [ - torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() - for batch_idx in range(packing_batch_size) + torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size) ] - + assert len(mixed_seq_length_packed) == packing_batch_size # Pack sequences by removing padding tokens @@ -317,7 +318,7 @@ def __call__( dtype=dtype, device=device, ) - + # Fill attention mask with block diagonal matrices # This ensures that tokens can only attend to other tokens within the same original sample for idx, mask in enumerate(attn_mask_matrix): @@ -429,7 +430,7 @@ def __call__( # Split by original sample sequence lengths hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist()) assert len(hidden_states_unpack) == batch_size - + # Further split each sample's sequence into text and image parts hidden_states_unpack = [ torch.split(h, [tlen, llen]) @@ -486,7 +487,9 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, attention_mask: Optional[Dict[str, torch.Tensor]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: @@ -677,7 +680,9 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy()