From 25b97002f4ad900001105278075ea26b57d543d1 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Sat, 22 Aug 2020 09:26:16 -0700 Subject: [PATCH 01/10] RNN -> Recurrent neural network --- examples/rnn/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/rnn/README.md b/examples/rnn/README.md index 187498f..337be32 100644 --- a/examples/rnn/README.md +++ b/examples/rnn/README.md @@ -2,7 +2,7 @@ # Examples -This directory contains examples using RNNs. +This directory contains examples using Recurrent Neural Networks (RNNs). See: * `shakespeare.py` - predict characters from Shakespeare's plays. From 9e54c36c7c2769b0a5c5f4d92454eacdcb33be26 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Thu, 10 Sep 2020 00:18:22 -0700 Subject: [PATCH 02/10] GRU module --- .../{shakespeare.py => shakespeare_rnn.py} | 0 objax/zoo/gru.py | 93 +++++++++++++++++++ requirements.txt | 6 +- 3 files changed, 97 insertions(+), 2 deletions(-) rename examples/rnn/{shakespeare.py => shakespeare_rnn.py} (100%) create mode 100644 objax/zoo/gru.py diff --git a/examples/rnn/shakespeare.py b/examples/rnn/shakespeare_rnn.py similarity index 100% rename from examples/rnn/shakespeare.py rename to examples/rnn/shakespeare_rnn.py diff --git a/objax/zoo/gru.py b/objax/zoo/gru.py new file mode 100644 index 0000000..d9fab2d --- /dev/null +++ b/objax/zoo/gru.py @@ -0,0 +1,93 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import jax.numpy as jn + +from objax import Module +from objax.nn.init import kaiming_normal +from objax.typing import JaxArray +from objax.variable import TrainVar, StateVar +from objax.functional import sigmoid + + +class GRU(Module): + """ Gated Recurrent Unit (GRU) block.""" + + def __init__(self, + nstate: int, + nin: int, + nout: int, + w_init: Callable = kaiming_normal): + """Creates a GRU instance. + + Args: + nstate: number of hidden units. + nin: number of input units. + nout: number of output units. + w_init: weight initializer for GRU model weights. + """ + self.num_inputs = nin + self.num_outputs = nout + self.nstate = nstate + + # Update gate parameters + self.w_xz = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hz = TrainVar(w_init((self.nstate, self.nstate))) + self.b_z = TrainVar(jn.zeros(self.nstate)) + + # Reset gate parameters + self.w_xr = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hr = TrainVar(w_init((self.nstate, self.nstate))) + self.b_r = TrainVar(jn.zeros(self.nstate)) + + # Candidate hidden state parameters + self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) + self.b_h = TrainVar(jn.zeros(self.nstate)) + + # Output layer parameters + self.w_hq = TrainVar(w_init((self.nstate, self.num_outputs))) + self.b_q = TrainVar(jn.zeros(self.num_outputs)) + + def init_state(self, batch_size): + """Initialize hidden state for input batch of size ``batch_size``.""" + self.state = StateVar(jn.zeros((batch_size, self.nstate))) + + def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: + """Forward pass through GRU. + + Args: + inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. + only_return_final: return only the last output if ``True``, or all output otherwise.` + + Returns: + Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. + """ + # Dimensions: num_steps, batch_size, vocab_size + outputs = [] + for x in inputs: + update_gate = sigmoid(jn.dot(x, self.w_xz.value) + jn.dot(self.state.value, self.w_hz.value) + + self.b_z. bnhvalue) + reset_gate = sigmoid(jn.dot(x, self.w_xr.value) + jn.dot(self.state.value, self.w_hr.value) + + self.b_r.value) + candidate_state = jn.tanh(jn.dot(x, self.w_xh.value) + + jn.dot(reset_gate * self.state.value, self.w_hh.value) + self.b_h.value) + self.state.value = update_gate * self.state.value + (1 - update_gate) * candidate_state + y = jn.dot(self.state.value, self.w_hq.value) + self.b_q.value + outputs.append(y) + if only_return_final: + return outputs[-1] + return jn.concatenate(outputs, axis=0) diff --git a/requirements.txt b/requirements.txt index 8f33c1e..10fd7ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ scipy -numpy>=1.18.0 -pillow +numpy~=1.17.4 +pillow~=7.2.0 jaxlib jax tensorboard>=2.3.0 parameterized + +setuptools~=46.1.3 \ No newline at end of file From efcde5c688bded0f3f6ab246e8e1639a295ffcf5 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Fri, 18 Sep 2020 16:47:49 -0700 Subject: [PATCH 03/10] Add Tutorials to the documentation. --- docs/source/index.rst | 1 + docs/source/tutorials.rst | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 docs/source/tutorials.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 71928b0..a4e3a6a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -90,6 +90,7 @@ Read more about this in :doc:`advanced/jit`. notebooks/Logistic_Regression notebooks/Custom_Networks examples + tutorials .. toctree:: :maxdepth: 2 diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst new file mode 100644 index 0000000..ebf1912 --- /dev/null +++ b/docs/source/tutorials.rst @@ -0,0 +1,7 @@ +Tutorials +========= + +This section includes various tutorials for Objax. + +* `MNIST Tutorial `_ +* `Metric learning for image similarity search `_ From 2e4569bc57ecf709b44d19aad3774c62c63f5061 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Mon, 28 Sep 2020 23:18:04 -0700 Subject: [PATCH 04/10] Move RNN cell to layers.py --- objax/nn/layers.py | 55 +++++++++++++++++++++++++++++++++++++++++++++- objax/zoo/rnn.py | 52 ------------------------------------------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/objax/nn/layers.py b/objax/nn/layers.py index d11b8fa..65c3ce1 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -14,7 +14,7 @@ __all__ = ['BatchNorm', 'BatchNorm0D', 'BatchNorm1D', 'BatchNorm2D', 'Conv2D', 'ConvTranspose2D', 'Dropout', 'Linear', - 'MovingAverage', 'ExponentialMovingAverage', 'Sequential', + 'MovingAverage', 'ExponentialMovingAverage', 'RNN', 'Sequential', 'SyncedBatchNorm', 'SyncedBatchNorm0D', 'SyncedBatchNorm1D', 'SyncedBatchNorm2D'] from typing import Callable, Iterable, Tuple, Optional, Union, List @@ -327,6 +327,59 @@ def __call__(self, x: JaxArray) -> JaxArray: self.avg.value += (self.avg.value - x) * (self.momentum - 1) return self.avg.value +class RNN(Module): + """ Recurrent Neural Network (RNN) block.""" + + def __init__(self, + nstate: int, + nin: int, + nout: int, + activation: Callable = jn.tanh, + w_init: Callable = kaiming_normal): + """Creates an RNN instance. + + Args: + nstate: number of hidden units. + nin: number of input units. + nout: number of output units. + activation: actication function for hidden layer. + w_init: weight initializer for RNN model weights. + """ + self.num_inputs = nin + self.num_outputs = nout + self.nstate = nstate + self.activation = activation + + # Hidden layer parameters + self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) + self.b_h = TrainVar(jn.zeros(self.nstate)) + + self.output_layer = Linear(self.nstate, self.num_outputs) + + def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: + """Forward pass through RNN. + + Args: + inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. + only_return_final: return only the last output if ``True``, or all output otherwise.` + + Returns: + Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. + """ + # Dimensions: num_steps, batch_size, vocab_size + outputs = [] + for x in inputs: + self.state.value = self.activation( + jn.dot(x, self.w_xh.value) + + jn.dot(self.state.value, self.w_hh.value) + + self.b_h.value) + y = self.output_layer(self.state.value) + outputs.append(y) + if only_return_final: + return outputs[-1] + return jn.concatenate(outputs, axis=0) + class Sequential(ModuleList): """Executes modules in the order they were passed to the constructor.""" diff --git a/objax/zoo/rnn.py b/objax/zoo/rnn.py index ddc3c15..cf9b104 100644 --- a/objax/zoo/rnn.py +++ b/objax/zoo/rnn.py @@ -24,58 +24,6 @@ class RNN(Module): - """ Recurrent Neural Network (RNN) block.""" - - def __init__(self, - nstate: int, - nin: int, - nout: int, - activation: Callable = jn.tanh, - w_init: Callable = kaiming_normal): - """Creates an RNN instance. - - Args: - nstate: number of hidden units. - nin: number of input units. - nout: number of output units. - activation: actication function for hidden layer. - w_init: weight initializer for RNN model weights. - """ - self.num_inputs = nin - self.num_outputs = nout - self.nstate = nstate - self.activation = activation - - # Hidden layer parameters - self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) - self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) - self.b_h = TrainVar(jn.zeros(self.nstate)) - - self.output_layer = Linear(self.nstate, self.num_outputs) - def init_state(self, batch_size): """Initialize hidden state for input batch of size ``batch_size``.""" self.state = StateVar(jn.zeros((batch_size, self.nstate))) - - def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: - """Forward pass through RNN. - - Args: - inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. - only_return_final: return only the last output if ``True``, or all output otherwise.` - - Returns: - Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. - """ - # Dimensions: num_steps, batch_size, vocab_size - outputs = [] - for x in inputs: - self.state.value = self.activation( - jn.dot(x, self.w_xh.value) - + jn.dot(self.state.value, self.w_hh.value) - + self.b_h.value) - y = self.output_layer(self.state.value) - outputs.append(y) - if only_return_final: - return outputs[-1] - return jn.concatenate(outputs, axis=0) From fb8b56271d43b21029ecf4c4acf23f00b00bef8b Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Tue, 29 Sep 2020 08:50:31 -0700 Subject: [PATCH 05/10] Fix empty spaces --- objax/nn/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/objax/nn/layers.py b/objax/nn/layers.py index 65c3ce1..0e03b09 100644 --- a/objax/nn/layers.py +++ b/objax/nn/layers.py @@ -330,7 +330,7 @@ def __call__(self, x: JaxArray) -> JaxArray: class RNN(Module): """ Recurrent Neural Network (RNN) block.""" - def __init__(self, + def __init__(self, nstate: int, nin: int, nout: int, From 2df126e4dc6ebc6fb1ae1f18d85f0a3237711325 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Thu, 17 Dec 2020 10:37:56 -0800 Subject: [PATCH 06/10] Possible RNN design for comment. --- objax/nn/rnn_redesign.py | 68 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 objax/nn/rnn_redesign.py diff --git a/objax/nn/rnn_redesign.py b/objax/nn/rnn_redesign.py new file mode 100644 index 0000000..5f5411c --- /dev/null +++ b/objax/nn/rnn_redesign.py @@ -0,0 +1,68 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Tuple, Union + +import jax.numpy as jn + +import objax +from objax.typing import JaxArray + + +class MyRnnCell(objax.Module): + def __init__(self, nin: int, state: int, activation: Callable = objax.functional.tanh): + self.op = objax.nn.Sequential([objax.nn.Linear(nin + state, state), objax.functional.relu, + objax.nn.Linear(state, state), activation]) + + def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + return self.op(jn.concatenate((x, state), axis=1)) + + +class DDLRnnCell(objax.Module): + def __init__(self, nin: int, state: int, activation: Callable = objax.functional.tanh): + self.wxh = objax.nn.Linear(nin, state, use_bias=False) + self.whh = objax.nn.Linear(state, state) + self.activation = activation + + def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + return self.activation(self.whh(state) + self.wxh(x)) + + +def output_layer(state: int, nout: int): + return objax.nn.Linear(state, nout) + + +class RNN(objax.Module): + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + next_state = self.cell(state_i, x_i) + next_output = self.output_layer(next_state) + return next_state, next_output + + def __call__(self, state: JaxArray, x: JaxArray) -> Tuple[JaxArray, JaxArray]: + # x = (batch, sequence, nin) state = (batch, state) + return objax.functional.scan(self.single, state, x.transpose((1, 0, 2))) # final state, outputs + + +seq, ns, nin, nout, batch = 7, 10, 3, 4, 64 +r = RNN(MyRnnCell(nin, ns), output_layer(ns, nout)) +x = objax.random.normal((batch, seq, nin)) +s = jn.zeros((batch, ns)) +y1 = r(s, x) + +r = RNN(DDLRnnCell(nin, ns), lambda x: x) +y2 = r(s, x) From 0f3dc99e0982f888e621ba73cd0e61254efd3450 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Tue, 12 Jan 2021 19:36:57 -0800 Subject: [PATCH 07/10] Documentation. --- objax/nn/rnn_redesign.py | 53 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/objax/nn/rnn_redesign.py b/objax/nn/rnn_redesign.py index 5f5411c..cf3a804 100644 --- a/objax/nn/rnn_redesign.py +++ b/objax/nn/rnn_redesign.py @@ -21,21 +21,41 @@ class MyRnnCell(objax.Module): + """ Simple RNN cell.""" + def __init__(self, nin: int, state: int, activation: Callable = objax.functional.tanh): - self.op = objax.nn.Sequential([objax.nn.Linear(nin + state, state), objax.functional.relu, + """Creates a MyRnnCell instance. + + Args: + nin: dimension of the input tensor. + state: hidden state tensor has dimensions ``nin`` by ``state``. + activation: activation function for the hidden state layer. + """ + self.op = objax.nn.Sequential([objax.nn.Linear(nin + state, state), objax.nn.Linear(state, state), activation]) def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" return self.op(jn.concatenate((x, state), axis=1)) class DDLRnnCell(objax.Module): + """ Another simple RNN cell.""" + def __init__(self, nin: int, state: int, activation: Callable = objax.functional.tanh): + """ Creates a DDLRnnCell instance. + + Args: + nin: dimension of the input tensor. + state: hidden state tensor has dimension ``nin`` by ``state``. + activation: activation function for the hidden state layer. + """ self.wxh = objax.nn.Linear(nin, state, use_bias=False) self.whh = objax.nn.Linear(state, state) self.activation = activation def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x`` and input ``state``.""" return self.activation(self.whh(state) + self.wxh(x)) @@ -44,17 +64,46 @@ def output_layer(state: int, nout: int): class RNN(objax.Module): + """Simple Recurrent Neural Network (RNN). + + State update is done by the provided RNN cell and output is generated by the + provided output layer. + """ + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ self.cell = cell self.output_layer = output_layer # Is it better inside or outside? def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Execute one step of the RNN. + + Args: + state_i: current state. + x_i: input. + + Returns: + next state and next output. + """ next_state = self.cell(state_i, x_i) next_output = self.output_layer(next_state) return next_state, next_output def __call__(self, state: JaxArray, x: JaxArray) -> Tuple[JaxArray, JaxArray]: - # x = (batch, sequence, nin) state = (batch, state) + """Sequentially processes input to generate output. + + Args: + state: Initial RNN state with dimensions ``batch_size`` by ``state``. + x: input tensor with dimensions ``batch_size`` by ``sequence_length`` by ``nin`` + Returns: + Tuple with final RNN state and output with dimensions ``sequence_length`` by ``batch_size`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``state`` if there is no output layer). + """ return objax.functional.scan(self.single, state, x.transpose((1, 0, 2))) # final state, outputs From b8f646269bc543d7917b68f302add4277a731b46 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Wed, 27 Jan 2021 16:31:02 -0800 Subject: [PATCH 08/10] Added vectorized implementation --- objax/nn/rnn_redesign.py | 94 +++++++++++++++------ objax/nn/rnn_redesign2.py | 169 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 25 deletions(-) create mode 100644 objax/nn/rnn_redesign2.py diff --git a/objax/nn/rnn_redesign.py b/objax/nn/rnn_redesign.py index cf3a804..28a7474 100644 --- a/objax/nn/rnn_redesign.py +++ b/objax/nn/rnn_redesign.py @@ -12,46 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Tuple, Union import jax.numpy as jn - import objax + from objax.typing import JaxArray +from typing import Callable, Tuple, Union class MyRnnCell(objax.Module): """ Simple RNN cell.""" - def __init__(self, nin: int, state: int, activation: Callable = objax.functional.tanh): + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): """Creates a MyRnnCell instance. - Args: nin: dimension of the input tensor. - state: hidden state tensor has dimensions ``nin`` by ``state``. + nstate: hidden state tensor has dimensions ``nin`` by ``nstate``. activation: activation function for the hidden state layer. """ - self.op = objax.nn.Sequential([objax.nn.Linear(nin + state, state), - objax.nn.Linear(state, state), activation]) + self.op = objax.nn.Sequential([objax.nn.Linear(nin + nstate, nstate), + objax.nn.Linear(nstate, nstate), activation]) def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" - return self.op(jn.concatenate((x, state), axis=1)) + return self.op(jn.concatenate((x, state), axis=0)) class DDLRnnCell(objax.Module): """ Another simple RNN cell.""" - def __init__(self, nin: int, state: int, activation: Callable = objax.functional.tanh): + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): """ Creates a DDLRnnCell instance. - Args: nin: dimension of the input tensor. - state: hidden state tensor has dimension ``nin`` by ``state``. + nstate: hidden state tensor has dimensions ``nin`` by ``nstate``. activation: activation function for the hidden state layer. """ - self.wxh = objax.nn.Linear(nin, state, use_bias=False) - self.whh = objax.nn.Linear(state, state) + self.wxh = objax.nn.Linear(nin, nstate, use_bias=False) + self.whh = objax.nn.Linear(nstate, nstate) self.activation = activation def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: @@ -65,14 +63,12 @@ def output_layer(state: int, nout: int): class RNN(objax.Module): """Simple Recurrent Neural Network (RNN). - State update is done by the provided RNN cell and output is generated by the provided output layer. """ def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): """Creates an RNN instance. - Args: cell: RNN cell. output_layer: output layer can be a function or another module. @@ -82,11 +78,9 @@ def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callabl def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: """Execute one step of the RNN. - Args: state_i: current state. x_i: input. - Returns: next state and next output. """ @@ -94,24 +88,74 @@ def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: next_output = self.output_layer(next_state) return next_state, next_output - def __call__(self, state: JaxArray, x: JaxArray) -> Tuple[JaxArray, JaxArray]: + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: """Sequentially processes input to generate output. - Args: - state: Initial RNN state with dimensions ``batch_size`` by ``state``. x: input tensor with dimensions ``batch_size`` by ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``batch_size`` by ``state``. Returns: Tuple with final RNN state and output with dimensions ``sequence_length`` by ``batch_size`` by ``nout``, where ``nout`` is the output dimension of the output layer (or ``state`` if there is no output layer). """ - return objax.functional.scan(self.single, state, x.transpose((1, 0, 2))) # final state, outputs + final_state, output = objax.functional.scan(self.single, state, x.transpose((1, 0, 2))) + return output, final_state + + +class VectorizedRNN(objax.Module): + """Vectorized Recurrent Neural Network (RNN). + State update is done by the provided RNN cell and output is generated by the + provided output layer. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + + """Execute one step of the RNN. + Args: + state_i: current state. + x_i: input. + Returns: + next state and next output. + """ + next_state = self.cell(state_i, x_i) + next_output = self.output_layer(next_state) + return next_state, next_output + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + Args: + x: input tensor with dimensions ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``(nstate,)``. + Returns: + Tuple with final RNN state and output with dimensions ``sequence_length`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``nstate`` if there is no output layer). + """ + final_state, output = objax.functional.scan(self.single, state, x) + return output, final_state seq, ns, nin, nout, batch = 7, 10, 3, 4, 64 -r = RNN(MyRnnCell(nin, ns), output_layer(ns, nout)) + +rnn_cell = DDLRnnCell(nin, ns) +out_layer = output_layer(ns, nout) + +r = RNN(rnn_cell, out_layer) x = objax.random.normal((batch, seq, nin)) s = jn.zeros((batch, ns)) -y1 = r(s, x) +y1 = r(x, s) + +# Vectorized Version +r = VectorizedRNN(rnn_cell, out_layer) +rnn_vec = objax.Vectorize(r, batch_axis=(0, 0)) +y4 = rnn_vec(x, s) -r = RNN(DDLRnnCell(nin, ns), lambda x: x) -y2 = r(s, x) +assert jn.array_equal(y4[1], y1[1]) +assert jn.array_equal(y4[0], y1[0].transpose((1, 0, 2))) diff --git a/objax/nn/rnn_redesign2.py b/objax/nn/rnn_redesign2.py new file mode 100644 index 0000000..27816d4 --- /dev/null +++ b/objax/nn/rnn_redesign2.py @@ -0,0 +1,169 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Tuple, Union + +import jax.numpy as jn + +import objax +from objax.typing import JaxArray + + +class MyRnnCell(objax.Module): + """ Simple RNN cell.""" + + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): + """Creates a MyRnnCell instance. + + Args: + nin: dimension of the input tensor. + state: hidden state tensor has dimensions ``nin`` by ``nstate``. + activation: activation function for the hidden state layer. + """ + self.op = objax.nn.Sequential([objax.nn.Linear(nin + nstate, nstate), + objax.nn.Linear(nstate, nstate), activation]) + + def __call__(self, x: JaxArray, state: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" + return self.op(jn.concatenate((x, state), axis=1)) + + +class DDLRnnCell(objax.Module): + """ Another simple RNN cell.""" + + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): + """ Creates a DDLRnnCell instance. + + Args: + nin: dimension of the input tensor. + nstate: hidden state tensor has dimension ``nin`` by ``nstate``. + activation: activation function for the hidden state layer. + """ + self.wxh = objax.nn.Linear(nin, nstate, use_bias=False) + self.whh = objax.nn.Linear(nstate, nstate) + self.activation = activation + + def __call__(self, x: JaxArray, state: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x`` and input ``state``.""" + return self.activation(self.whh(state) + self.wxh(x)) + + +def output_layer(nstate: int, nout: int): + return objax.nn.Linear(nstate, nout) + + +class RNN(objax.Module): + """Simple Recurrent Neural Network (RNN). + + The RNN cell provided as input updates the network's state while the provided output layer generates + the network's output. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + def single(self, x_i: JaxArray, state_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Execute one step of the RNN. + + Args: + x_i: input. + state_i: current state. + + Returns: + next output and next state. + """ + next_state = self.cell(x_i, state_i) + print("next_state.shape", next_state.shape) + next_output = self.output_layer(next_state) + print("next_output.shape", next_output.shape) + return next_output, next_state + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + + Args: + x: input tensor with dimensions ``batch_size`` by ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``batch_size`` by ``nstate``. + Returns: + Tuple with output with dimensions ``sequence_length`` by ``batch_size`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``nstate`` if there is no output layer) + and state. + """ + return objax.functional.scan(self.single, x.transpose((1, 0, 2)), state) #outputs, final state + +class no_batch_RNN(objax.Module): + """Simple Recurrent Neural Network (RNN). + + The RNN cell provided as input updates the network's state while the provided output layer generates + the network's output. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + def single(self, x_i: JaxArray, state_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Execute one step of the RNN. + + Args: + x_i: input. + state_i: current state. + + Returns: + next state and next output. + """ + next_state = self.cell(x_i, state_i) + next_output = self.output_layer(next_state) + return next_output, next_state + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + + Args: + x: input tensor with dimensions ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``batch_size`` by ``nstate``. + + Returns: + Tuple with output with dimensions ``sequence_length`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``nstate`` if there is no output layer) + and state. + """ + return objax.functional.scan(self.single, x, state) #outputs, final state + + +seq, ns, nin, nout, batch = 7, 10, 3, 4, 64 +r = RNN(MyRnnCell(nin, ns), output_layer(ns, nout)) +x = objax.random.normal((batch, seq, nin)) +s = jn.zeros((batch, ns)) + +print("x.shape", x.shape) +print("s.shape:", s.shape) + +#y1 = r(x, s) + +r = RNN(DDLRnnCell(nin, ns), lambda x: x) +y2 = r(x, s) From e35205c9042a8bec87f15ddc552f6e5cd975e59b Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Wed, 27 Jan 2021 16:32:09 -0800 Subject: [PATCH 09/10] removed unneeded file --- objax/nn/rnn_redesign2.py | 169 -------------------------------------- 1 file changed, 169 deletions(-) delete mode 100644 objax/nn/rnn_redesign2.py diff --git a/objax/nn/rnn_redesign2.py b/objax/nn/rnn_redesign2.py deleted file mode 100644 index 27816d4..0000000 --- a/objax/nn/rnn_redesign2.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Tuple, Union - -import jax.numpy as jn - -import objax -from objax.typing import JaxArray - - -class MyRnnCell(objax.Module): - """ Simple RNN cell.""" - - def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): - """Creates a MyRnnCell instance. - - Args: - nin: dimension of the input tensor. - state: hidden state tensor has dimensions ``nin`` by ``nstate``. - activation: activation function for the hidden state layer. - """ - self.op = objax.nn.Sequential([objax.nn.Linear(nin + nstate, nstate), - objax.nn.Linear(nstate, nstate), activation]) - - def __call__(self, x: JaxArray, state: JaxArray) -> JaxArray: - """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" - return self.op(jn.concatenate((x, state), axis=1)) - - -class DDLRnnCell(objax.Module): - """ Another simple RNN cell.""" - - def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): - """ Creates a DDLRnnCell instance. - - Args: - nin: dimension of the input tensor. - nstate: hidden state tensor has dimension ``nin`` by ``nstate``. - activation: activation function for the hidden state layer. - """ - self.wxh = objax.nn.Linear(nin, nstate, use_bias=False) - self.whh = objax.nn.Linear(nstate, nstate) - self.activation = activation - - def __call__(self, x: JaxArray, state: JaxArray) -> JaxArray: - """Updates and returns hidden state based on input sequence ``x`` and input ``state``.""" - return self.activation(self.whh(state) + self.wxh(x)) - - -def output_layer(nstate: int, nout: int): - return objax.nn.Linear(nstate, nout) - - -class RNN(objax.Module): - """Simple Recurrent Neural Network (RNN). - - The RNN cell provided as input updates the network's state while the provided output layer generates - the network's output. - """ - - def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): - """Creates an RNN instance. - - Args: - cell: RNN cell. - output_layer: output layer can be a function or another module. - """ - self.cell = cell - self.output_layer = output_layer # Is it better inside or outside? - - def single(self, x_i: JaxArray, state_i: JaxArray) -> Tuple[JaxArray, JaxArray]: - """Execute one step of the RNN. - - Args: - x_i: input. - state_i: current state. - - Returns: - next output and next state. - """ - next_state = self.cell(x_i, state_i) - print("next_state.shape", next_state.shape) - next_output = self.output_layer(next_state) - print("next_output.shape", next_output.shape) - return next_output, next_state - - def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: - """Sequentially processes input to generate output. - - Args: - x: input tensor with dimensions ``batch_size`` by ``sequence_length`` by ``nin`` - state: Initial RNN state with dimensions ``batch_size`` by ``nstate``. - Returns: - Tuple with output with dimensions ``sequence_length`` by ``batch_size`` by ``nout``, - where ``nout`` is the output dimension of the output layer (or ``nstate`` if there is no output layer) - and state. - """ - return objax.functional.scan(self.single, x.transpose((1, 0, 2)), state) #outputs, final state - -class no_batch_RNN(objax.Module): - """Simple Recurrent Neural Network (RNN). - - The RNN cell provided as input updates the network's state while the provided output layer generates - the network's output. - """ - - def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): - """Creates an RNN instance. - - Args: - cell: RNN cell. - output_layer: output layer can be a function or another module. - """ - self.cell = cell - self.output_layer = output_layer # Is it better inside or outside? - - def single(self, x_i: JaxArray, state_i: JaxArray) -> Tuple[JaxArray, JaxArray]: - """Execute one step of the RNN. - - Args: - x_i: input. - state_i: current state. - - Returns: - next state and next output. - """ - next_state = self.cell(x_i, state_i) - next_output = self.output_layer(next_state) - return next_output, next_state - - def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: - """Sequentially processes input to generate output. - - Args: - x: input tensor with dimensions ``sequence_length`` by ``nin`` - state: Initial RNN state with dimensions ``batch_size`` by ``nstate``. - - Returns: - Tuple with output with dimensions ``sequence_length`` by ``nout``, - where ``nout`` is the output dimension of the output layer (or ``nstate`` if there is no output layer) - and state. - """ - return objax.functional.scan(self.single, x, state) #outputs, final state - - -seq, ns, nin, nout, batch = 7, 10, 3, 4, 64 -r = RNN(MyRnnCell(nin, ns), output_layer(ns, nout)) -x = objax.random.normal((batch, seq, nin)) -s = jn.zeros((batch, ns)) - -print("x.shape", x.shape) -print("s.shape:", s.shape) - -#y1 = r(x, s) - -r = RNN(DDLRnnCell(nin, ns), lambda x: x) -y2 = r(x, s) From 74271d6d7770d6c115488055f38bc43760a34e54 Mon Sep 17 00:00:00 2001 From: Andreas Terzis Date: Tue, 16 Mar 2021 12:12:35 -0700 Subject: [PATCH 10/10] Updated design --- objax/nn/rnn_redesign.py | 76 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/objax/nn/rnn_redesign.py b/objax/nn/rnn_redesign.py index 28a7474..05506a7 100644 --- a/objax/nn/rnn_redesign.py +++ b/objax/nn/rnn_redesign.py @@ -38,6 +38,34 @@ def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: return self.op(jn.concatenate((x, state), axis=0)) +class FactorizedRnnCell(objax.Module): + """ Factorized version of RNN cell.""" + + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): + """Creates a MyRnnCell instance. + Args: + nin: dimension of the input tensor. + nstate: hidden state tensor has dimensions ``nin`` by ``nstate``. + activation: activation function for the hidden state layer. + """ + self.win = objax.nn.Linear(nin, nstate, use_bias=False) + self.wn = objax.nn.Linear(nstate, nstate) + self.activation = activation + self.nstate = nstate + + def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" + self.factor = self.win(x) + # TODO(aterzis): Replace with scan(?) + output = [] + for i in range(x.shape[0]): + state = self.activation(self.wn(state) + self.factor[i]) + output_i = jn.reshape(state, (1, self.nstate)) + output.append(output_i) + outputs = jn.concatenate(output, axis=0) + return jn.reshape(state, (1, self.nstate)) + + class DDLRnnCell(objax.Module): """ Another simple RNN cell.""" @@ -117,7 +145,6 @@ def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callabl self.output_layer = output_layer # Is it better inside or outside? def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: - """Execute one step of the RNN. Args: state_i: current state. @@ -142,8 +169,38 @@ def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: return output, final_state +class FactorizedRNN(objax.Module): + """Factorized Recurrent Neural Network (RNN). + State update is done by the provided RNN cell and output is generated by the + provided output layer. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + Args: + x: input tensor with dimensions ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``(state,)``. + Returns: + Tuple with final RNN state and output with dimensions ``sequence_length`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``state`` if there is no output layer). + """ + out = self.cell(state, x) + output = self.output_layer(out) + return output, out[-1] + seq, ns, nin, nout, batch = 7, 10, 3, 4, 64 +# RNN example rnn_cell = DDLRnnCell(nin, ns) out_layer = output_layer(ns, nout) @@ -152,10 +209,25 @@ def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: s = jn.zeros((batch, ns)) y1 = r(x, s) -# Vectorized Version +# Vectorized version r = VectorizedRNN(rnn_cell, out_layer) rnn_vec = objax.Vectorize(r, batch_axis=(0, 0)) y4 = rnn_vec(x, s) assert jn.array_equal(y4[1], y1[1]) assert jn.array_equal(y4[0], y1[0].transpose((1, 0, 2))) + +# Factorized version +factorized_cell = FactorizedRnnCell(nin, ns) +out = factorized_cell(s[0], x[0, :, :]) +out2 = out_layer(out) + +print("s.shape", s.shape) +print("out2.shape", out2.shape) + +f = FactorizedRNN(factorized_cell, out_layer) +y5 = f(x[0, :, :], s[0]) +print("y5[0].shape", y5[0].shape) +print("y5[1].shape", y5[1].shape) +assert jn.array_equal(y5[0], out2) +assert jn.array_equal(y5[1], out[-1])