-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathclip.py
More file actions
156 lines (135 loc) · 7.02 KB
/
clip.py
File metadata and controls
156 lines (135 loc) · 7.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Any, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.clip.modeling_clip import CLIPSdpaAttention
import numpy as np
import math
# Define a recursive monkey patching function
def replace_attn(module):
module_output = module
if isinstance(module, CLIPSdpaAttention):
module_output = CustomAttention(module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.out_proj)
for name, child in module.named_children():
module_output.add_module(name, replace_attn(child))
del module
return module_output
class DummyLayer(nn.Module):
def forward(self, x):
return x
class CustomAttention(nn.Module):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`CLIPAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
def __init__(self, config, q, k, v, o):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.k_proj.load_state_dict(k.state_dict())
self.v_proj.load_state_dict(v.state_dict())
self.q_proj.load_state_dict(q.state_dict())
self.out_proj.load_state_dict(o.state_dict())
# Dummy layers for hooking
self.dummy_before_softmax = DummyLayer()
self.dummy_after_softmax = DummyLayer()
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# Adapted from CLIPAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"CLIPModel is using CLIPSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
"the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
'be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
if attention_mask is not None and causal_attention_mask is not None:
attn_mask = attention_mask + causal_attention_mask
elif causal_attention_mask is not None:
attn_mask = causal_attention_mask
else:
attn_mask = attention_mask
bsz, tgt_len, embed_dim = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
# if not is_torch_greater_or_equal_than_2_2 and query_states.device.type == "cuda" and attn_mask is not None:
# query_states = query_states.contiguous()
# key_states = key_states.contiguous()
# value_states = value_states.contiguous()
# CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
attn_output = self.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=self.scale,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias = attn_mask + attn_bias
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = self.dummy_before_softmax(attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = self.dummy_after_softmax(attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value