-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder.py
More file actions
38 lines (30 loc) · 1.37 KB
/
decoder.py
File metadata and controls
38 lines (30 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from bahdanau import Bahdanau
RECURRENT_INITIALISER = 'glorot_uniform'
class Decoder(tf.keras.Model):
dimension = 256
def __init__(self, vocabulary, dimension, decoding, batchSize):
super(Decoder, self).__init__()
self.decoding = decoding
# GRU - Grated Recurrent Unit is an RNN
self.gru = tf.keras.layers.GRU(self.decoding,
return_sequences=True,
return_state=True,
recurrent_initializer=RECURRENT_INITIALISER)
self.fc = tf.keras.layers.Dense(vocabulary)
self.batchSize = batchSize
self.embedding = tf.keras.layers.Embedding(vocabulary, 256)
# Bahdanau provides attention for decoder
self.attention = Bahdanau(self.decoding)
def call(self, x, hidden, enc_output):
decodeVector, bahdanauWeights = self.attention(hidden, enc_output)
x = self.embedding(x)
x = tf.concat([tf.expand_dims(decodeVector, 1), x], axis=-1)
output, stateVector = self.gru(x)
shape = output.shape[2]
output = tf.reshape(output, (-1, shape))
x = self.fc(output)
return x, stateVector, bahdanauWeights
# Inspired and modified from TensorFlow example
# TensorFlow Addons Networks : Sequence-to-Sequence NMT with Attention Mechanism
# 2021