-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathattn_layer.py
More file actions
169 lines (145 loc) · 6.42 KB
/
attn_layer.py
File metadata and controls
169 lines (145 loc) · 6.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import logging
import torch
from torch import Tensor
import torch_npu
import torch.distributed as dist
import math
import os
from yunchang import LongContextAttention
try:
from yunchang.kernels import AttnType
except ImportError:
raise ImportError("Please install yunchang 0.6.0 or later")
from typing import Any
try:
from mindiesd.layers.flash_attn.attention_forward import attention_forward
MINDIE_SD_ATTENTION_FORWARD_AVAILABLE = True
except:
MINDIE_SD_ATTENTION_FORWARD_AVAILABLE = False
logging.info("MindIE-SD Attention Forward is not available, using torch_npu.npu_fusion_attention")
from ..distributed.parallel_mgr import get_sp_group
from ..distributed.comm import all_to_all_4D
logger = logging.getLogger(__name__)
MAX_TOKEN = 2147483647
class xFuserLongContextAttention(LongContextAttention):
ring_impl_type_supported_kv_cache = ["basic"]
def __init__(
self,
args: Any = None,
scatter_idx: int = 2,
gather_idx: int = 1,
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_kv_cache: bool = False,
attn_type: AttnType = AttnType.FA,
) -> None:
"""
Arguments:
scatter_idx: int = 2, the scatter dimension index for Ulysses All2All
gather_idx: int = 1, the gather dimension index for Ulysses All2All
ring_impl_type: str = "basic", the ring implementation type, currently only support "basic"
use_pack_qkv: bool = False, whether to use pack qkv in the input
use_kv_cache: bool = False, whether to use kv cache in the attention layer, which is applied in PipeFusion.
"""
super().__init__(
scatter_idx=scatter_idx,
gather_idx=gather_idx,
ring_impl_type=ring_impl_type,
use_pack_qkv=use_pack_qkv,
attn_type = attn_type,
)
self.use_kv_cache = use_kv_cache
if (
use_kv_cache
and ring_impl_type not in self.ring_impl_type_supported_kv_cache
):
raise RuntimeError(
f"ring_impl_type: {ring_impl_type} do not support SP kv cache."
)
self.world_size = dist.get_world_size()
self.args = args
self.video_size = ['480*832', '832*480', '480*720', '720*480']
self.algo = int(os.getenv('ALGO', 0))
self.ulysses_pg = get_sp_group().ulysses_group
self.ring_pg = get_sp_group().ring_group
def forward(
self,
attn,
query: Tensor,
key: Tensor,
value: Tensor,
*,
joint_tensor_query=None,
joint_tensor_key=None,
joint_tensor_value=None,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
joint_strategy="none",
scale=None
) -> Tensor:
"""forward
Arguments:
attn (Attention): the attention module
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args,
joint_tensor_query: Tensor = None, a replicated tensor among processes appended to the front or rear of query, depends the joint_strategy
joint_tensor_key: Tensor = None, a replicated tensor among processes appended to the front or rear of key, depends the joint_strategy
joint_tensor_value: Tensor = None, a replicated tensor among processes appended to the front or rear of value, depends the joint_strategy,
*args: the args same as flash_attn_interface
joint_strategy: str = "none", the joint strategy for joint attention, currently only support "front" and "rear"
Returns:
* output (Tensor): context output
"""
query_layer = all_to_all_4D(input_=query, scatter_idx=2, gather_idx=1, group=self.ulysses_pg)
key_layer = all_to_all_4D(input_=key, scatter_idx=2, gather_idx=1, group=self.ulysses_pg)
value_layer = all_to_all_4D(input_=value, scatter_idx=2, gather_idx=1, group=self.ulysses_pg)
if get_sp_group().ring_world_size > 1:
ring_size = get_sp_group().ring_world_size
b, s, n, d = key_layer.shape
k_full = torch.empty([ring_size, b, s, n, d], dtype=query_layer.dtype, device=query_layer.device)
dist.all_gather_into_tensor(k_full, key_layer, group=self.ring_pg)
key_layer = k_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d)
v_full = torch.empty([ring_size, b, s, n, d], dtype=query_layer.dtype, device=query_layer.device)
dist.all_gather_into_tensor(v_full, value_layer, group=self.ring_pg)
value_layer = v_full.permute(1, 0, 2, 3, 4).reshape(b, -1, n, d)
if not MINDIE_SD_ATTENTION_FORWARD_AVAILABLE:
head_num = query_layer.shape[-2]
head_dim = query_layer.shape[-1]
scale = head_dim ** -0.5
query_layer = query_layer.transpose(1, 2)
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)
out = torch_npu.npu_fusion_attention(
query_layer,
key_layer,
value_layer,
atten_mask=None,
input_layout="BNSD",
scale=scale,
pre_tockens=MAX_TOKEN,
next_tockens=MAX_TOKEN,
head_num=head_num)[0]
out = out.transpose(1, 2)
elif self.algo == 0:
out = attention_forward(query_layer, key_layer, value_layer,
opt_mode="manual", op_type="fused_attn_score", layout="BNSD")
elif self.algo == 1:
out = attention_forward(query_layer, key_layer, value_layer,
opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD")
else:
raise ValueError(f"select flash attention algorithm only support 0, 1, but got {self.algo}")
if type(out) == tuple:
context_layer, _, _ = out
else:
context_layer = out
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = all_to_all_4D(input_=context_layer, scatter_idx=1, gather_idx=2, group=self.ulysses_pg)
return output