-
Notifications
You must be signed in to change notification settings - Fork 73
Move RNN to layers.py and make it stateless. #97
base: master
Are you sure you want to change the base?
Changes from all commits
311e211
5f6e4af
853d4f0
f7774f4
efcb605
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,11 @@ | |
|
|
||
| import objax | ||
| from objax.functional import one_hot | ||
| from objax.zoo.rnn import RNN | ||
| <<<<<<< HEAD:examples/text_generation/shakespeare_rnn.py | ||
| from objax.nn import SimpleRNN | ||
| ======= | ||
| from objax.nn import RNN | ||
| >>>>>>> 2c04d4e (Move RNN to layers.py and make it stateless.):examples/rnn/shakespeare.py | ||
|
|
||
|
|
||
| def tokenize(lines, token_type='word'): | ||
|
|
@@ -137,32 +141,27 @@ def load_shakespeare(batch_size, num_steps, token_type): | |
| num_hiddens = 256 | ||
| lr = 0.0001 | ||
| theta = 1 | ||
| print(jax.local_devices()) | ||
|
|
||
| train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char') | ||
| vocab_size = len(vocab) | ||
|
|
||
| model = RNN(num_hiddens, vocab_size, vocab_size) | ||
| model = SimpleRNN(num_hiddens, vocab_size, vocab_size) | ||
| model_vars = model.vars() | ||
| model.init_state(batch_size) | ||
|
|
||
| # Sample call for forward pass | ||
| X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T | ||
| X_one_hot = one_hot(X, vocab_size) | ||
| Z = model(X_one_hot) | ||
| print("X_one_hot.shape:", X_one_hot.shape) | ||
| print("Z.shape:", Z.shape) | ||
| Z, _ = model(X_one_hot) | ||
|
|
||
|
|
||
| def predict_char(prefix, num_predicts, model, vocab): | ||
| model.init_state(batch_size=1) | ||
| outputs = [vocab[prefix[0]]] | ||
| get_input = lambda: one_hot(jn.array([outputs[-1]]).reshape(1, 1), len(vocab)) | ||
| for y in prefix[1:]: # Warmup state with prefix | ||
| model(get_input()) | ||
| outputs.append(vocab[y]) | ||
| for _ in range(num_predicts): # Predict num_predicts steps | ||
| Y = model(get_input()) | ||
| Y, _ = model(get_input()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| outc = int(Y.argmax(axis=1).reshape(1)) | ||
| outputs.append(outc) | ||
| return ''.join([vocab.idx_to_token[i] for i in outputs]) | ||
|
|
@@ -172,12 +171,11 @@ def predict_char(prefix, num_predicts, model, vocab): | |
|
|
||
| opt = objax.optimizer.Adam(model_vars) | ||
| ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999) | ||
| predict = ema.replace_vars(objax.Jit(lambda x: objax.functional.softmax(model(x)), model_vars)) | ||
|
|
||
|
|
||
| def loss(x, label): # sum(label * log(softmax(logit))) | ||
| logit = model(x) | ||
| return objax.functional.loss.cross_entropy_logits(logit, label).mean() | ||
| logits, _ = model(x) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| return objax.functional.loss.cross_entropy_logits(logits, label).mean() | ||
|
|
||
|
|
||
| gv = objax.GradValues(loss, model.vars()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
|
|
||
| __all__ = ['BatchNorm', 'BatchNorm0D', 'BatchNorm1D', 'BatchNorm2D', | ||
| 'Conv2D', 'ConvTranspose2D', 'Dropout', 'Linear', | ||
| 'MovingAverage', 'ExponentialMovingAverage', 'Sequential', | ||
| 'MovingAverage', 'ExponentialMovingAverage', 'SimpleRNN', 'Sequential', | ||
| 'SyncedBatchNorm', 'SyncedBatchNorm0D', 'SyncedBatchNorm1D', 'SyncedBatchNorm2D'] | ||
|
|
||
| from typing import Callable, Iterable, Tuple, Optional, Union, List | ||
|
|
@@ -332,6 +332,76 @@ def __call__(self, x: JaxArray) -> JaxArray: | |
| return self.avg.value | ||
|
|
||
|
|
||
| class SimpleRNN(Module): | ||
| """Simple Recurrent Neural Network (RNN) block.""" | ||
|
|
||
| def __init__(self, | ||
| nstate: int, | ||
| nin: int, | ||
| nout: int, | ||
| activation: Callable = jn.tanh, | ||
| w_init: Callable = kaiming_normal): | ||
| """Creates an RNN instance. | ||
|
|
||
| Args: | ||
| nstate: number of hidden units. | ||
| nin: number of input units. | ||
| nout: number of output units. | ||
| activation: actication function for hidden layer. | ||
| w_init: weight initializer for RNN model weights. | ||
| """ | ||
| assert nin > 0, 'nin should be larger than zero' | ||
| assert nout > 0, 'nout should be larger than zero' | ||
| assert nstate > 0, 'nstate should be larger than zero' | ||
| self.num_inputs = nin | ||
| self.num_outputs = nout | ||
| self.nstate = nstate | ||
| self.activation = activation | ||
|
|
||
| # Hidden layer parameters | ||
| self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) | ||
| self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) | ||
aterzis-google marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.b_h = TrainVar(jn.zeros(self.nstate)) | ||
|
|
||
| self.output_layer = Linear(self.nstate, self.num_outputs, w_init=w_init) | ||
|
|
||
| def __call__(self, inputs: JaxArray, initial_state: JaxArray = None, | ||
| only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]: | ||
|
Comment on lines
+368
to
+369
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One argument per line if they don't all fit on one line. |
||
| """Forward pass through RNN. | ||
|
|
||
| Args: | ||
| inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, nout``. | ||
| only_return_final: return only the last output if ``True``, or all output otherwise.` | ||
|
|
||
| Returns: | ||
| Tuple with two elements: | ||
| First, output tensor with dimensions ``N * batch_size, nout``. | ||
| N = 1 if ``only_return_final`` is ``True`` and ``num_steps`` otherwise. | ||
| Second, state with dimensions ``batch_size, nstate``. | ||
| """ | ||
| outputs = [] | ||
|
|
||
| if initial_state is None: | ||
| state = jn.zeros((inputs.shape[0], self.nstate)) | ||
| else: | ||
| state = initial_state | ||
|
|
||
| for x in inputs: | ||
| state = self.activation( | ||
| jn.dot(x, self.w_xh.value) | ||
| + jn.dot(state, self.w_hh.value) | ||
|
Comment on lines
+391
to
+392
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not sure if we shall use two weight matrices or one to act on concatenated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typically it's more efficient to act on one concatenated
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another nit, use |
||
| + self.b_h.value | ||
| ) | ||
| y = self.output_layer(state) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opted for having an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question why: this is something the user can do themselves after, right? So is there any purpose to add an output_layer?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would drop the output layer, that's forcing a decision on the user about what type of output they'd want. |
||
| if not only_return_final: | ||
| outputs.append(y) | ||
|
|
||
| if only_return_final: | ||
| return y, state | ||
| else: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for |
||
| return jn.concatenate(outputs, axis=0), state | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be |
||
|
|
||
|
|
||
| class Sequential(ModuleList): | ||
| """Executes modules in the order they were passed to the constructor.""" | ||
|
|
||
|
|
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright 2020 Google LLC | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Unittests for Simple RNN""" | ||
|
|
||
| import unittest | ||
|
|
||
| import jax.numpy as jn | ||
|
|
||
| from objax.nn.layers import SimpleRNN | ||
| from objax.functional import one_hot | ||
| from objax.functional.core.activation import relu | ||
| from objax.nn.init import identity | ||
|
|
||
|
|
||
| class TestSimpleRNN(unittest.TestCase): | ||
|
|
||
| def test_simple_rnn(self): | ||
| nin = nout = 3 | ||
| batch_size = 1 | ||
| num_hiddens = 1 | ||
| model = SimpleRNN(num_hiddens, nin, nout, activation=relu, w_init=identity) | ||
|
|
||
| X = jn.arange(batch_size) | ||
| X_one_hot = one_hot(X, nin) | ||
|
|
||
| Z, _ = model(X_one_hot) | ||
|
|
||
| self.assertEqual(Z.shape, (batch_size, nout)) | ||
| self.assertTrue(jn.array_equal(Z, X_one_hot)) | ||
|
|
||
| # Test passing in an explicit initial state | ||
| state = jn.array([[2.]]) | ||
| Z, _ = model(X_one_hot, state) | ||
| self.assertTrue(jn.array_equal(Z, jn.array([[3., 0., 0.]]))) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your commit contains an unresolved merge.