diff --git a/examples/text_generation/shakespeare_rnn.py b/examples/text_generation/shakespeare_rnn.py deleted file mode 100644 index b6dcb3a..0000000 --- a/examples/text_generation/shakespeare_rnn.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import collections -import random -import re - -import jax -import jax.numpy as jn -import tensorflow_datasets as tfds - -import objax -from objax.functional import one_hot -from objax.zoo.rnn import RNN - - -def tokenize(lines, token_type='word'): - """Split the lines list into word or char tokens depending on token_type.""" - - if token_type == 'word': - return [line.split(' ') for line in lines] - elif token_type == 'char': - return [list(line) for line in lines] - else: - raise ValueError('ERROR: unknown token type', token_type) - - -def count_corpus(token_list): - """Return a Counter of the tokens in token_list. - - Args: - token_list: list of token lists - """ - tokens = [tk for tokens in token_list for tk in tokens] - return collections.Counter(tokens) - - -class Vocabulary: - """Vocabulary extracts set of unique tokens and - constructs token to index and index to token lookup tables. - """ - - def __init__(self, token_list): - counter = count_corpus(token_list) - - self.token_freqs = sorted(counter.items(), key=lambda x: x[0]) - self.token_freqs.sort(key=lambda x: x[1], reverse=True) - self.unk, uniq_tokens = 0, [''] - uniq_tokens += [token for token, freq in self.token_freqs] - self.idx_to_token, self.token_to_idx = [], dict() - - for token in uniq_tokens: - self.idx_to_token.append(token) - self.token_to_idx[token] = len(self.idx_to_token) - 1 - - def __len__(self): - return len(self.idx_to_token) - - def __getitem__(self, tokens): - if not isinstance(tokens, (list, tuple)): - return self.token_to_idx.get(tokens, self.unk) - return [self.__getitem__(token) for token in tokens] - - def to_tokens(self, indices): - if not isinstance(indices, (list, tuple)): - return self.idx_to_token[indices] - return [self.idx_to_token[index] for index in indices] - - -def seq_data_iter(corpus, batch_size, num_steps): - # Offset the iterator over the data for uniform starts - corpus = corpus[random.randint(0, num_steps):] - # Subtract 1 extra since we need to account for label - num_examples = ((len(corpus) - 1) // num_steps) - example_indices = list(range(0, num_examples * num_steps, num_steps)) - random.shuffle(example_indices) - - def data(pos): - # This returns a sequence of length `num_steps` starting from `pos` - return corpus[pos: pos + num_steps] - - # Discard half empty batches - num_batches = num_examples // batch_size - for i in range(0, batch_size * num_batches, batch_size): - # `batch_size` indicates the random examples read each time - batch_indices = example_indices[i:(i + batch_size)] - X = [data(j) for j in batch_indices] - Y = [data(j + 1) for j in batch_indices] - yield X, Y - - -class DataLoader: - """An iterator to load sequence data.""" - - def __init__(self, batch_size, num_steps, token_type): - self.data_iter_fn = seq_data_iter - self.corpus, self.vocab = load_shakespeare_corpus(token_type) - self.batch_size, self.num_steps = batch_size, num_steps - - def __iter__(self): - return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps) - - -def load_shakespeare_corpus(token_type='char'): - """Load the tiny_shakespeare TFDS and return its corpus and vodabulary.""" - data = tfds.as_numpy(tfds.load(name='tiny_shakespeare', batch_size=-1)) - train_data = data['train'] - input_string = train_data['text'][0].decode() # decode binary string - re.sub('[^A-Za-z]+', ' ', input_string.strip().lower()) - lines = input_string.splitlines() - - token_list = tokenize(lines, token_type) - - vocab = Vocabulary(token_list) - corpus = [vocab[tk] for tokens in token_list for tk in tokens] - return corpus, vocab - - -def load_shakespeare(batch_size, num_steps, token_type): - data_iter = DataLoader(batch_size, num_steps, token_type) - return data_iter, data_iter.vocab - - -batch_size, num_steps = 10, 10 -num_epochs = 500 -num_hiddens = 256 -lr = 0.0001 -theta = 1 -print(jax.local_devices()) - -train_iter, vocab = load_shakespeare(batch_size, num_steps, 'char') -vocab_size = len(vocab) - -model = RNN(num_hiddens, vocab_size, vocab_size) -model_vars = model.vars() -model.init_state(batch_size) - -# Sample call for forward pass -X = jn.arange(batch_size * num_steps).reshape(batch_size, num_steps).T -X_one_hot = one_hot(X, vocab_size) -Z = model(X_one_hot) -print("X_one_hot.shape:", X_one_hot.shape) -print("Z.shape:", Z.shape) - - -def predict_char(prefix, num_predicts, model, vocab): - model.init_state(batch_size=1) - outputs = [vocab[prefix[0]]] - get_input = lambda: one_hot(jn.array([outputs[-1]]).reshape(1, 1), len(vocab)) - for y in prefix[1:]: # Warmup state with prefix - model(get_input()) - outputs.append(vocab[y]) - for _ in range(num_predicts): # Predict num_predicts steps - Y = model(get_input()) - outc = int(Y.argmax(axis=1).reshape(1)) - outputs.append(outc) - return ''.join([vocab.idx_to_token[i] for i in outputs]) - - -print(predict_char('to be or not to be', 10, model, vocab)) - -opt = objax.optimizer.Adam(model_vars) -model_ema = objax.optimizer.ExponentialMovingAverageModule(model, momentum=0.999) -predict = objax.Jit(objax.nn.Sequential([model_ema, objax.functional.softmax])) - - -@objax.Function.with_vars(model.vars()) -def loss(x, label): # sum(label * log(softmax(logit))) - logit = model(x) - return objax.functional.loss.cross_entropy_logits(logit, label).mean() - - -gv = objax.GradValues(loss, model.vars()) - - -def clip_gradients(grads, theta): - total_grad_norm = jn.linalg.norm([jn.linalg.norm(g) for g in grads]) - scale_factor = jn.minimum(theta / total_grad_norm, 1.) - return [g * scale_factor for g in grads] - - -@objax.Function.with_vars(model.vars() + opt.vars() + model_ema.vars()) -def train_op(x, xl): - g, v = gv(x, xl) # returns gradients, loss - clipped_g = clip_gradients(g, theta) - opt(lr, clipped_g) - model_ema.update_ema() - return v - - -train_op = objax.Jit(train_op) - -# Training -for epoch in range(num_epochs): - for test_data, labels in train_iter: - X = jn.array(test_data).T - X_one_hot = one_hot(X, vocab_size) - Y = jn.array(labels).T - Y_one_hot = one_hot(Y, vocab_size) - flat_labels = jn.concatenate(Y_one_hot, axis=0) - v = train_op(X_one_hot, flat_labels) - if epoch % 10 == 0: - print("loss:", float(v[0])) - -print(predict_char('to be or not to be', 40, model, vocab)) diff --git a/objax/io/checkpoint.py b/objax/io/checkpoint.py deleted file mode 100644 index e23ce01..0000000 --- a/objax/io/checkpoint.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ['Checkpoint'] - -import glob -import os -from typing import Callable, Optional - -from objax.io.ops import load_var_collection, save_var_collection -from objax.typing import FileOrStr -from objax.variable import VarCollection - - -class Checkpoint: - """Helper class which performs saving and restoring of the variables. - - Variables are stored in the checkpoint files. One checkpoint file stores a single snapshot of the variables. - Different checkpoint files store different snapshots of the variables (for example at different training step). - Each checkpoint has associated index, which is used to identify time when snapshot of the variables was made. - Typically training step or training epoch are used as an index. - """ - - DIR_NAME: str = 'ckpt' - """Name of the subdirectory of model directory where checkpoints will be saved.""" - - FILE_MATCH: str = '*.npz' - """File pattern which is used to search for checkpoint files.""" - - FILE_FORMAT: str = '%010d.npz' - """Format of the filename of one checkpoint file.""" - - LOAD_FN: Callable[[FileOrStr, VarCollection], None] = staticmethod(load_var_collection) - """Load function, which loads variables collection from given file.""" - - SAVE_FN: Callable[[FileOrStr, VarCollection], None] = staticmethod(save_var_collection) - """Save function, which saves variables collection into given file.""" - - def __init__(self, logdir: str, keep_ckpts: int, makedir: bool = True, verbose: bool = True): - """Creates instance of the Checkpoint class. - - Args: - logdir: model directory. Checkpoints will be saved in the subdirectory of model directory. - keep_ckpts: maximum number of checkpoints to keep. - makedir: if True then directory for checkpoints will be created, - otherwise it's expected that directory already exists. - verbose: if True then print when data is restored from checkpoint. - """ - self.logdir = logdir - self.keep_ckpts = keep_ckpts - self.verbose = verbose - if makedir: - os.makedirs(os.path.join(logdir, self.DIR_NAME), exist_ok=True) - - @staticmethod - def checkpoint_idx(filename: str): - """Returns index of checkpoint from given checkpoint filename. - - Args: - filename: checkpoint filename. - - Returns: - checkpoint index. - """ - return int(os.path.basename(filename).split('.')[0]) - - def restore(self, vc: VarCollection, idx: Optional[int] = None): - """Restores values of all variables of given variables collection from the checkpoint. - - Old values from the variables collection will be replaced with the new values read from checkpoint. - If variable does not exist in the variables collection, it won't be restored from checkpoint. - - Args: - vc: variables collection to restore. - idx: if provided then checkpoint index to use, if None then latest checkpoint will be restored. - - Returns: - idx: index of the restored checkpoint. - ckpt: full path to the restored checkpoint. - """ - assert isinstance(vc, VarCollection), f'Must pass a VarCollection to restore; received type {type(vc)}.' - if idx is None: - all_ckpts = glob.glob(os.path.join(self.logdir, self.DIR_NAME, self.FILE_MATCH)) - if not all_ckpts: - if self.verbose: - print('No checkpoints found. Skipping restoring variables.') - return 0, '' - idx = self.checkpoint_idx(max(all_ckpts)) - ckpt = os.path.join(self.logdir, self.DIR_NAME, self.FILE_FORMAT % idx) - if self.verbose: - print('Resuming from', ckpt) - self.LOAD_FN(ckpt, vc) - return idx, ckpt - - def save(self, vc: VarCollection, idx: int): - """Saves variables collection to checkpoint with given index. - - Args: - vc: variables collection to save. - idx: index of the new checkpoint where variables should be saved. - """ - assert isinstance(vc, VarCollection), f'Must pass a VarCollection to save; received type {type(vc)}.' - self.SAVE_FN(os.path.join(self.logdir, self.DIR_NAME, self.FILE_FORMAT % idx), vc) - for ckpt in sorted(glob.glob(os.path.join(self.logdir, self.DIR_NAME, self.FILE_MATCH)))[:-self.keep_ckpts]: - os.remove(ckpt) diff --git a/objax/nn/layers.py b/objax/nn/layers.py deleted file mode 100644 index 506a8eb..0000000 --- a/objax/nn/layers.py +++ /dev/null @@ -1,487 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ['BatchNorm', 'BatchNorm0D', 'BatchNorm1D', 'BatchNorm2D', - 'Conv2D', 'ConvTranspose2D', 'Dropout', 'Linear', - 'MovingAverage', 'ExponentialMovingAverage', 'Sequential', - 'SyncedBatchNorm', 'SyncedBatchNorm0D', 'SyncedBatchNorm1D', 'SyncedBatchNorm2D'] - -from typing import Callable, Iterable, Tuple, Optional, Union, List, Dict - -from jax import numpy as jn, random as jr, lax - -from objax import functional, random, util -from objax.constants import ConvPadding -from objax.module import ModuleList, Module -from objax.nn.init import kaiming_normal, xavier_normal -from objax.typing import JaxArray, ConvPaddingInt -from objax.util import class_name -from objax.variable import TrainVar, StateVar - - -class BatchNorm(Module): - """Applies a batch normalization on different ranks of an input tensor. - - The module follows the operation described in Algorithm 1 of - `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift - `_. - """ - - def __init__(self, dims: Iterable[int], redux: Iterable[int], momentum: float = 0.999, eps: float = 1e-6): - """Creates a BatchNorm module instance. - - Args: - dims: shape of the batch normalization state variables. - redux: list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__() - dims = tuple(dims) - self.momentum = momentum - self.eps = eps - self.redux = tuple(redux) - self.running_mean = StateVar(jn.zeros(dims)) - self.running_var = StateVar(jn.ones(dims)) - self.beta = TrainVar(jn.zeros(dims)) - self.gamma = TrainVar(jn.ones(dims)) - - def __call__(self, x: JaxArray, training: bool) -> JaxArray: - """Performs batch normalization of input tensor. - - Args: - x: input tensor. - training: if True compute batch normalization in training mode (accumulating batch statistics), - otherwise compute in evaluation mode (using already accumulated batch statistics). - - Returns: - Batch normalized tensor. - """ - if training: - m = x.mean(self.redux, keepdims=True) - v = ((x - m) ** 2).mean(self.redux, keepdims=True) # Note: x^2 - m^2 is not numerically stable. - self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) - self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) - else: - m, v = self.running_mean.value, self.running_var.value - y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value - return y - - def __repr__(self): - args = dict(dims=self.beta.value.shape, redux=self.redux, momentum=self.momentum, eps=self.eps) - args = ', '.join(f'{x}={y}' for x, y in args.items()) - return f'{class_name(self)}({args})' - - -class BatchNorm0D(BatchNorm): - """Applies a 0D batch normalization on a 2D-input batch of shape (N,C). - - The module follows the operation described in Algorithm 1 of - `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift - `_. - """ - - def __init__(self, nin: int, momentum: float = 0.999, eps: float = 1e-6): - """Creates a BatchNorm0D module instance. - - Args: - nin: number of channels in the input example. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__((1, nin), (0,), momentum, eps) - - def __repr__(self): - return f'{class_name(self)}(nin={self.beta.value.shape[1]}, momentum={self.momentum}, eps={self.eps})' - - -class BatchNorm1D(BatchNorm): - """Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L). - - The module follows the operation described in Algorithm 1 of - `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift - `_. - """ - - def __init__(self, nin: int, momentum: float = 0.999, eps: float = 1e-6): - """Creates a BatchNorm1D module instance. - - Args: - nin: number of channels in the input example. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__((1, nin, 1), (0, 2), momentum, eps) - - def __repr__(self): - return f'{class_name(self)}(nin={self.beta.value.shape[1]}, momentum={self.momentum}, eps={self.eps})' - - -class BatchNorm2D(BatchNorm): - """Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W). - - The module follows the operation described in Algorithm 1 of - `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift - `_. - """ - - def __init__(self, nin: int, momentum: float = 0.999, eps: float = 1e-6): - """Creates a BatchNorm2D module instance. - - Args: - nin: number of channels in the input example. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__((1, nin, 1, 1), (0, 2, 3), momentum, eps) - - def __repr__(self): - return f'{class_name(self)}(nin={self.beta.value.shape[1]}, momentum={self.momentum}, eps={self.eps})' - - -class Conv2D(Module): - """Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).""" - - def __init__(self, - nin: int, - nout: int, - k: Union[Tuple[int, int], int], - strides: Union[Tuple[int, int], int] = 1, - dilations: Union[Tuple[int, int], int] = 1, - groups: int = 1, - padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME, - use_bias: bool = True, - w_init: Callable = kaiming_normal): - """Creates a Conv2D module instance. - - Args: - nin: number of channels of the input tensor. - nout: number of channels of the output tensor. - k: size of the convolution kernel, either tuple (height, width) or single number if they're the same. - strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same. - dilations: spacing between kernel points (also known as astrous convolution), - either tuple (dilation_y, dilation_x) or single number if they're the same. - groups: number of input and output channels group. When groups > 1 convolution operation is applied - individually for each group. nin and nout must both be divisible by groups. - padding: padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values. - use_bias: if True then convolution will have bias term. - w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix). - """ - super().__init__() - assert nin % groups == 0, 'nin should be divisible by groups' - assert nout % groups == 0, 'nout should be divisible by groups' - self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None - self.w = TrainVar(w_init((*util.to_tuple(k, 2), nin // groups, nout))) # HWIO - self.padding = util.to_padding(padding, 2) - self.strides = util.to_tuple(strides, 2) - self.dilations = util.to_tuple(dilations, 2) - self.groups = groups - self.w_init = w_init - - def __call__(self, x: JaxArray) -> JaxArray: - """Returns the results of applying the convolution to input x.""" - nin = self.w.value.shape[2] * self.groups - assert x.shape[1] == nin, (f'Attempting to convolve an input with {x.shape[1]} input channels ' - f'when the convolution expects {nin} channels. For reference, ' - f'self.w.value.shape={self.w.value.shape} and x.shape={x.shape}.') - y = lax.conv_general_dilated(x, self.w.value, self.strides, self.padding, - rhs_dilation=self.dilations, - feature_group_count=self.groups, - dimension_numbers=('NCHW', 'HWIO', 'NCHW')) - if self.b: - y += self.b.value - return y - - def __repr__(self): - args = dict(nin=self.w.value.shape[2] * self.groups, nout=self.w.value.shape[3], k=self.w.value.shape[:2], - strides=self.strides, dilations=self.dilations, groups=self.groups, padding=self.padding, - use_bias=self.b is not None) - args = ', '.join(f'{k}={repr(v)}' for k, v in args.items()) - return f'{class_name(self)}({args}, w_init={util.repr_function(self.w_init)})' - - -class ConvTranspose2D(Conv2D): - """Applies a 2D transposed convolution on a 4D-input batch of shape (N,C,H,W). - - This module can be seen as a transformation going in the opposite direction of a normal convolution, i.e., - from something that has the shape of the output of some convolution to something that has the shape of its input - while maintaining a connectivity pattern that is compatible with said convolution. - Note that ConvTranspose2D is consistent with - `Conv2DTranspose `_, - of Tensorflow but is not consistent with - `ConvTranspose2D `_ - of PyTorch due to kernel transpose and padding. - """ - - def __init__(self, - nin: int, - nout: int, - k: Union[Tuple[int, int], int], - strides: Union[Tuple[int, int], int] = 1, - dilations: Union[Tuple[int, int], int] = 1, - padding: Union[ConvPadding, str, ConvPaddingInt] = ConvPadding.SAME, - use_bias: bool = True, - w_init: Callable = kaiming_normal): - """Creates a ConvTranspose2D module instance. - - Args: - nin: number of channels of the input tensor. - nout: number of channels of the output tensor. - k: size of the convolution kernel, either tuple (height, width) or single number if they're the same. - strides: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same. - dilations: spacing between kernel points (also known as astrous convolution), - either tuple (dilation_y, dilation_x) or single number if they're the same. - padding: padding of the input tensor, either Padding.SAME, Padding.VALID or numerical values. - use_bias: if True then convolution will have bias term. - w_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix). - """ - super().__init__(nin=nout, nout=nin, k=k, strides=strides, dilations=dilations, padding=padding, - use_bias=False, w_init=w_init) - self.b = TrainVar(jn.zeros((nout, 1, 1))) if use_bias else None - - def __call__(self, x: JaxArray) -> JaxArray: - """Returns the results of applying the transposed convolution to input x.""" - y = lax.conv_transpose(x, self.w.value, self.strides, self.padding, - rhs_dilation=self.dilations, - dimension_numbers=('NCHW', 'HWIO', 'NCHW'), transpose_kernel=True) - if self.b: - y += self.b.value - return y - - def __repr__(self): - args = dict(nin=self.w.value.shape[3], nout=self.w.value.shape[2], k=self.w.value.shape[:2], - strides=self.strides, dilations=self.dilations, padding=self.padding, - use_bias=self.b is not None) - args = ', '.join(f'{k}={repr(v)}' for k, v in args.items()) - return f'{class_name(self)}({args}, w_init={util.repr_function(self.w_init)})' - - -class Dropout(Module): - """In the training phase, a dropout layer zeroes some elements of the input tensor with probability 1-keep and - scale the other elements by a factor of 1/keep.""" - - def __init__(self, keep: float, generator=random.DEFAULT_GENERATOR): - """Creates Dropout module instance. - - Args: - keep: probability to keep element of the tensor. - generator: optional argument with instance of ObJAX random generator. - """ - self.keygen = generator - self.keep = keep - - def __call__(self, x: JaxArray, training: bool, dropout_keep: Optional[float] = None) -> JaxArray: - """Performs dropout of input tensor. - - Args: - x: input tensor. - training: if True then apply dropout to the input, otherwise keep input tensor unchanged. - dropout_keep: optional argument, when set overrides dropout keep probability. - - Returns: - Tensor with applied dropout. - """ - keep = dropout_keep or self.keep - if not training or keep >= 1: - return x - keep_mask = jr.bernoulli(self.keygen(), keep, x.shape) - return jn.where(keep_mask, x / keep, 0) - - def __repr__(self): - return f'{class_name(self)}(keep={self.keep})' - - -class ExponentialMovingAverage(Module): - """computes exponential moving average (also called EMA or EWMA) of an input batch.""" - - def __init__(self, shape: Tuple[int, ...], momentum: float = 0.999, init_value: float = 0): - """Creates a ExponentialMovingAverage module instance. - - Args: - shape: shape of the input tensor. - momentum: momentum for exponential decrease of accumulated value. - init_value: initial value for exponential moving average. - """ - self.momentum = momentum - self.init_value = init_value - self.avg = StateVar(jn.zeros(shape) + init_value) - - def __call__(self, x: JaxArray) -> JaxArray: - """Update the statistics using x and return the exponential moving average.""" - self.avg.value += (self.avg.value - x) * (self.momentum - 1) - return self.avg.value - - def __repr__(self): - s = self.avg.value.shape - return f'{class_name(self)}(shape={s}, momentum={self.momentum}, init_value={self.init_value})' - - -class Linear(Module): - """Applies a linear transformation on an input batch.""" - - def __init__(self, nin: int, nout: int, use_bias: bool = True, w_init: Callable = xavier_normal): - """Creates a Linear module instance. - - Args: - nin: number of channels of the input tensor. - nout: number of channels of the output tensor. - use_bias: if True then linear layer will have bias term. - w_init: weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix). - """ - super().__init__() - self.w_init = w_init - self.b = TrainVar(jn.zeros(nout)) if use_bias else None - self.w = TrainVar(w_init((nin, nout))) - - def __call__(self, x: JaxArray) -> JaxArray: - """Returns the results of applying the linear transformation to input x.""" - y = jn.dot(x, self.w.value) - if self.b: - y += self.b.value - return y - - def __repr__(self): - s = self.w.value.shape - args = f'nin={s[0]}, nout={s[1]}, use_bias={self.b is not None}, w_init={util.repr_function(self.w_init)}' - return f'{class_name(self)}({args})' - - -class MovingAverage(Module): - """Computes moving average of an input batch.""" - - def __init__(self, shape: Tuple[int, ...], buffer_size: int, init_value: float = 0): - """Creates a MovingAverage module instance. - - Args: - shape: shape of the input tensor. - buffer_size: buffer size for moving average. - init_value: initial value for moving average buffer. - """ - self.init_value = init_value - self.buffer = StateVar(jn.zeros((buffer_size,) + shape) + init_value) - - def __call__(self, x: JaxArray) -> JaxArray: - """Update the statistics using x and return the moving average.""" - self.buffer.value = jn.concatenate([self.buffer.value[1:], x[None]]) - return self.buffer.value.mean(0) - - def __repr__(self): - s = self.buffer.value.shape - return f'{class_name(self)}(shape={s[1:]}, buffer_size={s[0]}, init_value={self.init_value})' - - -class Sequential(ModuleList): - """Executes modules in the order they were passed to the constructor.""" - - @staticmethod - def run_layer(layer: int, f: Callable, args: List, kwargs: Dict): - try: - return f(*args, **util.local_kwargs(kwargs, f)) - except Exception as e: - raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e - - def __call__(self, *args, **kwargs) -> Union[JaxArray, List[JaxArray]]: - """Execute the sequence of operations contained on ``*args`` and ``**kwargs`` and return result.""" - if not self: - return args if len(args) > 1 else args[0] - for i, f in enumerate(self[:-1]): - args = self.run_layer(i, f, args, kwargs) - if not isinstance(args, tuple): - args = (args,) - return self.run_layer(len(self) - 1, self[-1], args, kwargs) - - def __getitem__(self, key: Union[int, slice]): - value = list.__getitem__(self, key) - if isinstance(key, slice): - return Sequential(value) - return value - - -class SyncedBatchNorm(BatchNorm): - """Synchronized batch normalization which aggregates batch statistics across all devices (GPUs/TPUs).""" - - def __call__(self, x: JaxArray, training: bool, batch_norm_update: bool = True) -> JaxArray: - if training: - m = functional.parallel.pmean(x.mean(self.redux, keepdims=True)) - v = functional.parallel.pmean(((x - m) ** 2).mean(self.redux, keepdims=True)) - if batch_norm_update: - self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value) - self.running_var.value += (1 - self.momentum) * (v - self.running_var.value) - else: - m, v = self.running_mean.value, self.running_var.value - y = self.gamma.value * (x - m) * functional.rsqrt(v + self.eps) + self.beta.value - return y - - -class SyncedBatchNorm0D(SyncedBatchNorm): - """Applies a 0D synchronized batch normalization on a 2D-input batch of shape (N,C). - - Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. - Compared to regular batch norm this usually leads to better accuracy at a slight performance cost. - """ - - def __init__(self, nin: int, momentum: float = 0.999, eps: float = 1e-6): - """Creates a SyncedBatchNorm0D module instance. - - Args: - nin: number of channels in the input example. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__((1, nin), (0,), momentum, eps) - - def __repr__(self): - return f'{class_name(self)}(nin={self.beta.value.shape[1]}, momentum={self.momentum}, eps={self.eps})' - - -class SyncedBatchNorm1D(SyncedBatchNorm): - """Applies a 1D synchronized batch normalization on a 3D-input batch of shape (N,C,L). - - Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. - Compared to regular batch norm this usually leads to better accuracy at a slight performance cost. - """ - - def __init__(self, nin: int, momentum: float = 0.999, eps: float = 1e-6): - """Creates a SyncedBatchNorm1D module instance. - - Args: - nin: number of channels in the input example. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__((1, nin, 1), (0, 2), momentum, eps) - - def __repr__(self): - return f'{class_name(self)}(nin={self.beta.value.shape[1]}, momentum={self.momentum}, eps={self.eps})' - - -class SyncedBatchNorm2D(SyncedBatchNorm): - """Applies a 2D synchronized batch normalization on a 4D-input batch of shape (N,C,H,W). - - Synchronized batch normalization aggregated batch statistics across all devices (GPUs/TPUs) on each call. - Compared to regular batch norm this usually leads to better accuracy at a slight performance cost. - """ - - def __init__(self, nin: int, momentum: float = 0.999, eps: float = 1e-6): - """Creates a SyncedBatchNorm2D module instance. - - Args: - nin: number of channels in the input example. - momentum: value used to compute exponential moving average of batch statistics. - eps: small value which is used for numerical stability. - """ - super().__init__((1, nin, 1, 1), (0, 2, 3), momentum, eps) - - def __repr__(self): - return f'{class_name(self)}(nin={self.beta.value.shape[1]}, momentum={self.momentum}, eps={self.eps})' diff --git a/objax/nn/rnn_redesign.py b/objax/nn/rnn_redesign.py new file mode 100644 index 0000000..05506a7 --- /dev/null +++ b/objax/nn/rnn_redesign.py @@ -0,0 +1,233 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax.numpy as jn +import objax + +from objax.typing import JaxArray +from typing import Callable, Tuple, Union + + +class MyRnnCell(objax.Module): + """ Simple RNN cell.""" + + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): + """Creates a MyRnnCell instance. + Args: + nin: dimension of the input tensor. + nstate: hidden state tensor has dimensions ``nin`` by ``nstate``. + activation: activation function for the hidden state layer. + """ + self.op = objax.nn.Sequential([objax.nn.Linear(nin + nstate, nstate), + objax.nn.Linear(nstate, nstate), activation]) + + def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" + return self.op(jn.concatenate((x, state), axis=0)) + + +class FactorizedRnnCell(objax.Module): + """ Factorized version of RNN cell.""" + + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): + """Creates a MyRnnCell instance. + Args: + nin: dimension of the input tensor. + nstate: hidden state tensor has dimensions ``nin`` by ``nstate``. + activation: activation function for the hidden state layer. + """ + self.win = objax.nn.Linear(nin, nstate, use_bias=False) + self.wn = objax.nn.Linear(nstate, nstate) + self.activation = activation + self.nstate = nstate + + def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x``and input ``state``.""" + self.factor = self.win(x) + # TODO(aterzis): Replace with scan(?) + output = [] + for i in range(x.shape[0]): + state = self.activation(self.wn(state) + self.factor[i]) + output_i = jn.reshape(state, (1, self.nstate)) + output.append(output_i) + outputs = jn.concatenate(output, axis=0) + return jn.reshape(state, (1, self.nstate)) + + +class DDLRnnCell(objax.Module): + """ Another simple RNN cell.""" + + def __init__(self, nin: int, nstate: int, activation: Callable = objax.functional.tanh): + """ Creates a DDLRnnCell instance. + Args: + nin: dimension of the input tensor. + nstate: hidden state tensor has dimensions ``nin`` by ``nstate``. + activation: activation function for the hidden state layer. + """ + self.wxh = objax.nn.Linear(nin, nstate, use_bias=False) + self.whh = objax.nn.Linear(nstate, nstate) + self.activation = activation + + def __call__(self, state: JaxArray, x: JaxArray) -> JaxArray: + """Updates and returns hidden state based on input sequence ``x`` and input ``state``.""" + return self.activation(self.whh(state) + self.wxh(x)) + + +def output_layer(state: int, nout: int): + return objax.nn.Linear(state, nout) + + +class RNN(objax.Module): + """Simple Recurrent Neural Network (RNN). + State update is done by the provided RNN cell and output is generated by the + provided output layer. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Execute one step of the RNN. + Args: + state_i: current state. + x_i: input. + Returns: + next state and next output. + """ + next_state = self.cell(state_i, x_i) + next_output = self.output_layer(next_state) + return next_state, next_output + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + Args: + x: input tensor with dimensions ``batch_size`` by ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``batch_size`` by ``state``. + Returns: + Tuple with final RNN state and output with dimensions ``sequence_length`` by ``batch_size`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``state`` if there is no output layer). + """ + final_state, output = objax.functional.scan(self.single, state, x.transpose((1, 0, 2))) + return output, final_state + + +class VectorizedRNN(objax.Module): + """Vectorized Recurrent Neural Network (RNN). + State update is done by the provided RNN cell and output is generated by the + provided output layer. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + def single(self, state_i: JaxArray, x_i: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Execute one step of the RNN. + Args: + state_i: current state. + x_i: input. + Returns: + next state and next output. + """ + next_state = self.cell(state_i, x_i) + next_output = self.output_layer(next_state) + return next_state, next_output + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + Args: + x: input tensor with dimensions ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``(nstate,)``. + Returns: + Tuple with final RNN state and output with dimensions ``sequence_length`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``nstate`` if there is no output layer). + """ + final_state, output = objax.functional.scan(self.single, state, x) + return output, final_state + + +class FactorizedRNN(objax.Module): + """Factorized Recurrent Neural Network (RNN). + State update is done by the provided RNN cell and output is generated by the + provided output layer. + """ + + def __init__(self, cell: objax.Module, output_layer: Union[objax.Module, Callable]): + """Creates an RNN instance. + Args: + cell: RNN cell. + output_layer: output layer can be a function or another module. + """ + self.cell = cell + self.output_layer = output_layer # Is it better inside or outside? + + + def __call__(self, x: JaxArray, state: JaxArray) -> Tuple[JaxArray, JaxArray]: + """Sequentially processes input to generate output. + Args: + x: input tensor with dimensions ``sequence_length`` by ``nin`` + state: Initial RNN state with dimensions ``(state,)``. + Returns: + Tuple with final RNN state and output with dimensions ``sequence_length`` by ``nout``, + where ``nout`` is the output dimension of the output layer (or ``state`` if there is no output layer). + """ + out = self.cell(state, x) + output = self.output_layer(out) + return output, out[-1] + +seq, ns, nin, nout, batch = 7, 10, 3, 4, 64 + +# RNN example +rnn_cell = DDLRnnCell(nin, ns) +out_layer = output_layer(ns, nout) + +r = RNN(rnn_cell, out_layer) +x = objax.random.normal((batch, seq, nin)) +s = jn.zeros((batch, ns)) +y1 = r(x, s) + +# Vectorized version +r = VectorizedRNN(rnn_cell, out_layer) +rnn_vec = objax.Vectorize(r, batch_axis=(0, 0)) +y4 = rnn_vec(x, s) + +assert jn.array_equal(y4[1], y1[1]) +assert jn.array_equal(y4[0], y1[0].transpose((1, 0, 2))) + +# Factorized version +factorized_cell = FactorizedRnnCell(nin, ns) +out = factorized_cell(s[0], x[0, :, :]) +out2 = out_layer(out) + +print("s.shape", s.shape) +print("out2.shape", out2.shape) + +f = FactorizedRNN(factorized_cell, out_layer) +y5 = f(x[0, :, :], s[0]) +print("y5[0].shape", y5[0].shape) +print("y5[1].shape", y5[1].shape) +assert jn.array_equal(y5[0], out2) +assert jn.array_equal(y5[1], out[-1]) diff --git a/objax/zoo/gru.py b/objax/zoo/gru.py new file mode 100644 index 0000000..d9fab2d --- /dev/null +++ b/objax/zoo/gru.py @@ -0,0 +1,93 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +import jax.numpy as jn + +from objax import Module +from objax.nn.init import kaiming_normal +from objax.typing import JaxArray +from objax.variable import TrainVar, StateVar +from objax.functional import sigmoid + + +class GRU(Module): + """ Gated Recurrent Unit (GRU) block.""" + + def __init__(self, + nstate: int, + nin: int, + nout: int, + w_init: Callable = kaiming_normal): + """Creates a GRU instance. + + Args: + nstate: number of hidden units. + nin: number of input units. + nout: number of output units. + w_init: weight initializer for GRU model weights. + """ + self.num_inputs = nin + self.num_outputs = nout + self.nstate = nstate + + # Update gate parameters + self.w_xz = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hz = TrainVar(w_init((self.nstate, self.nstate))) + self.b_z = TrainVar(jn.zeros(self.nstate)) + + # Reset gate parameters + self.w_xr = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hr = TrainVar(w_init((self.nstate, self.nstate))) + self.b_r = TrainVar(jn.zeros(self.nstate)) + + # Candidate hidden state parameters + self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) + self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) + self.b_h = TrainVar(jn.zeros(self.nstate)) + + # Output layer parameters + self.w_hq = TrainVar(w_init((self.nstate, self.num_outputs))) + self.b_q = TrainVar(jn.zeros(self.num_outputs)) + + def init_state(self, batch_size): + """Initialize hidden state for input batch of size ``batch_size``.""" + self.state = StateVar(jn.zeros((batch_size, self.nstate))) + + def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: + """Forward pass through GRU. + + Args: + inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. + only_return_final: return only the last output if ``True``, or all output otherwise.` + + Returns: + Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. + """ + # Dimensions: num_steps, batch_size, vocab_size + outputs = [] + for x in inputs: + update_gate = sigmoid(jn.dot(x, self.w_xz.value) + jn.dot(self.state.value, self.w_hz.value) + + self.b_z. bnhvalue) + reset_gate = sigmoid(jn.dot(x, self.w_xr.value) + jn.dot(self.state.value, self.w_hr.value) + + self.b_r.value) + candidate_state = jn.tanh(jn.dot(x, self.w_xh.value) + + jn.dot(reset_gate * self.state.value, self.w_hh.value) + self.b_h.value) + self.state.value = update_gate * self.state.value + (1 - update_gate) * candidate_state + y = jn.dot(self.state.value, self.w_hq.value) + self.b_q.value + outputs.append(y) + if only_return_final: + return outputs[-1] + return jn.concatenate(outputs, axis=0) diff --git a/objax/zoo/rnn.py b/objax/zoo/rnn.py index ddc3c15..cf9b104 100644 --- a/objax/zoo/rnn.py +++ b/objax/zoo/rnn.py @@ -24,58 +24,6 @@ class RNN(Module): - """ Recurrent Neural Network (RNN) block.""" - - def __init__(self, - nstate: int, - nin: int, - nout: int, - activation: Callable = jn.tanh, - w_init: Callable = kaiming_normal): - """Creates an RNN instance. - - Args: - nstate: number of hidden units. - nin: number of input units. - nout: number of output units. - activation: actication function for hidden layer. - w_init: weight initializer for RNN model weights. - """ - self.num_inputs = nin - self.num_outputs = nout - self.nstate = nstate - self.activation = activation - - # Hidden layer parameters - self.w_xh = TrainVar(w_init((self.num_inputs, self.nstate))) - self.w_hh = TrainVar(w_init((self.nstate, self.nstate))) - self.b_h = TrainVar(jn.zeros(self.nstate)) - - self.output_layer = Linear(self.nstate, self.num_outputs) - def init_state(self, batch_size): """Initialize hidden state for input batch of size ``batch_size``.""" self.state = StateVar(jn.zeros((batch_size, self.nstate))) - - def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: - """Forward pass through RNN. - - Args: - inputs: ``JaxArray`` with dimensions ``num_steps, batch_size, vocabulary_size``. - only_return_final: return only the last output if ``True``, or all output otherwise.` - - Returns: - Output tensor with dimensions ``num_steps * batch_size, vocabulary_size``. - """ - # Dimensions: num_steps, batch_size, vocab_size - outputs = [] - for x in inputs: - self.state.value = self.activation( - jn.dot(x, self.w_xh.value) - + jn.dot(self.state.value, self.w_hh.value) - + self.b_h.value) - y = self.output_layer(self.state.value) - outputs.append(y) - if only_return_final: - return outputs[-1] - return jn.concatenate(outputs, axis=0) diff --git a/requirements.txt b/requirements.txt index 8f33c1e..10fd7ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ scipy -numpy>=1.18.0 -pillow +numpy~=1.17.4 +pillow~=7.2.0 jaxlib jax tensorboard>=2.3.0 parameterized + +setuptools~=46.1.3 \ No newline at end of file diff --git a/tests/testrandom.py b/tests/testrandom.py deleted file mode 100644 index 0e34da2..0000000 --- a/tests/testrandom.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unittests for objax.random.""" - -import unittest - -import numpy as np -import scipy.stats - -import objax - - -class TestRandom(unittest.TestCase): - def helper_test_randint(self, shape, low, high): - """Helper function to test objax.random.randint.""" - value = objax.random.randint(shape, low, high) - self.assertEqual(value.shape, shape) - self.assertTrue(np.all(value >= low)) - self.assertTrue(np.all(value < high)) - - def test_randint(self): - """Test for objax.random.randint.""" - objax.random.DEFAULT_GENERATOR.seed(123) - self.helper_test_randint(shape=(3, 4), low=1, high=10) - self.helper_test_randint(shape=(5,), low=0, high=5) - self.helper_test_randint(shape=(), low=-5, high=5) - - def helper_test_normal(self, shape, stddev): - """Helper function to test objax.random.normal.""" - value = objax.random.normal(shape, stddev=stddev) - self.assertEqual(value.shape, shape) - - def test_normal(self): - """Test for objax.random.normal.""" - objax.random.DEFAULT_GENERATOR.seed(123) - self.helper_test_normal(shape=(4, 2, 3), stddev=1.0) - self.helper_test_normal(shape=(2, 3), stddev=2.0) - self.helper_test_normal(shape=(5,), stddev=2.0) - self.helper_test_normal(shape=(), stddev=10.0) - value = np.array(objax.random.normal((1000, 100))) - self.assertAlmostEqual(value.mean(), 0, delta=0.01) - self.assertAlmostEqual(value.std(), 1, delta=0.01) - value = np.array(objax.random.normal((1000, 100), mean=0, stddev=2)) - self.assertAlmostEqual(value.mean(), 0, delta=0.02) - self.assertAlmostEqual(value.std(), 2, delta=0.01) - value = np.array(objax.random.normal((1000, 100), mean=1, stddev=1.5)) - self.assertAlmostEqual(value.mean(), 1, delta=0.015) - self.assertAlmostEqual(value.std(), 1.5, delta=0.01) - - def helper_test_truncated_normal(self, shape, stddev, bound): - """Helper function to test objax.random.truncated_normal.""" - value = objax.random.truncated_normal(shape, stddev=stddev, lower=-bound, upper=bound) - self.assertEqual(value.shape, shape) - self.assertTrue(np.all(value >= -bound * stddev)) - self.assertTrue(np.all(value <= bound * stddev)) - - def test_truncated_normal(self): - """Test for objax.random.truncated_normal.""" - objax.random.DEFAULT_GENERATOR.seed(123) - self.helper_test_truncated_normal(shape=(5, 7), stddev=1.0, bound=2.0) - self.helper_test_truncated_normal(shape=(4,), stddev=2.0, bound=4.0) - self.helper_test_truncated_normal(shape=(), stddev=1.0, bound=4.0) - value = np.array(objax.random.truncated_normal((1000, 100))) - truncated_std = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1) - self.assertAlmostEqual(value.mean(), 0, delta=0.01) - self.assertAlmostEqual(value.std(), truncated_std, delta=0.01) - self.assertAlmostEqual(value.min(), -1.9, delta=0.1) - self.assertAlmostEqual(value.max(), 1.9, delta=0.1) - value = np.array(objax.random.truncated_normal((1000, 100), stddev=2, lower=-3, upper=3)) - truncated_std = scipy.stats.truncnorm.std(a=-3, b=3, loc=0., scale=2) - self.assertAlmostEqual(value.mean(), 0, delta=0.02) - self.assertAlmostEqual(value.std(), truncated_std, delta=0.01) - self.assertAlmostEqual(value.min(), -5.9, delta=0.1) - self.assertAlmostEqual(value.max(), 5.9, delta=0.1) - value = np.array(objax.random.truncated_normal((1000, 100), stddev=1.5)) - truncated_std = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.5) - self.assertAlmostEqual(value.mean(), 0, delta=0.015) - self.assertAlmostEqual(value.std(), truncated_std, delta=0.01) - self.assertAlmostEqual(value.min(), -2.9, delta=0.1) - self.assertAlmostEqual(value.max(), 2.9, delta=0.1) - - def helper_test_uniform(self, shape): - """Helper function to test objax.random.uniform.""" - value = objax.random.uniform(shape) - self.assertEqual(value.shape, shape) - self.assertTrue(np.all(value >= 0.0)) - self.assertTrue(np.all(value < 1.0)) - - def test_uniform(self): - """Test for objax.random.uniform.""" - objax.random.DEFAULT_GENERATOR.seed(123) - self.helper_test_uniform(shape=(4, 3)) - self.helper_test_uniform(shape=(5,)) - self.helper_test_uniform(shape=()) - - def test_generator(self): - """Test for objax.random.Generator.""" - g1 = objax.random.Generator(0) - g2 = objax.random.Generator(0) - g3 = objax.random.Generator(1) - value1 = objax.random.randint((3, 4), low=0, high=65536, generator=g1) - value2 = objax.random.randint((3, 4), low=0, high=65536, generator=g2) - value3 = objax.random.randint((3, 4), low=0, high=65536, generator=g3) - self.assertTrue(np.all(value1 == value2)) - self.assertFalse(np.all(value1 == value3)) - - g4 = objax.random.Generator(123) - value = [objax.random.randint(shape=(1,), low=0, high=65536, generator=g4) for _ in range(2)] - self.assertNotEqual(value[0], value[1]) - - -if __name__ == '__main__': - unittest.main()