From 311e211cdf2e19f4b2c6610bfeb6345ee74a10e7 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Mon, 12 Oct 2020 15:59:48 -0700 Subject: [PATCH 1/5] Move RNN to layers.py and make it stateless. --- examples/text_generation/shakespeare_rnn.py | 12 +-- objax/nn/layers.py | 59 ++++++++++++++- objax/zoo/rnn.py | 81 --------------------- 3 files changed, 61 insertions(+), 91 deletions(-) delete mode 100644 objax/zoo/rnn.py diff --git a/examples/text_generation/shakespeare_rnn.py b/examples/text_generation/shakespeare_rnn.py index 7c07d1d..541ed0c 100644 --- a/examples/text_generation/shakespeare_rnn.py +++ b/examples/text_generation/shakespeare_rnn.py @@ -22,7 +22,7 @@ import objax from objax.functional import one_hot -from objax.zoo.rnn import RNN +from objax.nn import RNN def tokenize(lines, token_type='word'): @@ -137,25 +137,20 @@ def load_shakespeare(batch_size, num_steps, token_type): num_hiddens = 256 lr = 0.0001 theta = 1 -print(jax.local_devices()) train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char') vocab_size = len(vocab) model = RNN(num_hiddens, vocab_size, vocab_size) model_vars = model.vars() -model.init_state(batch_size) # Sample call for forward pass X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T X_one_hot = one_hot(X, vocab_size) Z = model(X_one_hot) -print("X_one_hot.shape:", X_one_hot.shape) -print("Z.shape:", Z.shape) def predict_char(prefix, num_predicts, model, vocab): - model.init_state(batch_size=1) outputs = [vocab[prefix[0]]] get_input = lambda: one_hot(jn.array([outputs[-1]]).reshape(1, 1), len(vocab)) for y in prefix[1:]: # Warmup state with prefix @@ -172,12 +167,11 @@ def predict_char(prefix, num_predicts, model, vocab): opt = objax.optimizer.Adam(model_vars) ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999) -predict = ema.replace_vars(objax.Jit(lambda x: objax.functional.softmax(model(x)), model_vars)) def loss(x, label): # sum(label * log(softmax(logit))) - logit = model(x) - return objax.functional.loss.cross_entropy_logits(logit, label).mean() + logits = model(x) + return objax.functional.loss.cross_entropy_logits(logits, label).mean() gv = objax.GradValues(loss, model.vars()) diff --git a/objax/nn/layers.py b/objax/nn/layers.py index 404cde5..92caacc 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -14,7 +14,7 @@ __all__ = ['BatchNorm', 'BatchNorm0D', 'BatchNorm1D', 'BatchNorm2D', 'Conv2D', 'ConvTranspose2D', 'Dropout', 'Linear', - 'MovingAverage', 'ExponentialMovingAverage', 'Sequential', + 'MovingAverage', 'ExponentialMovingAverage', 'RNN', 'Sequential', 'SyncedBatchNorm', 'SyncedBatchNorm0D', 'SyncedBatchNorm1D', 'SyncedBatchNorm2D'] from typing import Callable, Iterable, Tuple, Optional, Union, List @@ -331,6 +331,63 @@ def __call__(self, x: JaxArray) -> JaxArray: self.avg.value += (self.avg.value - x) * (self.momentum - 1) return self.avg.value +class RNN(Module): + """ Recurrent Neural Network (RNN) block.""" + + def __init__(self, + nstate: int, + nin: int, + nout: int, + activation: Callable = jn.tanh, + w_init: Callable = kaiming_normal): + """Creates an RNN instance. + + Args: + nstate: number of hidden units. + nin: number of input units. + nout: number of output units. + activation: actication function for hidden layer. + w_init: weight initializer for RNN model weights. + """ + self.num_inputs = nin + self.num_outputs = nout + self.nstate = nstate + self.activation = activation + + # Hidden layer parameters + self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) + self.b_h = TrainVar(jn.zeros(self.nstate)) + + self.output_layer = Linear(self.nstate, self.num_outputs) + + def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: + """Forward pass through RNN. + + Args: + inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. + only_return_final: return only the last output if ``True``, or all output otherwise.` + + Returns: + Output tensor with dimensions ``N * batch_size, vocabulary_size``. + N = 1 if ``only_return_final`` is ``True`` and ``num_steps`` otherwise. + """ + outputs = [] + state = jn.zeros((inputs.shape[1], self.nstate)) + for x in inputs: + state = self.activation( + jn.dot(x, self.w_xh.value) + + jn.dot(state, self.w_hh.value) + + self.b_h.value + ) + y = self.output_layer(state) + if not only_return_final: + outputs.append(y) + + if only_return_final: + return y + else: + return jn.concatenate(outputs, axis=0) class Sequential(ModuleList): """Executes modules in the order they were passed to the constructor.""" diff --git a/objax/zoo/rnn.py b/objax/zoo/rnn.py deleted file mode 100644 index ddc3c15..0000000 --- a/objax/zoo/rnn.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable - -import jax.numpy as jn - -from objax import Module -from objax.nn import Linear -from objax.nn.init import kaiming_normal -from objax.typing import JaxArray -from objax.variable import TrainVar, StateVar - - -class RNN(Module): - """ Recurrent Neural Network (RNN) block.""" - - def __init__(self, - nstate: int, - nin: int, - nout: int, - activation: Callable = jn.tanh, - w_init: Callable = kaiming_normal): - """Creates an RNN instance. - - Args: - nstate: number of hidden units. - nin: number of input units. - nout: number of output units. - activation: actication function for hidden layer. - w_init: weight initializer for RNN model weights. - """ - self.num_inputs = nin - self.num_outputs = nout - self.nstate = nstate - self.activation = activation - - # Hidden layer parameters - self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) - self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) - self.b_h = TrainVar(jn.zeros(self.nstate)) - - self.output_layer = Linear(self.nstate, self.num_outputs) - - def init_state(self, batch_size): - """Initialize hidden state for input batch of size ``batch_size``.""" - self.state = StateVar(jn.zeros((batch_size, self.nstate))) - - def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: - """Forward pass through RNN. - - Args: - inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. - only_return_final: return only the last output if ``True``, or all output otherwise.` - - Returns: - Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. - """ - # Dimensions: num_steps, batch_size, vocab_size - outputs = [] - for x in inputs: - self.state.value = self.activation( - jn.dot(x, self.w_xh.value) - + jn.dot(self.state.value, self.w_hh.value) - + self.b_h.value) - y = self.output_layer(self.state.value) - outputs.append(y) - if only_return_final: - return outputs[-1] - return jn.concatenate(outputs, axis=0) From 5f6e4afecb814f9e3221892e77709df7fd513ddf Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Mon, 26 Oct 2020 09:08:31 -0700 Subject: [PATCH 2/5] Rename RNN to SimpleRNN, add the ability to pass initial_state to the constructor, and output RNN state when call() returns. --- examples/text_generation/shakespeare_rnn.py | 10 ++++---- objax/nn/layers.py | 27 ++++++++++++++------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/examples/text_generation/shakespeare_rnn.py b/examples/text_generation/shakespeare_rnn.py index 541ed0c..7f8c5c4 100644 --- a/examples/text_generation/shakespeare_rnn.py +++ b/examples/text_generation/shakespeare_rnn.py @@ -22,7 +22,7 @@ import objax from objax.functional import one_hot -from objax.nn import RNN +from objax.nn import SimpleRNN def tokenize(lines, token_type='word'): @@ -141,13 +141,13 @@ def load_shakespeare(batch_size, num_steps, token_type): train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char') vocab_size = len(vocab) -model = RNN(num_hiddens, vocab_size, vocab_size) +model = SimpleRNN(num_hiddens, vocab_size, vocab_size) model_vars = model.vars() # Sample call for forward pass X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T X_one_hot = one_hot(X, vocab_size) -Z = model(X_one_hot) +Z, _ = model(X_one_hot) def predict_char(prefix, num_predicts, model, vocab): @@ -157,7 +157,7 @@ def predict_char(prefix, num_predicts, model, vocab): model(get_input()) outputs.append(vocab[y]) for _ in range(num_predicts): # Predict num_predicts steps - Y = model(get_input()) + Y, _ = model(get_input()) outc = int(Y.argmax(axis=1).reshape(1)) outputs.append(outc) return ''.join([vocab.idx_to_token[i] for i in outputs]) @@ -170,7 +170,7 @@ def predict_char(prefix, num_predicts, model, vocab): def loss(x, label): # sum(label * log(softmax(logit))) - logits = model(x) + logits, _ = model(x) return objax.functional.loss.cross_entropy_logits(logits, label).mean() diff --git a/objax/nn/layers.py b/objax/nn/layers.py index 92caacc..620defb 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -14,7 +14,7 @@ __all__ = ['BatchNorm', 'BatchNorm0D', 'BatchNorm1D', 'BatchNorm2D', 'Conv2D', 'ConvTranspose2D', 'Dropout', 'Linear', - 'MovingAverage', 'ExponentialMovingAverage', 'RNN', 'Sequential', + 'MovingAverage', 'ExponentialMovingAverage', 'SimpleRNN', 'Sequential', 'SyncedBatchNorm', 'SyncedBatchNorm0D', 'SyncedBatchNorm1D', 'SyncedBatchNorm2D'] from typing import Callable, Iterable, Tuple, Optional, Union, List @@ -331,8 +331,8 @@ def __call__(self, x: JaxArray) -> JaxArray: self.avg.value += (self.avg.value - x) * (self.momentum - 1) return self.avg.value -class RNN(Module): - """ Recurrent Neural Network (RNN) block.""" +class SimpleRNN(Module): + """Simple Recurrent Neural Network (RNN) block.""" def __init__(self, nstate: int, @@ -361,19 +361,28 @@ def __init__(self, self.output_layer = Linear(self.nstate, self.num_outputs) - def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: + def __call__(self, inputs: JaxArray, + initial_state: JaxArray = None, + only_return_final = False) -> Tuple[JaxArray, JaxArray]: """Forward pass through RNN. Args: - inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. + inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, nout``. only_return_final: return only the last output if ``True``, or all output otherwise.` Returns: - Output tensor with dimensions ``N * batch_size, vocabulary_size``. + Tuple with two elements: + First, output tensor with dimensions ``N * batch_size, nout``. N = 1 if ``only_return_final`` is ``True`` and ``num_steps`` otherwise. + Second, state with dimensions ``batch_size, nstate``. """ outputs = [] - state = jn.zeros((inputs.shape[1], self.nstate)) + + if initial_state == None: + state = jn.zeros((inputs.shape[1], self.nstate)) + else: + state = initial_state + for x in inputs: state = self.activation( jn.dot(x, self.w_xh.value) @@ -385,9 +394,9 @@ def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: outputs.append(y) if only_return_final: - return y + return y, state else: - return jn.concatenate(outputs, axis=0) + return jn.concatenate(outputs, axis=0), state class Sequential(ModuleList): """Executes modules in the order they were passed to the constructor.""" From 853d4f02235ef47127f61d5b86ae8b664bd102eb Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Wed, 28 Oct 2020 11:48:43 -0700 Subject: [PATCH 3/5] Updated RNN test. --- objax/nn/layers.py | 7 ++++-- tests/simple_rnn.py | 52 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 tests/simple_rnn.py diff --git a/objax/nn/layers.py b/objax/nn/layers.py index 620defb..c29c534c0 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -349,6 +349,9 @@ def __init__(self, activation: actication function for hidden layer. w_init: weight initializer for RNN model weights. """ + assert nin > 0, 'nin should be larger than zero' + assert nout > 0, 'nout should be larger than zero' + assert nstate > 0, 'nstate should be larger than zero' self.num_inputs = nin self.num_outputs = nout self.nstate = nstate @@ -359,7 +362,7 @@ def __init__(self, self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) self.b_h = TrainVar(jn.zeros(self.nstate)) - self.output_layer = Linear(self.nstate, self.num_outputs) + self.output_layer = Linear(self.nstate, self.num_outputs, w_init = w_init) def __call__(self, inputs: JaxArray, initial_state: JaxArray = None, @@ -379,7 +382,7 @@ def __call__(self, inputs: JaxArray, outputs = [] if initial_state == None: - state = jn.zeros((inputs.shape[1], self.nstate)) + state = jn.zeros((inputs.shape[0], self.nstate)) else: state = initial_state diff --git a/tests/simple_rnn.py b/tests/simple_rnn.py new file mode 100644 index 0000000..4c6cd10 --- /dev/null +++ b/tests/simple_rnn.py @@ -0,0 +1,52 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittests for Simple RNN""" + +import unittest + +import numpy as np +import jax.numpy as jn +import tensorflow as tf + +import objax +from objax.nn.layers import SimpleRNN +from objax.functional import one_hot +from objax.functional.core.activation import relu +from objax.nn.init import identity +from objax.zoo.resnet_v2 import convert_keras_model, load_pretrained_weights_from_keras + +class TestSimpleRNN(unittest.TestCase): + + def test_simple_rnn(self): + nin = nout = 3 + batch_size = 1 + num_hiddens = 1 + model = SimpleRNN(num_hiddens, nin, nout, activation = relu, w_init = identity) + + X = jn.arange(batch_size) + X_one_hot = one_hot(X, nin) + + Z, _ = model(X_one_hot) + + self.assertEqual(Z.shape, (batch_size, nout)) + self.assertTrue(jn.array_equal(Z, X_one_hot)) + + # Test passing in an explicit initial state + state = jn.array([[2.]]) + Z, _ = model(X_one_hot, state) + self.assertTrue(jn.array_equal(Z, jn.array([[3., 0., 0.]]))) + +if __name__ == '__main__': + unittest.main() From f7774f464c168530f49b374b2d0636288606fdca Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Mon, 12 Oct 2020 15:59:48 -0700 Subject: [PATCH 4/5] Move RNN to layers.py and make it stateless. --- examples/text_generation/shakespeare_rnn.py | 4 ++++ objax/nn/layers.py | 8 +++++--- tests/simple_rnn.py | 8 +++----- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/text_generation/shakespeare_rnn.py b/examples/text_generation/shakespeare_rnn.py index 7f8c5c4..bbb8e5f 100644 --- a/examples/text_generation/shakespeare_rnn.py +++ b/examples/text_generation/shakespeare_rnn.py @@ -22,7 +22,11 @@ import objax from objax.functional import one_hot +<<<<<<< HEAD:examples/text_generation/shakespeare_rnn.py from objax.nn import SimpleRNN +======= +from objax.nn import RNN +>>>>>>> 2c04d4e (Move RNN to layers.py and make it stateless.):examples/rnn/shakespeare.py def tokenize(lines, token_type='word'): diff --git a/objax/nn/layers.py b/objax/nn/layers.py index c29c534c0..1aa43f4 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -331,6 +331,7 @@ def __call__(self, x: JaxArray) -> JaxArray: self.avg.value += (self.avg.value - x) * (self.momentum - 1) return self.avg.value + class SimpleRNN(Module): """Simple Recurrent Neural Network (RNN) block.""" @@ -362,11 +363,11 @@ def __init__(self, self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) self.b_h = TrainVar(jn.zeros(self.nstate)) - self.output_layer = Linear(self.nstate, self.num_outputs, w_init = w_init) + self.output_layer = Linear(self.nstate, self.num_outputs, w_init=w_init) def __call__(self, inputs: JaxArray, initial_state: JaxArray = None, - only_return_final = False) -> Tuple[JaxArray, JaxArray]: + only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]: """Forward pass through RNN. Args: @@ -381,7 +382,7 @@ def __call__(self, inputs: JaxArray, """ outputs = [] - if initial_state == None: + if initial_state is None: state = jn.zeros((inputs.shape[0], self.nstate)) else: state = initial_state @@ -401,6 +402,7 @@ def __call__(self, inputs: JaxArray, else: return jn.concatenate(outputs, axis=0), state + class Sequential(ModuleList): """Executes modules in the order they were passed to the constructor.""" diff --git a/tests/simple_rnn.py b/tests/simple_rnn.py index 4c6cd10..9ed1269 100644 --- a/tests/simple_rnn.py +++ b/tests/simple_rnn.py @@ -16,16 +16,13 @@ import unittest -import numpy as np import jax.numpy as jn -import tensorflow as tf -import objax from objax.nn.layers import SimpleRNN from objax.functional import one_hot from objax.functional.core.activation import relu from objax.nn.init import identity -from objax.zoo.resnet_v2 import convert_keras_model, load_pretrained_weights_from_keras + class TestSimpleRNN(unittest.TestCase): @@ -33,7 +30,7 @@ def test_simple_rnn(self): nin = nout = 3 batch_size = 1 num_hiddens = 1 - model = SimpleRNN(num_hiddens, nin, nout, activation = relu, w_init = identity) + model = SimpleRNN(num_hiddens, nin, nout, activation=relu, w_init=identity) X = jn.arange(batch_size) X_one_hot = one_hot(X, nin) @@ -48,5 +45,6 @@ def test_simple_rnn(self): Z, _ = model(X_one_hot, state) self.assertTrue(jn.array_equal(Z, jn.array([[3., 0., 0.]]))) + if __name__ == '__main__': unittest.main() From efcb605fe581afb07b70d599f5e05137c9746801 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Wed, 28 Oct 2020 14:57:49 -0700 Subject: [PATCH 5/5] Fix linter errors. --- objax/nn/layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/objax/nn/layers.py b/objax/nn/layers.py index 1aa43f4..e6e6df3 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -365,8 +365,7 @@ def __init__(self, self.output_layer = Linear(self.nstate, self.num_outputs, w_init=w_init) - def __call__(self, inputs: JaxArray, - initial_state: JaxArray = None, + def __call__(self, inputs: JaxArray, initial_state: JaxArray = None, only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]: """Forward pass through RNN.