@@ -56,16 +56,16 @@ def get_config(self):
5656 def build (self , inputs_shape ):
5757 # Transformation for linearly projecting the queries, keys, and values.
5858 self .q_transformation = self ._get_weights (
59- "q_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .keras . initializers .get ('glorot_uniform' )
59+ "q_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .initializers .get ('glorot_uniform' )
6060 )
6161 self .v_transformation = self ._get_weights (
62- "v_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .keras . initializers .get ('glorot_uniform' )
62+ "v_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .initializers .get ('glorot_uniform' )
6363 )
6464 self .k_transformation = self ._get_weights (
65- "k_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .keras . initializers .get ('glorot_uniform' )
65+ "k_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .initializers .get ('glorot_uniform' )
6666 )
6767 self .out_transformation = self ._get_weights (
68- "out_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .keras . initializers .get ('glorot_uniform' )
68+ "out_project" , shape = (self .hidden_size , self .hidden_size ), init = tf .initializers .get ('glorot_uniform' )
6969 )
7070
7171 def split_heads (self , x ):
@@ -108,7 +108,7 @@ def combine_heads(self, x):
108108 x = tf .transpose (x , [0 , 2 , 1 , 3 ]) # --> [batch, length, num_heads, depth]
109109 return tf .reshape (x , [batch_size , length , self .hidden_size ])
110110
111- def forward (self , inputs , mask , cache = None ):
111+ def forward (self , x , y , mask , cache = None ):
112112 """Apply attention mechanism to x and y.
113113
114114 Args:
@@ -130,14 +130,8 @@ def forward(self, inputs, mask, cache=None):
130130 # multiple heads. Multi-head attention uses multiple queries, keys, and
131131 # values rather than regular attention (which uses a single q, k, v).
132132
133- if (len (inputs ) == 2 ):
134- q = inputs [0 ]
135- k = v = inputs [1 ]
136-
137- if (len (inputs ) == 3 ):
138- q = inputs [0 ]
139- k = inputs [1 ]
140- v = inputs [2 ]
133+ v = k = y
134+ q = x
141135
142136 q = tf .tensordot (q , self .q_transformation , axes = [[2 ], [0 ]])
143137 k = tf .tensordot (k , self .k_transformation , axes = [[2 ], [0 ]])
@@ -166,6 +160,7 @@ def forward(self, inputs, mask, cache=None):
166160 logits = tf .matmul (q , k , transpose_b = True ) #(Batch, num_head, length_q, length_k)
167161 logits += mask
168162 weights = tf .nn .softmax (logits , name = "attention_weights" ) #(Batch, num_head, length_q, length_k)
163+ weights_store = weights
169164 if self .is_train :
170165 weights = tf .nn .dropout (weights , rate = self .attention_dropout )
171166
@@ -176,11 +171,11 @@ def forward(self, inputs, mask, cache=None):
176171
177172 # Run the combined outputs through another linear projection layer.
178173 attention_output = tf .tensordot (attention_output , self .out_transformation , axes = [[2 ], [0 ]])
179- return attention_output
174+ return attention_output , weights_store
180175
181176
182177class SelfAttentionLayer (MultiHeadAttentionLayer ):
183178 """Multiheaded self-attention layer."""
184179
185180 def forward (self , inputs , mask , cache = None ):
186- return super (SelfAttentionLayer , self ).forward (inputs = [ inputs , inputs ] , mask = mask , cache = cache )
181+ return super (SelfAttentionLayer , self ).forward (x = inputs , y = inputs , mask = mask , cache = cache )
0 commit comments