From e4b565f63821607cc35e59239dd5a245cfb19dee Mon Sep 17 00:00:00 2001 From: David Berthelot Date: Tue, 17 Nov 2020 22:43:34 -0800 Subject: [PATCH 1/3] APIs to import weights from external frameworks. Demonstrated on PyTorch VGG. The core function is objax.util.convert.import_weights: - It takes a variable collection `target_vc` (basically the model variables for which to set the weights). - Then `source_numpy` is a dictionary of numpy values and their names (in a possibly different naming convention than Objax', say for example PyTorch). - `source_names` maps objax names to the `source_numpy` names, so that from a variable we can find the numpy value to set its weight. - `numpy_convert` is a dictionary that maps module variables names to actions (functions). The actions are used to perform conversions (like transpositions or reshaping for example). Here are some examples: ```python ARRAY_CONVERT = { '(BatchNorm2D).beta': assign(), '(BatchNorm2D).gamma': assign, '(BatchNorm2D).running_mean': assign, '(BatchNorm2D).running_var': assign, '(Conv2D).b': assign, '(Conv2D).w': lambda x, y: assign(x, y.transpose((2, 3, 1, 0))), '(Linear).b': assign, '(Linear).w': lambda x, y: assign(x, y.T), } ``` --- examples/vgg_pytorch/pytorch_vgg.py | 25 ++++ objax/util/__init__.py | 3 +- objax/util/convert/__init__.py | 2 + objax/util/convert/convert.py | 28 ++++ objax/util/convert/pytorch.py | 26 ++++ objax/zoo/vgg.py | 206 +++++++++------------------- 6 files changed, 148 insertions(+), 142 deletions(-) create mode 100644 examples/vgg_pytorch/pytorch_vgg.py create mode 100644 objax/util/convert/__init__.py create mode 100644 objax/util/convert/convert.py create mode 100644 objax/util/convert/pytorch.py diff --git a/examples/vgg_pytorch/pytorch_vgg.py b/examples/vgg_pytorch/pytorch_vgg.py new file mode 100644 index 0000000..7c76a3e --- /dev/null +++ b/examples/vgg_pytorch/pytorch_vgg.py @@ -0,0 +1,25 @@ +import jax.numpy as jn +import torch +import torchvision + +from objax.zoo import vgg + + +def delta(x, y): # pytoch, jax + return jn.abs(x.detach().numpy() - y).max() + + +mo = vgg.vgg16(use_bn=False) +vgg.load_pretrained_weights_from_pytorch(mo) +print(mo.vars()) + +mt = torchvision.models.vgg16(pretrained=True) +mt.eval() # Wow that's error prone +x = torch.randn((4, 3, 224, 224)) +yt = mt(x) # (4, 1000) + +for name, param in mt.state_dict().items(): + print(f'{name:40s} {tuple(param.shape)}') + +yo = mo(x.numpy(), training=False) +print('Max difference:', jn.abs(yt.detach().numpy() - yo).max()) diff --git a/objax/util/__init__.py b/objax/util/__init__.py index 2b2dbff..3fcf77a 100644 --- a/objax/util/__init__.py +++ b/objax/util/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import image from . import check +from . import convert +from . import image from .util import * diff --git a/objax/util/convert/__init__.py b/objax/util/convert/__init__.py new file mode 100644 index 0000000..589631c --- /dev/null +++ b/objax/util/convert/__init__.py @@ -0,0 +1,2 @@ +from . import pytorch +from .convert import * diff --git a/objax/util/convert/convert.py b/objax/util/convert/convert.py new file mode 100644 index 0000000..46fefb6 --- /dev/null +++ b/objax/util/convert/convert.py @@ -0,0 +1,28 @@ +__all__ = ['assign', 'import_weights'] + +import re +from typing import Dict, Callable + +import jax.numpy as jn +import numpy as np + +import objax + + +def assign(x: objax.BaseVar, v: np.ndarray): + x.assign(jn.array(v.reshape(x.value.shape))) + + +def import_weights(target_vc: objax.VarCollection, + source_numpy: Dict[str, np.ndarray], + source_names: Dict[str, str], + numpy_convert: Dict[str, Callable[[objax.BaseVar, np.ndarray], None]]): + module_var = re.compile(r'.*(\([^)]*\)\.[^(]*)$') + for k, v in target_vc.items(): + s = source_names[k] + t = module_var.match(k).group(1) + if s not in source_numpy: + print(f'Skipping {k} ({s})') + continue + assert t in numpy_convert, f'Unhandled name {k}' + numpy_convert[t](v, source_numpy[s]) diff --git a/objax/util/convert/pytorch.py b/objax/util/convert/pytorch.py new file mode 100644 index 0000000..c352e76 --- /dev/null +++ b/objax/util/convert/pytorch.py @@ -0,0 +1,26 @@ +__all__ = ['ARRAY_CONVERT', 'rename'] + +import re + +from objax.util.convert import assign + +ARRAY_CONVERT = { + '(BatchNorm2D).beta': assign, + '(BatchNorm2D).gamma': assign, + '(BatchNorm2D).running_mean': assign, + '(BatchNorm2D).running_var': assign, + '(Conv2D).b': assign, + '(Conv2D).w': lambda x, y: assign(x, y.transpose((2, 3, 1, 0))), + '(Linear).b': assign, + '(Linear).w': lambda x, y: assign(x, y.T), +} + + +def rename(x): + x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias') + x = re.sub(r'\([^)]*\)', '', x) + x = re.sub(r'^\.', '', x) + x = re.sub('.w$', '.weight', x) + x = re.sub('.b$', '.bias', x) + x = x.replace('[', '.').replace(']', '') + return x diff --git a/objax/zoo/vgg.py b/objax/zoo/vgg.py index c60db77..6f449d0 100644 --- a/objax/zoo/vgg.py +++ b/objax/zoo/vgg.py @@ -1,144 +1,68 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module with VGG-19 implementation. - -See https://arxiv.org/abs/1409.1556 for detail. -""" - -import functools -import os -from urllib import request - -import jax.numpy as jn -import numpy as np +__all__ = ['VGG', 'load_pretrained_weights_from_pytorch', 'vgg11', 'vgg13', 'vgg16', 'vgg19'] + +from typing import Union, Sequence import objax +from objax.util.convert import import_weights, pytorch + + +class VGG(objax.Module): + def __init__(self, nin: int, nout: int, ops: Sequence[Union[str, int]], use_bn: bool, name: str): + self.name = name + ('_bn' if use_bn else '') + self.ops = tuple(ops) + n = nin + self.features = objax.nn.Sequential() + for v in ops: + if v == 'M': + self.features.append(objax.functional.max_pool_2d) + continue + self.features.append(objax.nn.Conv2D(n, v, 3, padding=1)) + if use_bn: + self.features.append(objax.nn.BatchNorm2D(v, momentum=0.1, eps=1e-5)) + self.features.append(objax.functional.relu) + n = v + + self.classifier = objax.nn.Sequential([objax.nn.Linear(512 * 7 * 7, 4096), objax.functional.relu, + objax.nn.Dropout(0.5), + objax.nn.Linear(4096, 4096), objax.functional.relu, + objax.nn.Dropout(0.5), + objax.nn.Linear(4096, nout)]) + + def __call__(self, *args, **kwargs): + features = objax.functional.flatten(self.features(*args, **kwargs)) + return self.classifier(features, **kwargs) + + def __repr__(self): + use_bn = self.name.endswith('_bn') + name = self.name[:-3] if use_bn else self.name + return f'{self.__class__.__name__}(nin={self.features[0].w.value.shape[2]}, ' \ + f'nout={self.features[0].w.value.shape[3]}, ops={self.ops}, use_bn={use_bn}, name={repr(name)})' + + +def load_pretrained_weights_from_pytorch(m: VGG): + import torchvision + torch_model = getattr(torchvision.models, m.name)(pretrained=True) + torch_model.eval() # Just a safety precaution. + numpy_arrays = {name: param.numpy() for name, param in torch_model.state_dict().items()} + numpy_names = {k: pytorch.rename(k) for k in m.vars().keys()} + import_weights(m.vars(), numpy_arrays, numpy_names, pytorch.ARRAY_CONVERT) + + +def vgg11(use_bn: bool): + ops = 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' + return VGG(3, 1000, ops, use_bn=use_bn, name='vgg11') + + +def vgg13(use_bn: bool): + ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' + return VGG(3, 1000, ops, use_bn=use_bn, name='vgg13') + + +def vgg16(use_bn: bool): + ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M' + return VGG(3, 1000, ops, use_bn=use_bn, name='vgg16') + -_VGG19_URL = 'https://github.com/machrisaa/tensorflow-vgg' -_VGG19_NPY = './objax/zoo/pretrained/vgg19.npy' -_SYNSET_URL = 'https://raw.githubusercontent.com/machrisaa/tensorflow-vgg/master/synset.txt' -_SYNSET_PATH = './objax/zoo/pretrained/synset.txt' - - -def preprocess(x): - bgr_mean = [103.939, 116.779, 123.68] - red, green, blue = [x[:, i, :, :] for i in range(3)] - return jn.stack([blue - bgr_mean[0], green - bgr_mean[1], red - bgr_mean[2]], axis=1) - - -def max_pool_2d(x): - return functools.partial(objax.functional.max_pool_2d, - size=2, strides=2, padding=objax.constants.ConvPadding.VALID)(x) - - -class VGG19(objax.nn.Sequential): - """VGG19 implementation.""" - - def __init__(self, pretrained=False): - """Creates VGG19 instance. - - Args: - pretrained: if True load weights from ImageNet pretrained model. - """ - if not os.path.exists(_VGG19_NPY): - raise FileNotFoundError( - 'You must download vgg19.npy from %s and save it to %s' % (_VGG19_URL, _VGG19_NPY)) - if not os.path.exists(_SYNSET_PATH): - request.urlretrieve(_SYNSET_URL, _SYNSET_PATH) - self.data_dict = np.load(_VGG19_NPY, encoding='latin1', allow_pickle=True).item() - self.pretrained = pretrained - self.ops = self.build() - super().__init__(self.ops) - - def build(self): - # inputs in [0, 255] - self.preprocess = preprocess - self.conv1_1 = objax.nn.Conv2D(nin=3, nout=64, k=3) - self.relu1_1 = objax.functional.relu - self.conv1_2 = objax.nn.Conv2D(nin=64, nout=64, k=3) - self.relu1_2 = objax.functional.relu - self.pool1 = max_pool_2d - - self.conv2_1 = objax.nn.Conv2D(nin=64, nout=128, k=3) - self.relu2_1 = objax.functional.relu - self.conv2_2 = objax.nn.Conv2D(nin=128, nout=128, k=3) - self.relu2_2 = objax.functional.relu - self.pool2 = max_pool_2d - - self.conv3_1 = objax.nn.Conv2D(nin=128, nout=256, k=3) - self.relu3_1 = objax.functional.relu - self.conv3_2 = objax.nn.Conv2D(nin=256, nout=256, k=3) - self.relu3_2 = objax.functional.relu - self.conv3_3 = objax.nn.Conv2D(nin=256, nout=256, k=3) - self.relu3_3 = objax.functional.relu - self.conv3_4 = objax.nn.Conv2D(nin=256, nout=256, k=3) - self.relu3_4 = objax.functional.relu - self.pool3 = max_pool_2d - - self.conv4_1 = objax.nn.Conv2D(nin=256, nout=512, k=3) - self.relu4_1 = objax.functional.relu - self.conv4_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu4_2 = objax.functional.relu - self.conv4_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu4_3 = objax.functional.relu - self.conv4_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu4_4 = objax.functional.relu - self.pool4 = max_pool_2d - - self.conv5_1 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_1 = objax.functional.relu - self.conv5_2 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_2 = objax.functional.relu - self.conv5_3 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_3 = objax.functional.relu - self.conv5_4 = objax.nn.Conv2D(nin=512, nout=512, k=3) - self.relu5_4 = objax.functional.relu - self.pool5 = max_pool_2d - - self.flatten = objax.functional.flatten - self.fc6 = objax.nn.Linear(nin=512 * 7 * 7, nout=4096) - self.relu6 = objax.functional.relu - self.fc7 = objax.nn.Linear(nin=4096, nout=4096) - self.relu7 = objax.functional.relu - self.fc8 = objax.nn.Linear(nin=4096, nout=1000) - - if self.pretrained: - for it in self.data_dict: - if it.startswith('conv'): - conv = getattr(self, it) - kernel, bias = self.data_dict[it] - conv.w = objax.TrainVar(jn.array(kernel)) - conv.b = objax.TrainVar(jn.array(bias[:, None, None])) - setattr(self, it, conv) - elif it.startswith('fc'): - linear = getattr(self, it) - kernel, bias = self.data_dict[it] - if it == 'fc6': - kernel = kernel.reshape([7, 7, 512, -1]).transpose((2, 0, 1, 3)).reshape([512 * 7 * 7, -1]) - linear.w = objax.TrainVar(jn.array(kernel)) - linear.b = objax.TrainVar(jn.array(bias)) - setattr(self, it, linear) - - ops = [self.conv1_1, self.relu1_1, self.conv1_2, self.relu1_2, self.pool1, - self.conv2_1, self.relu2_1, self.conv2_2, self.relu2_2, self.pool2, - self.conv3_1, self.relu3_1, self.conv3_2, self.relu3_2, - self.conv3_3, self.relu3_3, self.conv3_4, self.relu3_4, self.pool3, - self.conv4_1, self.relu4_1, self.conv4_2, self.relu4_2, - self.conv4_3, self.relu4_3, self.conv4_4, self.relu4_4, self.pool4, - self.conv5_1, self.relu5_1, self.conv5_2, self.relu5_2, - self.conv5_3, self.relu5_3, self.conv5_4, self.relu5_4, self.pool5, - self.flatten, self.fc6, self.relu6, self.fc7, self.relu7, self.fc8] - - return ops +def vgg19(use_bn: bool): + ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M' + return VGG(3, 1000, ops, use_bn=use_bn, name='vgg19') From b375c9124c58162007eae21842af328b0bde36c3 Mon Sep 17 00:00:00 2001 From: David Berthelot Date: Tue, 17 Nov 2020 23:49:23 -0800 Subject: [PATCH 2/3] Reworked imports to avoid cyclic dependencies. --- objax/functional/core/pooling.py | 2 +- objax/io/ops.py | 2 +- objax/module.py | 2 +- objax/typing.py | 2 +- objax/util/convert/convert.py | 8 ++++---- objax/util/convert/pytorch.py | 2 +- objax/variable.py | 2 +- 7 files changed, 10 insertions(+), 10 deletions(-) diff --git a/objax/functional/core/pooling.py b/objax/functional/core/pooling.py index 0a9600c..6520207 100644 --- a/objax/functional/core/pooling.py +++ b/objax/functional/core/pooling.py @@ -22,7 +22,7 @@ from objax.constants import ConvPadding from objax.typing import JaxArray, ConvPaddingInt -from objax.util import to_padding, to_tuple +from objax.util.util import to_padding, to_tuple def average_pool_2d(x: JaxArray, diff --git a/objax/io/ops.py b/objax/io/ops.py index bc500b0..76beacb 100644 --- a/objax/io/ops.py +++ b/objax/io/ops.py @@ -21,7 +21,7 @@ import jax.numpy as jn import numpy as np -from objax.util import Renamer +from objax.util.util import Renamer from objax.variable import TrainRef, VarCollection diff --git a/objax/module.py b/objax/module.py index dc19e60..f85b469 100644 --- a/objax/module.py +++ b/objax/module.py @@ -22,7 +22,7 @@ from jax.interpreters.pxla import ShardedDeviceArray from objax.typing import JaxArray -from objax.util import override_args_kwargs, positional_args_names +from objax.util.util import override_args_kwargs, positional_args_names from objax.variable import BaseVar, RandomState, VarCollection diff --git a/objax/typing.py b/objax/typing.py index 1c8a083..0ffd422 100644 --- a/objax/typing.py +++ b/objax/typing.py @@ -14,7 +14,7 @@ """This module contains type declarations for Objax.""" -__all__ = ['FileOrStr', 'JaxArray', 'JaxDType'] +__all__ = ['ConvPaddingInt', 'FileOrStr', 'JaxArray', 'JaxDType'] from typing import Union, IO, BinaryIO, Sequence, Tuple diff --git a/objax/util/convert/convert.py b/objax/util/convert/convert.py index 46fefb6..9c0cdcc 100644 --- a/objax/util/convert/convert.py +++ b/objax/util/convert/convert.py @@ -6,17 +6,17 @@ import jax.numpy as jn import numpy as np -import objax +from objax.variable import BaseVar, VarCollection -def assign(x: objax.BaseVar, v: np.ndarray): +def assign(x: BaseVar, v: np.ndarray): x.assign(jn.array(v.reshape(x.value.shape))) -def import_weights(target_vc: objax.VarCollection, +def import_weights(target_vc: VarCollection, source_numpy: Dict[str, np.ndarray], source_names: Dict[str, str], - numpy_convert: Dict[str, Callable[[objax.BaseVar, np.ndarray], None]]): + numpy_convert: Dict[str, Callable[[BaseVar, np.ndarray], None]]): module_var = re.compile(r'.*(\([^)]*\)\.[^(]*)$') for k, v in target_vc.items(): s = source_names[k] diff --git a/objax/util/convert/pytorch.py b/objax/util/convert/pytorch.py index c352e76..143e457 100644 --- a/objax/util/convert/pytorch.py +++ b/objax/util/convert/pytorch.py @@ -2,7 +2,7 @@ import re -from objax.util.convert import assign +from .convert import assign ARRAY_CONVERT = { '(BatchNorm2D).beta': assign, diff --git a/objax/variable.py b/objax/variable.py index 1330baf..1f9c72a 100644 --- a/objax/variable.py +++ b/objax/variable.py @@ -24,8 +24,8 @@ import numpy as np from objax.typing import JaxArray -from objax.util import map_to_device, Renamer from objax.util.check import assert_assigned_type_and_shape_match +from objax.util.util import map_to_device, Renamer def reduce_mean(x: JaxArray) -> JaxArray: From 8a2df3bf71b382e316aaef8b6dcbb5dc4709551a Mon Sep 17 00:00:00 2001 From: David Berthelot Date: Sun, 22 Nov 2020 12:10:14 -0800 Subject: [PATCH 3/3] VGG import weights from Keras as well. --- examples/vgg/keras_vgg.py | 19 +++ examples/{vgg_pytorch => vgg}/pytorch_vgg.py | 7 +- objax/util/convert/__init__.py | 1 + objax/util/convert/convert.py | 4 +- objax/util/convert/keras.py | 27 ++++ objax/zoo/vgg.py | 140 +++++++++++++++---- 6 files changed, 164 insertions(+), 34 deletions(-) create mode 100644 examples/vgg/keras_vgg.py rename examples/{vgg_pytorch => vgg}/pytorch_vgg.py (81%) create mode 100644 objax/util/convert/keras.py diff --git a/examples/vgg/keras_vgg.py b/examples/vgg/keras_vgg.py new file mode 100644 index 0000000..1a856c1 --- /dev/null +++ b/examples/vgg/keras_vgg.py @@ -0,0 +1,19 @@ +import numpy as np +import tensorflow as tf + +import objax +from objax.zoo import vgg + +mo = vgg.VGG16() +vgg.load_pretrained_weights_from_keras(mo) +print(mo.vars()) + +mk = tf.keras.applications.VGG16(weights='imagenet') +x = np.random.randn(4, 3, 224, 224) +yk = mk(x.transpose((0, 2, 3, 1))) # (4, 1000) + +for name, param in ((weight.name, weight.numpy()) for layer in mk.layers for weight in layer.weights): + print(f'{name:40s} {tuple(param.shape)}') + +yo = objax.functional.softmax(mo(x, training=False)) +print('Max difference:', np.abs(yk - yo).max()) diff --git a/examples/vgg_pytorch/pytorch_vgg.py b/examples/vgg/pytorch_vgg.py similarity index 81% rename from examples/vgg_pytorch/pytorch_vgg.py rename to examples/vgg/pytorch_vgg.py index 7c76a3e..511f4c8 100644 --- a/examples/vgg_pytorch/pytorch_vgg.py +++ b/examples/vgg/pytorch_vgg.py @@ -4,12 +4,7 @@ from objax.zoo import vgg - -def delta(x, y): # pytoch, jax - return jn.abs(x.detach().numpy() - y).max() - - -mo = vgg.vgg16(use_bn=False) +mo = vgg.VGG16() vgg.load_pretrained_weights_from_pytorch(mo) print(mo.vars()) diff --git a/objax/util/convert/__init__.py b/objax/util/convert/__init__.py index 589631c..c62d85b 100644 --- a/objax/util/convert/__init__.py +++ b/objax/util/convert/__init__.py @@ -1,2 +1,3 @@ +from . import keras from . import pytorch from .convert import * diff --git a/objax/util/convert/convert.py b/objax/util/convert/convert.py index 9c0cdcc..cc59bd2 100644 --- a/objax/util/convert/convert.py +++ b/objax/util/convert/convert.py @@ -15,11 +15,11 @@ def assign(x: BaseVar, v: np.ndarray): def import_weights(target_vc: VarCollection, source_numpy: Dict[str, np.ndarray], - source_names: Dict[str, str], + target_to_source_names: Dict[str, str], numpy_convert: Dict[str, Callable[[BaseVar, np.ndarray], None]]): module_var = re.compile(r'.*(\([^)]*\)\.[^(]*)$') for k, v in target_vc.items(): - s = source_names[k] + s = target_to_source_names[k] t = module_var.match(k).group(1) if s not in source_numpy: print(f'Skipping {k} ({s})') diff --git a/objax/util/convert/keras.py b/objax/util/convert/keras.py new file mode 100644 index 0000000..a2d506d --- /dev/null +++ b/objax/util/convert/keras.py @@ -0,0 +1,27 @@ +__all__ = ['ARRAY_CONVERT', 'rename'] + +import re + +from .convert import assign + +ARRAY_CONVERT = { + # '(BatchNorm2D).beta': assign, + # '(BatchNorm2D).gamma': assign, + # '(BatchNorm2D).running_mean': assign, + # '(BatchNorm2D).running_var': assign, + '(Conv2D).b': assign, + '(Conv2D).w': assign, + '(Linear).b': assign, + '(Linear).w': assign, +} + + +def rename(x): + # x = x.replace('(BatchNorm2D).gamma', '(BatchNorm2D).weight').replace('(BatchNorm2D).beta', '(BatchNorm2D).bias') + x = re.sub(r'\([^)]*\)', '', x) + x = re.sub(r'^\.', '', x) + x = re.sub(r'.w$', '/kernel', x) + x = re.sub(r'.b$', '/bias', x) + x = re.sub(r'\[|\]', '', x) + x = re.sub(r'\.', '_', x) + return x diff --git a/objax/zoo/vgg.py b/objax/zoo/vgg.py index 6f449d0..20fadcf 100644 --- a/objax/zoo/vgg.py +++ b/objax/zoo/vgg.py @@ -1,14 +1,22 @@ -__all__ = ['VGG', 'load_pretrained_weights_from_pytorch', 'vgg11', 'vgg13', 'vgg16', 'vgg19'] +__all__ = ['VGG', 'VGG11', 'VGG11_BN', 'VGG13', 'VGG13_BN', 'VGG16', 'VGG16_BN', 'VGG19', 'VGG19_BN', + 'load_pretrained_weights_from_keras', 'load_pretrained_weights_from_pytorch'] from typing import Union, Sequence import objax from objax.util.convert import import_weights, pytorch +OPS = dict( + vgg11=(64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'), + vgg13=(64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'), + vgg16=(64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'), + vgg19=(64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M') +) + class VGG(objax.Module): - def __init__(self, nin: int, nout: int, ops: Sequence[Union[str, int]], use_bn: bool, name: str): - self.name = name + ('_bn' if use_bn else '') + def __init__(self, nin: int, nout: int, ops: Sequence[Union[str, int]], use_bn: bool): + self.use_bn = use_bn self.ops = tuple(ops) n = nin self.features = objax.nn.Sequential() @@ -33,36 +41,116 @@ def __call__(self, *args, **kwargs): return self.classifier(features, **kwargs) def __repr__(self): - use_bn = self.name.endswith('_bn') - name = self.name[:-3] if use_bn else self.name - return f'{self.__class__.__name__}(nin={self.features[0].w.value.shape[2]}, ' \ - f'nout={self.features[0].w.value.shape[3]}, ops={self.ops}, use_bn={use_bn}, name={repr(name)})' + nin, nout = self.features[0].w.value.shape[2:] + return f'{self.__class__.__name__}(nin={nin}, nout={nout}, ops={self.ops}, use_bn={self.use_bn})' -def load_pretrained_weights_from_pytorch(m: VGG): - import torchvision - torch_model = getattr(torchvision.models, m.name)(pretrained=True) - torch_model.eval() # Just a safety precaution. - numpy_arrays = {name: param.numpy() for name, param in torch_model.state_dict().items()} - numpy_names = {k: pytorch.rename(k) for k in m.vars().keys()} - import_weights(m.vars(), numpy_arrays, numpy_names, pytorch.ARRAY_CONVERT) +class CustomVGG(VGG): + def __repr__(self): + nin, nout = self.features[0].w.value.shape[2:] + return f'{self.__class__.__name__}(nin={nin}, nout={nout})' + + +class VGG11(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg11'], use_bn=False) + + +class VGG11_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg11'], use_bn=True) + + +class VGG13(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg13'], use_bn=False) + +class VGG13_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg13'], use_bn=True) -def vgg11(use_bn: bool): - ops = 64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' - return VGG(3, 1000, ops, use_bn=use_bn, name='vgg11') +class VGG16(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg16'], use_bn=False) -def vgg13(use_bn: bool): - ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' - return VGG(3, 1000, ops, use_bn=use_bn, name='vgg13') +class VGG16_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg16'], use_bn=True) -def vgg16(use_bn: bool): - ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M' - return VGG(3, 1000, ops, use_bn=use_bn, name='vgg16') +class VGG19(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg19'], use_bn=False) -def vgg19(use_bn: bool): - ops = 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M' - return VGG(3, 1000, ops, use_bn=use_bn, name='vgg19') + +class VGG19_BN(CustomVGG): + def __init__(self, nin: int = 3, nout: int = 1000): + super().__init__(nin, nout, ops=OPS['vgg19'], use_bn=True) + + +def load_pretrained_weights_from_keras(m: VGG): + import tensorflow as tf + assert hasattr(tf.keras.applications, m.__class__.__name__), \ + f'No Keras pretrained model for {m.__class__.__name__}' + # 1. Get Keras model + keras_model = getattr(tf.keras.applications, m.__class__.__name__)(weights='imagenet') + + # 2. Get Keras model weights + keras_numpy = {weight.name.split(':')[0]: weight.numpy() # Remove :0 at the end of the variable name. + for layer in keras_model.layers for weight in layer.weights} + # 2.1 Flattening differs between NHWC and NCHW: convert first linear layer post-flattening. + nhwc_kernel = keras_numpy['fc1/kernel'].reshape((7, 7, 512, 4096)) + keras_numpy['fc1/kernel'] = nhwc_kernel.transpose((2, 0, 1, 3)).reshape((-1, 4096)) + + # 3. Map Objax names to Keras names. + # The architectures are syntactically different (Objax uses Sequential while Keras does not). + # So we have to map the name semi-manually since there's no automatic way to do it. + keras_names = {k: objax.util.convert.keras.rename(k) for k in m.vars().keys()} + to_keras = { + 'classifier0': 'fc1', + 'classifier3': 'fc2', + 'classifier6': 'predictions', + } + + # The features in Keras are of the form "block{i}_conv{j}/variable" + # In Objax they are of the form "features{pos}/variable" + # Below we convert list position to block_conv. + target_to_source_names = {} + block_id, conv_id, seq_id = 1, 1, 0 + for k, v in keras_names.items(): + if '/' not in v: + target_to_source_names[k] = v + continue + layer, variable = v.split('/') + if layer.startswith('features'): + new_seq_id = int(layer[8:]) + if new_seq_id - seq_id == 2: + conv_id += 1 + elif new_seq_id - seq_id == 3: + block_id += 1 + conv_id = 1 + else: + assert new_seq_id == seq_id + seq_id = new_seq_id + target_to_source_names[k] = f'block{block_id}_conv{conv_id}/{variable}' + elif layer.startswith('classifier'): + target_to_source_names[k] = f'{to_keras[layer]}/{variable}' + else: + target_to_source_names[k] = v + + objax.util.convert.import_weights(m.vars(), keras_numpy, target_to_source_names, + objax.util.convert.keras.ARRAY_CONVERT) + + +def load_pretrained_weights_from_pytorch(m: VGG): + import torchvision + assert hasattr(torchvision.models, m.__class__.__name__.lower()), \ + f'No TorchVision pretrained model for {m.__class__.__name__.lower()}' + torch_model = getattr(torchvision.models, m.__class__.__name__.lower())(pretrained=True) + torch_model.eval() # Just a safety precaution. + torch_numpy = {name: param.numpy() for name, param in torch_model.state_dict().items()} + target_to_source_names = {k: pytorch.rename(k) for k in m.vars().keys()} + import_weights(m.vars(), torch_numpy, target_to_source_names, pytorch.ARRAY_CONVERT)