Skip to content

Commit 15c8794

Browse files
authored
Merge pull request #1354 from calmdown539/dev-postgresql
Add the implementations for the TransformerEncoderLayer
2 parents 8dcdcba + c9a06b0 commit 15c8794

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

  • examples/singa_peft/examples/model

examples/singa_peft/examples/model/trans.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,46 @@ def get_posi_angle_vec(position):
329329
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
330330
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
331331
return tensor.Tensor(data=sinusoid_table, requires_grad=False)
332+
333+
334+
class TransformerEncoderLayer(layer.Layer):
335+
def __init__(self, d_model=512, n_head=8, dim_feedforward=2048):
336+
super(TransformerEncoderLayer, self).__init__()
337+
self.d_model = d_model
338+
self.n_head = n_head
339+
self.dim_feedforward = dim_feedforward
340+
self.enc_self_attn = MultiHeadAttention(d_model, n_head)
341+
self.pos_ffn = PoswiseFeedForwardNet(d_model=d_model, dim_feedforward=dim_feedforward, bias=False)
342+
343+
def forward(self, enc_inputs, enc_self_attn_mask):
344+
"""
345+
Args:
346+
enc_inputs: [batch_size, src_len, d_model]
347+
enc_self_attn_mask: [batch_size, src_len, src_len]
348+
349+
Returns:
350+
enc_outputs: [batch_size, src_len, d_model]
351+
attn: [batch_size, n_heads, src_len, src_len]
352+
"""
353+
# enc_inputs to same Q,K,V
354+
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
355+
enc_outputs = self.pos_ffn(enc_outputs)
356+
return enc_outputs, attn
357+
358+
359+
def matmul4d(x1, x2):
360+
batchs, heads = x1.shape[0], x1.shape[1]
361+
ys = []
362+
for b in range(batchs):
363+
x1b, x2b = autograd.squeeze(x1[b]), autograd.squeeze(x2[b])
364+
yb = []
365+
for h in range(heads):
366+
x1h, x2h = autograd.squeeze(x1b[h]), autograd.squeeze(x2b[h])
367+
yh = autograd.matmul(x1h, x2h)
368+
yh = autograd.unsqueeze(yh, axis=[0])
369+
yb.append(yh)
370+
yb = autograd.cat(yb, axis=0)
371+
yb = autograd.unsqueeze(yb, axis=[0])
372+
ys.append(yb)
373+
y = autograd.cat(ys, axis=0)
374+
return y

0 commit comments

Comments
 (0)