Skip to content

Commit dca431f

Browse files
fix(AINode): [Issue-17301] Add missing IBM deps (basic, normalization, tools) and fix forward() argument mismatch
1 parent 2e4a5e3 commit dca431f

File tree

4 files changed

+636
-16
lines changed

4 files changed

+636
-16
lines changed
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
from typing import Optional, Type
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
8+
def make_attn_mask(query_pad: torch.Tensor, key_pad: torch.Tensor) -> torch.Tensor:
9+
"""
10+
Build an additive attention mask of shape (B, Q, K) from
11+
query/key padding masks.
12+
13+
Args:
14+
query_pad: (B, Q) bool or 0/1 tensor. 1/True = padded query position.
15+
key_pad: (B, K) bool or 0/1 tensor. 1/True = padded key position.
16+
17+
Returns:
18+
attn_mask: (B, Q, K) float tensor, where masked positions are -inf
19+
and valid positions are 0.0 (for use with SDPA).
20+
"""
21+
# Ensure boolean
22+
q_pad = query_pad.bool() # (B, Q)
23+
k_pad = key_pad.bool() # (B, K)
24+
25+
# A position (q, k) is invalid if *either* the query or key is padded
26+
# Shape: (B, Q, K)
27+
pad = q_pad.unsqueeze(-1) | k_pad.unsqueeze(-2)
28+
29+
# Build float mask with -inf on padded positions, 0 elsewhere
30+
attn_mask = torch.zeros_like(pad, dtype=torch.float32)
31+
attn_mask.masked_fill_(pad, float("-inf"))
32+
33+
return attn_mask
34+
35+
36+
class MLP(nn.Module):
37+
def __init__(
38+
self,
39+
in_dim,
40+
out_dim,
41+
hidden_dim=256,
42+
num_hidden_layers=1,
43+
dropout=0,
44+
norm=False,
45+
activation=nn.GELU(approximate="tanh"),
46+
output_activation=nn.Identity(),
47+
norm_layer=nn.LayerNorm,
48+
):
49+
super().__init__()
50+
layers = []
51+
layers.append(nn.Linear(in_dim, hidden_dim))
52+
# layers.append(norm_layer(hidden_dim) if norm else nn.Identity())
53+
layers.append(activation)
54+
for _ in range(num_hidden_layers - 1):
55+
layers.append(nn.Dropout(dropout))
56+
layers.append(norm_layer(hidden_dim) if norm else nn.Identity())
57+
layers.append(nn.Linear(hidden_dim, hidden_dim))
58+
layers.append(activation)
59+
layers.append(nn.Dropout(dropout))
60+
layers.append(norm_layer(hidden_dim) if norm else nn.Identity())
61+
layers.append(nn.Linear(hidden_dim, out_dim))
62+
layers.append(output_activation)
63+
self.layers = nn.Sequential(*layers)
64+
# self.init_weights()
65+
66+
def forward(self, x):
67+
return self.layers(x)
68+
69+
70+
class SwiGLU(nn.Module):
71+
def __init__(self, in_dim, out_dim, hidden_dim=384, dropout=0):
72+
super().__init__()
73+
hidden_dim = round(hidden_dim * 2 / 3)
74+
self.fc1 = nn.Linear(in_dim, hidden_dim)
75+
self.fc2 = nn.Linear(in_dim, hidden_dim)
76+
self.fc3 = nn.Linear(hidden_dim, out_dim)
77+
self.activation = nn.SiLU()
78+
self.dropout = nn.Dropout(dropout)
79+
80+
def forward(self, x):
81+
x = self.fc1(x) * self.activation(self.fc2(x))
82+
return self.dropout(self.fc3(x))
83+
84+
85+
class Attention(nn.Module):
86+
def __init__(
87+
self,
88+
dim: int,
89+
num_heads: int = 8,
90+
qkv_bias: bool = False,
91+
qk_norm: bool = False,
92+
proj_bias: bool = True,
93+
attn_drop: float = 0.0,
94+
proj_drop: float = 0.0,
95+
norm_layer: Type[nn.Module] = nn.LayerNorm,
96+
) -> None:
97+
super().__init__()
98+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
99+
self.num_heads = num_heads
100+
self.head_dim = dim // num_heads
101+
self.scale = self.head_dim**-0.5
102+
103+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
104+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
105+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
106+
self.attn_drop = nn.Dropout(attn_drop)
107+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
108+
self.proj_drop = nn.Dropout(proj_drop)
109+
110+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
111+
if x.ndim == 3:
112+
B, N, C = x.shape
113+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
114+
q, k, v = qkv.unbind(0) # (B, num_heads, N, head_dim)
115+
q, k = self.q_norm(q), self.k_norm(k)
116+
x = F.scaled_dot_product_attention(
117+
q,
118+
k,
119+
v,
120+
dropout_p=self.attn_drop.p if self.training else 0.0,
121+
attn_mask=attn_mask,
122+
)
123+
x = x.transpose(1, 2).reshape(B, N, C)
124+
elif x.ndim == 4:
125+
B, M, N, C = x.shape
126+
qkv = self.qkv(x).reshape(B, M, N, 3, self.num_heads, self.head_dim).permute(3, 0, 4, 1, 2, 5)
127+
q, k, v = qkv.unbind(0) # (B, num_heads, M, N, head_dim)
128+
q, k = self.q_norm(q), self.k_norm(k)
129+
# print('q', q.shape, 'k', k.shape, 'v', v.shape, 'attn_mask', attn_mask.shape if attn_mask is not None else "None")
130+
x = F.scaled_dot_product_attention(
131+
q,
132+
k,
133+
v,
134+
dropout_p=self.attn_drop.p if self.training else 0.0,
135+
attn_mask=attn_mask.unsqueeze(1) if attn_mask is not None else None,
136+
)
137+
x = x.permute(0, 2, 3, 1, 4).reshape(B, M, N, C)
138+
else:
139+
raise ValueError(f"Unsupported input dimension: {x.ndim}")
140+
x = self.proj(x)
141+
x = self.proj_drop(x)
142+
return x
143+
144+
145+
class CrossAttention(nn.Module):
146+
def __init__(
147+
self,
148+
q_dim: int, # dim of x
149+
kv_dim: Optional[int] = None, # dim of m (defaults to q_dim)
150+
num_heads: int = 8,
151+
qkv_bias: bool = False,
152+
qk_norm: bool = False,
153+
proj_bias: bool = True,
154+
attn_drop: float = 0.0,
155+
proj_drop: float = 0.0,
156+
norm_layer: Type[nn.Module] = nn.LayerNorm,
157+
) -> None:
158+
super().__init__()
159+
kv_dim = kv_dim if kv_dim is not None else q_dim
160+
assert q_dim % num_heads == 0, "q_dim must be divisible by num_heads"
161+
162+
self.num_heads = num_heads
163+
self.head_dim = q_dim // num_heads
164+
165+
self.q = nn.Linear(q_dim, q_dim, bias=qkv_bias)
166+
self.kv = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias) # produce k and v in the SAME head dim as q
167+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
168+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
169+
170+
self.attn_drop = nn.Dropout(attn_drop)
171+
self.proj = nn.Linear(q_dim, q_dim, bias=proj_bias)
172+
self.proj_drop = nn.Dropout(proj_drop)
173+
174+
def forward(
175+
self,
176+
x: torch.Tensor, # (B, Nq, q_dim)
177+
m: torch.Tensor, # (B, Nk, kv_dim)
178+
attn_mask: Optional[torch.Tensor] = None, # broadcastable to (B, num_heads, Nq, Nk) or (Nq, Nk)
179+
is_causal: bool = False,
180+
) -> torch.Tensor:
181+
if x.ndim == 3:
182+
B, Nq, Cq = x.shape
183+
_, Nk, _ = m.shape
184+
q = self.q(x).reshape(B, Nq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # (B, H, Nq, Hd)
185+
kv = self.kv(m).reshape(B, Nk, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
186+
k, v = kv.unbind(0) # (B, H, Nk, Hd)
187+
q, k = self.q_norm(q), self.k_norm(k)
188+
x = F.scaled_dot_product_attention(
189+
q,
190+
k,
191+
v,
192+
attn_mask=attn_mask,
193+
dropout_p=self.attn_drop.p if self.training else 0.0,
194+
is_causal=is_causal,
195+
) # (B, H, Nq, Hd)
196+
x = x.transpose(1, 2).reshape(B, Nq, Cq) # back to (B, Nq, q_dim)
197+
elif x.ndim == 4:
198+
B, M, Nq, Cq = x.shape
199+
_, Nk, _ = m.shape
200+
q = self.q(x).reshape(B, M, Nq, self.num_heads, self.head_dim).permute(0, 3, 1, 2, 4) # (B, H, M, Nq, Hd)
201+
kv = self.kv(m).reshape(B, Nk, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
202+
k, v = kv.unbind(0) # (B, H, Nk, Hd)
203+
q, k = self.q_norm(q), self.k_norm(k)
204+
x = F.scaled_dot_product_attention(
205+
q,
206+
k.unsqueeze(2),
207+
v.unsqueeze(2),
208+
attn_mask=attn_mask.unsqueeze(1) if attn_mask is not None else None,
209+
dropout_p=self.attn_drop.p if self.training else 0.0,
210+
is_causal=is_causal,
211+
) # (B, H, M, Nq, Hd)
212+
x = x.permute(0, 2, 3, 1, 4).reshape(B, M, Nq, Cq)
213+
else:
214+
raise ValueError(f"Unsupported input dimension: {x.ndim}")
215+
x = self.proj_drop(self.proj(x))
216+
return x
217+
218+
219+
class TransformerBlock(nn.Module):
220+
"""
221+
A standard Transformer block.
222+
"""
223+
224+
def __init__(
225+
self,
226+
d_model,
227+
num_heads,
228+
mlp_ratio=4.0,
229+
dropout=0.1,
230+
norm_first=True,
231+
norm_layer=nn.LayerNorm,
232+
mlp_type="mlp",
233+
):
234+
super().__init__()
235+
self.norm_first = norm_first
236+
self.norm1 = norm_layer(d_model, elementwise_affine=True, eps=1e-6)
237+
self.attn = Attention(d_model, num_heads, qkv_bias=True, attn_drop=dropout, proj_drop=dropout)
238+
self.norm2 = norm_layer(d_model, elementwise_affine=True, eps=1e-6)
239+
if mlp_type == "swiglu":
240+
self.mlp = SwiGLU(d_model, d_model, hidden_dim=int(mlp_ratio * d_model), dropout=dropout)
241+
elif mlp_type == "mlp":
242+
self.mlp = MLP(
243+
in_dim=d_model,
244+
out_dim=d_model,
245+
hidden_dim=int(mlp_ratio * d_model),
246+
dropout=dropout,
247+
)
248+
else:
249+
raise ValueError(f"Unsupported MLP type: {mlp_type}")
250+
self.dropout = nn.Dropout(dropout)
251+
252+
def forward(self, x, attn_mask=None):
253+
if self.norm_first:
254+
x = x + self.attn(self.norm1(x), attn_mask)
255+
x = x + self.dropout(self.mlp(self.norm2(x)))
256+
else:
257+
x = self.norm1(x + self.attn(x, attn_mask))
258+
x = self.norm2(x + self.dropout(self.mlp(x)))
259+
return x
260+
261+
262+
class TransformerBlockCrossAttention(nn.Module):
263+
def __init__(
264+
self,
265+
d_model,
266+
num_heads,
267+
d_cond=None,
268+
mlp_ratio=4.0,
269+
dropout=0.1,
270+
norm_first=True,
271+
norm_layer=nn.LayerNorm,
272+
mlp_type="mlp",
273+
):
274+
super().__init__()
275+
d_cond = d_cond if d_cond is not None else d_model
276+
self.norm_first = norm_first
277+
self.norm1 = norm_layer(d_model, elementwise_affine=True, eps=1e-6)
278+
self.attn = CrossAttention(
279+
d_model,
280+
d_cond,
281+
num_heads,
282+
qkv_bias=True,
283+
attn_drop=dropout,
284+
proj_drop=dropout,
285+
)
286+
self.norm2 = norm_layer(d_model, elementwise_affine=True, eps=1e-6)
287+
if mlp_type == "swiglu":
288+
self.mlp = SwiGLU(d_model, d_model, hidden_dim=int(mlp_ratio * d_model), dropout=dropout)
289+
elif mlp_type == "mlp":
290+
self.mlp = MLP(
291+
in_dim=d_model,
292+
out_dim=d_model,
293+
hidden_dim=int(mlp_ratio * d_model),
294+
dropout=dropout,
295+
)
296+
else:
297+
raise ValueError(f"Unsupported MLP type: {mlp_type}")
298+
self.dropout = nn.Dropout(dropout)
299+
300+
def forward(self, x, m, attn_mask=None):
301+
if self.norm_first:
302+
x = x + self.attn(self.norm1(x), m, attn_mask)
303+
x = x + self.dropout(self.mlp(self.norm2(x)))
304+
else:
305+
x = self.norm1(x + self.attn(x, m, attn_mask))
306+
x = self.norm2(x + self.dropout(self.mlp(x)))
307+
return x

0 commit comments

Comments
 (0)