Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions examples/text_generation/shakespeare_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

import objax
from objax.functional import one_hot
from objax.zoo.rnn import RNN
<<<<<<< HEAD:examples/text_generation/shakespeare_rnn.py
from objax.nn import SimpleRNN
=======
from objax.nn import RNN
>>>>>>> 2c04d4e (Move RNN to layers.py and make it stateless.):examples/rnn/shakespeare.py
Comment on lines +25 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your commit contains an unresolved merge.



def tokenize(lines, token_type='word'):
Expand Down Expand Up @@ -137,32 +141,27 @@ def load_shakespeare(batch_size, num_steps, token_type):
num_hiddens = 256
lr = 0.0001
theta = 1
print(jax.local_devices())

train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char')
vocab_size = len(vocab)

model = RNN(num_hiddens, vocab_size, vocab_size)
model = SimpleRNN(num_hiddens, vocab_size, vocab_size)
model_vars = model.vars()
model.init_state(batch_size)

# Sample call for forward pass
X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T
X_one_hot = one_hot(X, vocab_size)
Z = model(X_one_hot)
print("X_one_hot.shape:", X_one_hot.shape)
print("Z.shape:", Z.shape)
Z, _ = model(X_one_hot)


def predict_char(prefix, num_predicts, model, vocab):
model.init_state(batch_size=1)
outputs = [vocab[prefix[0]]]
get_input = lambda: one_hot(jn.array([outputs[-1]]).reshape(1, 1), len(vocab))
for y in prefix[1:]: # Warmup state with prefix
model(get_input())
outputs.append(vocab[y])
for _ in range(num_predicts): # Predict num_predicts steps
Y = model(get_input())
Y, _ = model(get_input())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Uppercase are for global constants, use lower case identifiers for variables please.
  2. Also rather than doing two assigns, the better way is to just assign what you use.
    Y = model(get_input())[0]

outc = int(Y.argmax(axis=1).reshape(1))
outputs.append(outc)
return ''.join([vocab.idx_to_token[i] for i in outputs])
Expand All @@ -172,12 +171,11 @@ def predict_char(prefix, num_predicts, model, vocab):

opt = objax.optimizer.Adam(model_vars)
ema = objax.optimizer.ExponentialMovingAverage(model_vars, momentum=0.999)
predict = ema.replace_vars(objax.Jit(lambda x: objax.functional.softmax(model(x)), model_vars))


def loss(x, label): # sum(label * log(softmax(logit)))
logit = model(x)
return objax.functional.loss.cross_entropy_logits(logit, label).mean()
logits, _ = model(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits = model(x)[0]

return objax.functional.loss.cross_entropy_logits(logits, label).mean()


gv = objax.GradValues(loss, model.vars())
Expand Down
72 changes: 71 additions & 1 deletion objax/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

__all__ = ['BatchNorm', 'BatchNorm0D', 'BatchNorm1D', 'BatchNorm2D',
'Conv2D', 'ConvTranspose2D', 'Dropout', 'Linear',
'MovingAverage', 'ExponentialMovingAverage', 'Sequential',
'MovingAverage', 'ExponentialMovingAverage', 'SimpleRNN', 'Sequential',
'SyncedBatchNorm', 'SyncedBatchNorm0D', 'SyncedBatchNorm1D', 'SyncedBatchNorm2D']

from typing import Callable, Iterable, Tuple, Optional, Union, List
Expand Down Expand Up @@ -332,6 +332,76 @@ def __call__(self, x: JaxArray) -> JaxArray:
return self.avg.value


class SimpleRNN(Module):
"""Simple Recurrent Neural Network (RNN) block."""

def __init__(self,
nstate: int,
nin: int,
nout: int,
activation: Callable = jn.tanh,
w_init: Callable = kaiming_normal):
"""Creates an RNN instance.

Args:
nstate: number of hidden units.
nin: number of input units.
nout: number of output units.
activation: actication function for hidden layer.
w_init: weight initializer for RNN model weights.
"""
assert nin > 0, 'nin should be larger than zero'
assert nout > 0, 'nout should be larger than zero'
assert nstate > 0, 'nstate should be larger than zero'
self.num_inputs = nin
self.num_outputs = nout
self.nstate = nstate
self.activation = activation

# Hidden layer parameters
self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate)))
self.w_hh = TrainVar(w_init((self.nstate, self.nstate)))
self.b_h = TrainVar(jn.zeros(self.nstate))

self.output_layer = Linear(self.nstate, self.num_outputs, w_init=w_init)

def __call__(self, inputs: JaxArray, initial_state: JaxArray = None,
only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]:
Comment on lines +368 to +369
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One argument per line if they don't all fit on one line.

"""Forward pass through RNN.

Args:
inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, nout``.
only_return_final: return only the last output if ``True``, or all output otherwise.`

Returns:
Tuple with two elements:
First, output tensor with dimensions ``N * batch_size, nout``.
N = 1 if ``only_return_final`` is ``True`` and ``num_steps`` otherwise.
Second, state with dimensions ``batch_size, nstate``.
"""
outputs = []

if initial_state is None:
state = jn.zeros((inputs.shape[0], self.nstate))
else:
state = initial_state

for x in inputs:
state = self.activation(
jn.dot(x, self.w_xh.value)
+ jn.dot(state, self.w_hh.value)
Comment on lines +391 to +392
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_inputs could be zero. -- Essentially empty inputs but internal states continue to evolve along time.

Not sure if we shall use two weight matrices or one to act on concatenated [h, x].

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically it's more efficient to act on one concatenated [h, x], but depends on the system and sizes. At some point you can make this an __init__ mode parameter like Keras does. For now I'd suggest using the concatenated format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another nit, use x.dot(y) rather than jn.dot(x, y) since we might as well take advantage of object oriented APIs.

+ self.b_h.value
)
y = self.output_layer(state)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need output_layer or can we directly return internal states h and let user do further transform on that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for having an output_layer

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question why: this is something the user can do themselves after, right? So is there any purpose to add an output_layer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would drop the output layer, that's forcing a decision on the user about what type of output they'd want.

if not only_return_final:
outputs.append(y)

if only_return_final:
return y, state
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for else.

return jn.concatenate(outputs, axis=0), state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be jn.stack?



class Sequential(ModuleList):
"""Executes modules in the order they were passed to the constructor."""

Expand Down
81 changes: 0 additions & 81 deletions objax/zoo/rnn.py

This file was deleted.

50 changes: 50 additions & 0 deletions tests/simple_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unittests for Simple RNN"""

import unittest

import jax.numpy as jn

from objax.nn.layers import SimpleRNN
from objax.functional import one_hot
from objax.functional.core.activation import relu
from objax.nn.init import identity


class TestSimpleRNN(unittest.TestCase):

def test_simple_rnn(self):
nin = nout = 3
batch_size = 1
num_hiddens = 1
model = SimpleRNN(num_hiddens, nin, nout, activation=relu, w_init=identity)

X = jn.arange(batch_size)
X_one_hot = one_hot(X, nin)

Z, _ = model(X_one_hot)

self.assertEqual(Z.shape, (batch_size, nout))
self.assertTrue(jn.array_equal(Z, X_one_hot))

# Test passing in an explicit initial state
state = jn.array([[2.]])
Z, _ = model(X_one_hot, state)
self.assertTrue(jn.array_equal(Z, jn.array([[3., 0., 0.]])))


if __name__ == '__main__':
unittest.main()