Skip to content

Commit 8911654

Browse files
Lingjun LiuLingjun Liu
authored andcommitted
add attention visualisation
1 parent 005ab91 commit 8911654

File tree

3 files changed

+317
-282
lines changed

3 files changed

+317
-282
lines changed

tensorlayer/models/transformer/attention_layer.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ def get_config(self):
5656
def build(self, inputs_shape):
5757
# Transformation for linearly projecting the queries, keys, and values.
5858
self.q_transformation = self._get_weights(
59-
"q_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
59+
"q_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
6060
)
6161
self.v_transformation = self._get_weights(
62-
"v_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
62+
"v_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
6363
)
6464
self.k_transformation = self._get_weights(
65-
"k_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
65+
"k_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
6666
)
6767
self.out_transformation = self._get_weights(
68-
"out_project", shape=(self.hidden_size, self.hidden_size), init=tf.keras.initializers.get('glorot_uniform')
68+
"out_project", shape=(self.hidden_size, self.hidden_size), init=tf.initializers.get('glorot_uniform')
6969
)
7070

7171
def split_heads(self, x):
@@ -108,7 +108,7 @@ def combine_heads(self, x):
108108
x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
109109
return tf.reshape(x, [batch_size, length, self.hidden_size])
110110

111-
def forward(self, inputs, mask, cache=None):
111+
def forward(self, x, y, mask, cache=None):
112112
"""Apply attention mechanism to x and y.
113113
114114
Args:
@@ -130,14 +130,8 @@ def forward(self, inputs, mask, cache=None):
130130
# multiple heads. Multi-head attention uses multiple queries, keys, and
131131
# values rather than regular attention (which uses a single q, k, v).
132132

133-
if (len(inputs) == 2):
134-
q = inputs[0]
135-
k = v = inputs[1]
136-
137-
if (len(inputs) == 3):
138-
q = inputs[0]
139-
k = inputs[1]
140-
v = inputs[2]
133+
v = k = y
134+
q = x
141135

142136
q = tf.tensordot(q, self.q_transformation, axes=[[2], [0]])
143137
k = tf.tensordot(k, self.k_transformation, axes=[[2], [0]])
@@ -166,6 +160,7 @@ def forward(self, inputs, mask, cache=None):
166160
logits = tf.matmul(q, k, transpose_b=True) #(Batch, num_head, length_q, length_k)
167161
logits += mask
168162
weights = tf.nn.softmax(logits, name="attention_weights") #(Batch, num_head, length_q, length_k)
163+
weights_store = weights
169164
if self.is_train:
170165
weights = tf.nn.dropout(weights, rate=self.attention_dropout)
171166

@@ -176,11 +171,11 @@ def forward(self, inputs, mask, cache=None):
176171

177172
# Run the combined outputs through another linear projection layer.
178173
attention_output = tf.tensordot(attention_output, self.out_transformation, axes=[[2], [0]])
179-
return attention_output
174+
return attention_output, weights_store
180175

181176

182177
class SelfAttentionLayer(MultiHeadAttentionLayer):
183178
"""Multiheaded self-attention layer."""
184179

185180
def forward(self, inputs, mask, cache=None):
186-
return super(SelfAttentionLayer, self).forward(inputs=[inputs, inputs], mask=mask, cache=cache)
181+
return super(SelfAttentionLayer, self).forward(x=inputs, y=inputs, mask=mask, cache=cache)

tensorlayer/models/transformer/feedforward_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ def __init__(self, hidden_size, filter_size, keep_prob):
3838
self.filter_size = filter_size
3939
self.relu_dropout = 1 - keep_prob
4040
self.filter_dense_layer = tl.layers.Dense(
41-
self.filter_size, in_channels=self.hidden_size, W_init=tf.keras.initializers.get('glorot_uniform'),
41+
self.filter_size, in_channels=self.hidden_size, W_init=tf.initializers.get('glorot_uniform'),
4242
name="input_layer"
4343
)
4444
self.output_dense_layer = tl.layers.Dense(
45-
self.hidden_size, in_channels=self.filter_size, W_init=tf.keras.initializers.get('glorot_uniform'),
45+
self.hidden_size, in_channels=self.filter_size, W_init=tf.initializers.get('glorot_uniform'),
4646
name="output_layer"
4747
)
4848
self.build(None)

0 commit comments

Comments
 (0)