diff --git a/examples/vgg/keras_vgg.py b/examples/vgg/keras_vgg.py new file mode 100644 index 0000000..1a856c1 --- /dev/null +++ b/examples/vgg/keras_vgg.py @@ -0,0 +1,19 @@ +import numpy as np +import tensorflow as tf + +import objax +from objax.zoo import vgg + +mo = vgg.VGG16() +vgg.load_pretrained_weights_from_keras(mo) +print(mo.vars()) + +mk = tf.keras.applications.VGG16(weights='imagenet') +x = np.random.randn(4, 3, 224, 224) +yk = mk(x.transpose((0, 2, 3, 1))) # (4, 1000) + +for name, param in ((weight.name, weight.numpy()) for layer in mk.layers for weight in layer.weights): + print(f'{name:40s} {tuple(param.shape)}') + +yo = objax.functional.softmax(mo(x, training=False)) +print('Max difference:', np.abs(yk - yo).max()) diff --git a/examples/vgg/pytorch_vgg.py b/examples/vgg/pytorch_vgg.py new file mode 100644 index 0000000..511f4c8 --- /dev/null +++ b/examples/vgg/pytorch_vgg.py @@ -0,0 +1,20 @@ +import jax.numpy as jn +import torch +import torchvision + +from objax.zoo import vgg + +mo = vgg.VGG16() +vgg.load_pretrained_weights_from_pytorch(mo) +print(mo.vars()) + +mt = torchvision.models.vgg16(pretrained=True) +mt.eval() # Wow that's error prone +x = torch.randn((4, 3, 224, 224)) +yt = mt(x) # (4, 1000) + +for name, param in mt.state_dict().items(): + print(f'{name:40s} {tuple(param.shape)}') + +yo = mo(x.numpy(), training=False) +print('Max difference:', jn.abs(yt.detach().numpy() - yo).max()) diff --git a/objax/functional/core/pooling.py b/objax/functional/core/pooling.py index 0a9600c..6520207 100644 --- a/objax/functional/core/pooling.py +++ b/objax/functional/core/pooling.py @@ -22,7 +22,7 @@ from objax.constants import ConvPadding from objax.typing import JaxArray, ConvPaddingInt -from objax.util import to_padding, to_tuple +from objax.util.util import to_padding, to_tuple def average_pool_2d(x: JaxArray, diff --git a/objax/io/ops.py b/objax/io/ops.py index bc500b0..76beacb 100644 --- a/objax/io/ops.py +++ b/objax/io/ops.py @@ -21,7 +21,7 @@ import jax.numpy as jn import numpy as np -from objax.util import Renamer +from objax.util.util import Renamer from objax.variable import TrainRef, VarCollection diff --git a/objax/module.py b/objax/module.py index dc19e60..f85b469 100644 --- a/objax/module.py +++ b/objax/module.py @@ -22,7 +22,7 @@ from jax.interpreters.pxla import ShardedDeviceArray from objax.typing import JaxArray -from objax.util import override_args_kwargs, positional_args_names +from objax.util.util import override_args_kwargs, positional_args_names from objax.variable import BaseVar, RandomState, VarCollection diff --git a/objax/typing.py b/objax/typing.py index 1c8a083..0ffd422 100644 --- a/objax/typing.py +++ b/objax/typing.py @@ -14,7 +14,7 @@ """This module contains type declarations for Objax.""" -__all__ = ['FileOrStr', 'JaxArray', 'JaxDType'] +__all__ = ['ConvPaddingInt', 'FileOrStr', 'JaxArray', 'JaxDType'] from typing import Union, IO, BinaryIO, Sequence, Tuple diff --git a/objax/util/__init__.py b/objax/util/__init__.py index 2b2dbff..3fcf77a 100644 --- a/objax/util/__init__.py +++ b/objax/util/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import image from . import check +from . import convert +from . import image from .util import * diff --git a/objax/util/convert/__init__.py b/objax/util/convert/__init__.py new file mode 100644 index 0000000..c62d85b --- /dev/null +++ b/objax/util/convert/__init__.py @@ -0,0 +1,3 @@ +from . import keras +from . import pytorch +from .convert import * diff --git a/objax/util/convert/convert.py b/objax/util/convert/convert.py new file mode 100644 index 0000000..cc59bd2 --- /dev/null +++ b/objax/util/convert/convert.py @@ -0,0 +1,28 @@ +__all__ = ['assign', 'import_weights'] + +import re +from typing import Dict, Callable + +import jax.numpy as jn +import numpy as np + +from objax.variable import BaseVar, VarCollection + + +def assign(x: BaseVar, v: np.ndarray): + x.assign(jn.array(v.reshape(x.value.shape))) + + +def import_weights(target_vc: VarCollection, + source_numpy: Dict[str, np.ndarray], + target_to_source_names: Dict[str, str], + numpy_convert: Dict[str, Callable[[BaseVar, np.ndarray], None]]): + module_var = re.compile(r'.*(\([^)]*\)\.[^(]*)$') + for k, v in target_vc.items(): + s = target_to_source_names[k] + t = module_var.match(k).group(1) + if s not in source_numpy: + print(f'Skipping {k} ({s})') + continue + assert t in numpy_convert, f'Unhandled name {k}' + numpy_convert[t](v, source_numpy[s]) diff --git a/objax/util/convert/keras.py b/objax/util/convert/keras.py new file mode 100644 index 0000000..a2d506d --- /dev/null +++ b/objax/util/convert/keras.py @@ -0,0 +1,27 @@ +__all__ = ['ARRAY_CONVERT', 'rename'] + +import re + +from .convert import assign + +ARRAY_CONVERT = { + # '(BatchNorm2D).beta': assign, + # '(BatchNorm2D).gamma': assign, + # '(BatchNorm2D).running_mean': assign, + # '(BatchNorm2D).running_var': assign, + '(Conv2D).b': assign, + '(Conv2D).w': assign, + '(Linear).b': assign, + '(Linear).w': assign, +} + + +def rename(x): + # x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias') + x = re.sub(r'\([^)]*\)', '', x) + x = re.sub(r'^\.', '', x) + x = re.sub(r'.w$', '/kernel', x) + x = re.sub(r'.b$', '/bias', x) + x = re.sub(r'\[|\]', '', x) + x = re.sub(r'\.', '_', x) + return x diff --git a/objax/util/convert/pytorch.py b/objax/util/convert/pytorch.py new file mode 100644 index 0000000..143e457 --- /dev/null +++ b/objax/util/convert/pytorch.py @@ -0,0 +1,26 @@ +__all__ = ['ARRAY_CONVERT', 'rename'] + +import re + +from .convert import assign + +ARRAY_CONVERT = { + '(BatchNorm2D).beta': assign, + '(BatchNorm2D).gamma': assign, + '(BatchNorm2D).running_mean': assign, + '(BatchNorm2D).running_var': assign, + '(Conv2D).b': assign, + '(Conv2D).w': lambda x, y: assign(x, y.transpose((2, 3, 1, 0))), + '(Linear).b': assign, + '(Linear).w': lambda x, y: assign(x, y.T), +} + + +def rename(x): + x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias') + x = re.sub(r'\([^)]*\)', '', x) + x = re.sub(r'^\.', '', x) + x = re.sub('.w$', '.weight', x) + x = re.sub('.b$', '.bias', x) + x = x.replace('[', '.').replace(']', '') + return x diff --git a/objax/variable.py b/objax/variable.py index 1330baf..1f9c72a 100644 --- a/objax/variable.py +++ b/objax/variable.py @@ -24,8 +24,8 @@ import numpy as np from objax.typing import JaxArray -from objax.util import map_to_device, Renamer from objax.util.check import assert_assigned_type_and_shape_match +from objax.util.util import map_to_device, Renamer def reduce_mean(x: JaxArray) -> JaxArray: diff --git a/objax/zoo/vgg.py b/objax/zoo/vgg.py index c60db77..20fadcf 100644 --- a/objax/zoo/vgg.py +++ b/objax/zoo/vgg.py @@ -1,144 +1,156 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module with VGG-19 implementation. - -See https://arxiv.org/abs/1409.1556 for detail. -""" - -import functools -import os -from urllib import request - -import jax.numpy as jn -import numpy as np +__all__ = ['VGG', 'VGG11', 'VGG11_BN', 'VGG13', 'VGG13_BN', 'VGG16', 'VGG16_BN', 'VGG19', 'VGG19_BN', + 'load_pretrained_weights_from_keras', 'load_pretrained_weights_from_pytorch'] + +from typing import Union, Sequence import objax +from objax.util.convert import import_weights, pytorch + +OPS = dict( + vgg11=(64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'), + vgg13=(64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'), + vgg16=(64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'), + vgg19=(64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M') +) + -_VGG19_URL = 'https://github.com/machrisaa/tensorflow-vgg' -_VGG19_NPY = './objax/zoo/pretrained/vgg19.npy' -_SYNSET_URL = 'https://raw.githubusercontent.com/machrisaa/tensorflow-vgg/master/synset.txt' -_SYNSET_PATH = './objax/zoo/pretrained/synset.txt' - - -def preprocess(x): - bgr_mean = [103.939, 116.779, 123.68] - red, green, blue = [x[:, i, :, :] for i in range(3)] - return jn.stack([blue - bgr_mean[0], green - bgr_mean[1], red - bgr_mean[2]], axis=1) - - -def max_pool_2d(x): - return functools.partial(objax.functional.max_pool_2d, - size=2, strides=2, padding=objax.constants.ConvPadding.VALID)(x) - - -class VGG19(objax.nn.Sequential): - """VGG19 implementation.""" - - def __init__(self, pretrained=False): - """Creates VGG19 instance. - - Args: - pretrained: if True load weights from ImageNet pretrained model. - """ - if not os.path.exists(_VGG19_NPY): - raise FileNotFoundError( - 'You must download vgg19.npy from %s and save it to %s' % (_VGG19_URL, _VGG19_NPY)) - if not os.path.exists(_SYNSET_PATH): - request.urlretrieve(_SYNSET_URL, _SYNSET_PATH) - self.data_dict = np.load(_VGG19_NPY, encoding='latin1', allow_pickle=True).item() - self.pretrained = pretrained - self.ops = self.build() - super().__init__(self.ops) - - def build(self): - # inputs in [0, 255] - self.preprocess = preprocess - self.conv1_1 = objax.nn.Conv2D(nin=3, nout=64, k=3) - self.relu1_1 = objax.functional.relu - self.conv1_2 = objax.nn.Conv2D(nin=64, nout=64, k=3) - self.relu1_2 = objax.functional.relu - self.pool1 = max_pool_2d - - self.conv2_1 = objax.nn.Conv2D(nin=64, nout=128, k=3) - self.relu2_1 = objax.functional.relu - self.conv2_2 = objax.nn.Conv2D(nin=128, nout=128, k=3) - self.relu2_2 = objax.functional.relu - self.pool2 = max_pool_2d - - self.conv3_1 = objax.nn.Conv2D(nin=128, nout=256, k=3) - self.relu3_1 = objax.functional.relu - self.conv3_2 = objax.nn.Conv2D(nin=256, nout=256, k=3) - self.relu3_2 = objax.functional.relu - self.conv3_3 = objax.nn.Conv2D(nin=256, nout=256, k=3) - self.relu3_3 = objax.functional.relu - self.conv3_4 = objax.nn.Conv2D(nin=256, nout=256, k=3) - self.relu3_4 = objax.functional.relu - self.pool3 = max_pool_2d - - self.conv4_1 = objax.nn.Conv2D(nin=256, nout=512, k=3) - self.relu4_1 = objax.functional.relu - self.conv4_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu4_2 = objax.functional.relu - self.conv4_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu4_3 = objax.functional.relu - self.conv4_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu4_4 = objax.functional.relu - self.pool4 = max_pool_2d - - self.conv5_1 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_1 = objax.functional.relu - self.conv5_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_2 = objax.functional.relu - self.conv5_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_3 = objax.functional.relu - self.conv5_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_4 = objax.functional.relu - self.pool5 = max_pool_2d - - self.flatten = objax.functional.flatten - self.fc6 = objax.nn.Linear(nin=512 * 7 * 7, nout=4096) - self.relu6 = objax.functional.relu - self.fc7 = objax.nn.Linear(nin=4096, nout=4096) - self.relu7 = objax.functional.relu - self.fc8 = objax.nn.Linear(nin=4096, nout=1000) - - if self.pretrained: - for it in self.data_dict: - if it.startswith('conv'): - conv = getattr(self, it) - kernel, bias = self.data_dict[it] - conv.w = objax.TrainVar(jn.array(kernel)) - conv.b = objax.TrainVar(jn.array(bias[:, None, None])) - setattr(self, it, conv) - elif it.startswith('fc'): - linear = getattr(self, it) - kernel, bias = self.data_dict[it] - if it == 'fc6': - kernel = kernel.reshape([7, 7, 512, -1]).transpose((2, 0, 1, 3)).reshape([512 * 7 * 7, -1]) - linear.w = objax.TrainVar(jn.array(kernel)) - linear.b = objax.TrainVar(jn.array(bias)) - setattr(self, it, linear) - - ops = [self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1, - self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2, - self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2, - self.conv3_3, self.relu3_3, self.conv3_4, self.relu3_4, self.pool3, - self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2, - self.conv4_3, self.relu4_3, self.conv4_4, self.relu4_4, self.pool4, - self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2, - self.conv5_3, self.relu5_3, self.conv5_4, self.relu5_4, self.pool5, - self.flatten, self.fc6, self.relu6, self.fc7, self.relu7, self.fc8] - - return ops +class VGG(objax.Module): + def __init__(self, nin: int, nout: int, ops: Sequence[Union[str, int]], use_bn: bool): + self.use_bn = use_bn + self.ops = tuple(ops) + n = nin + self.features = objax.nn.Sequential() + for v in ops: + if v == 'M': + self.features.append(objax.functional.max_pool_2d) + continue + self.features.append(objax.nn.Conv2D(n, v, 3, padding=1)) + if use_bn: + self.features.append(objax.nn.BatchNorm2D(v, momentum=0.1, eps=1e-5)) + self.features.append(objax.functional.relu) + n = v + + self.classifier = objax.nn.Sequential([objax.nn.Linear(512 * 7 * 7, 4096), objax.functional.relu, + objax.nn.Dropout(0.5), + objax.nn.Linear(4096, 4096), objax.functional.relu, + objax.nn.Dropout(0.5), + objax.nn.Linear(4096, nout)]) + + def __call__(self, *args, **kwargs): + features = objax.functional.flatten(self.features(*args, **kwargs)) + return self.classifier(features, **kwargs) + + def __repr__(self): + nin, nout = self.features[0].w.value.shape[2:] + return f'{self.__class__.__name__}(nin={nin}, nout={nout}, ops={self.ops}, use_bn={self.use_bn})' + + +class CustomVGG(VGG): + def __repr__(self): + nin, nout = self.features[0].w.value.shape[2:] + return f'{self.__class__.__name__}(nin={nin}, nout={nout})' + + +class VGG11(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg11'], use_bn=False) + + +class VGG11_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg11'], use_bn=True) + + +class VGG13(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg13'], use_bn=False) + + +class VGG13_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg13'], use_bn=True) + + +class VGG16(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg16'], use_bn=False) + + +class VGG16_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg16'], use_bn=True) + + +class VGG19(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg19'], use_bn=False) + + +class VGG19_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg19'], use_bn=True) + + +def load_pretrained_weights_from_keras(m: VGG): + import tensorflow as tf + assert hasattr(tf.keras.applications, m.__class__.__name__), \ + f'No Keras pretrained model for {m.__class__.__name__}' + # 1. Get Keras model + keras_model = getattr(tf.keras.applications, m.__class__.__name__)(weights='imagenet') + + # 2. Get Keras model weights + keras_numpy = {weight.name.split(':')[0]: weight.numpy() # Remove :0 at the end of the variable name. + for layer in keras_model.layers for weight in layer.weights} + # 2.1 Flattening differs between NHWC and NCHW: convert first linear layer post-flattening. + nhwc_kernel = keras_numpy['fc1/kernel'].reshape((7, 7, 512, 4096)) + keras_numpy['fc1/kernel'] = nhwc_kernel.transpose((2, 0, 1, 3)).reshape((-1, 4096)) + + # 3. Map Objax names to Keras names. + # The architectures are syntactically different (Objax uses Sequential while Keras does not). + # So we have to map the name semi-manually since there's no automatic way to do it. + keras_names = {k: objax.util.convert.keras.rename(k) for k in m.vars().keys()} + to_keras = { + 'classifier0': 'fc1', + 'classifier3': 'fc2', + 'classifier6': 'predictions', + } + + # The features in Keras are of the form "block{i}_conv{j}/variable" + # In Objax they are of the form "features{pos}/variable" + # Below we convert list position to block_conv. + target_to_source_names = {} + block_id, conv_id, seq_id = 1, 1, 0 + for k, v in keras_names.items(): + if '/' not in v: + target_to_source_names[k] = v + continue + layer, variable = v.split('/') + if layer.startswith('features'): + new_seq_id = int(layer[8:]) + if new_seq_id - seq_id == 2: + conv_id += 1 + elif new_seq_id - seq_id == 3: + block_id += 1 + conv_id = 1 + else: + assert new_seq_id == seq_id + seq_id = new_seq_id + target_to_source_names[k] = f'block{block_id}_conv{conv_id}/{variable}' + elif layer.startswith('classifier'): + target_to_source_names[k] = f'{to_keras[layer]}/{variable}' + else: + target_to_source_names[k] = v + + objax.util.convert.import_weights(m.vars(), keras_numpy, target_to_source_names, + objax.util.convert.keras.ARRAY_CONVERT) + + +def load_pretrained_weights_from_pytorch(m: VGG): + import torchvision + assert hasattr(torchvision.models, m.__class__.__name__.lower()), \ + f'No TorchVision pretrained model for {m.__class__.__name__.lower()}' + torch_model = getattr(torchvision.models, m.__class__.__name__.lower())(pretrained=True) + torch_model.eval() # Just a safety precaution. + torch_numpy = {name: param.numpy() for name, param in torch_model.state_dict().items()} + target_to_source_names = {k: pytorch.rename(k) for k in m.vars().keys()} + import_weights(m.vars(), torch_numpy, target_to_source_names, pytorch.ARRAY_CONVERT)