diff --git a/examples/super_convergence/super_convergence.ipynb b/examples/super_convergence/super_convergence.ipynb new file mode 100644 index 0000000..eb4bb03 --- /dev/null +++ b/examples/super_convergence/super_convergence.ipynb @@ -0,0 +1,2065 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "vdYsiFP6jjPY" + }, + "source": [ + "# Super Convergence\n", + "Reference Papers\n", + "1. [Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates](https://arxiv.org/abs/1708.07120)\n", + "2. [Cyclical Learning Rates for Training Neural Networks](https://arxiv.org/abs/1506.01186)\n", + "\n", + "Blogs\n", + "1. [Finding Good Learning Rate and The One Cycle Policy.](https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6)\n", + "2. [The Learning Rate Finder Technique: How Reliable Is It?](https://blog.dataiku.com/the-learning-rate-finder-technique-how-reliable-is-it)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "k7tqeeGB6peL" + }, + "outputs": [], + "source": [ + "import objax\n", + "from objax.nn import Conv2D\n", + "from time import time\n", + "import jax.numpy as jn\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from objax.util import EasyDict\n", + "import os\n", + "import tensorflow_datasets as tfds\n", + "from tqdm.notebook import tqdm\n", + "from tqdm import trange\n", + "import math\n", + "from copy import deepcopy\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 222, + "referenced_widgets": [ + "e7b07465fa61490dbb41545aac4fa9b4", + "1096b35a43204819b793ae71ca903009", + "80beed2bab07454eaf0eb26399ea19b9", + "559a266f43484b378a1fe4783178274b", + "ef01d022b37b4806bbbc05f5c2bade19", + "7b3f8d1c9ab944c69c0146a0951a8807", + "36a97e7e6f934026817e68b6daab5e8b", + "3d1b718eec564ac6866d732e93c52362" + ] + }, + "id": "le17weRd81kv", + "outputId": "610f91ce-ab0e-4a73-8260-871f3def5a19" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your\n", + "local data directory. If you'd instead prefer to read directly from our public\n", + "GCS bucket (recommended if you're running on GCP), you can instead set\n", + "data_dir=gs://tfds-data/datasets.\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mDownloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /root/TFDS/mnist/3.0.0...\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e7b07465fa61490dbb41545aac4fa9b4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1mDataset mnist downloaded and prepared to /root/TFDS/mnist/3.0.0. Subsequent calls will reuse this data.\u001b[0m\n" + ] + } + ], + "source": [ + "DATA_DIR = os.path.join(os.environ['HOME'], 'TFDS')\n", + "data = tfds.as_numpy(tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR))\n", + "train = EasyDict(image=data['train']['image'].transpose(0, 3, 1, 2) / 255, label=data['train']['label'])\n", + "test = EasyDict(image=data['test']['image'].transpose(0, 3, 1, 2) / 255, label=data['test']['label'])\n", + "del data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "xs2WTHVn_2-K" + }, + "outputs": [], + "source": [ + "# Image Augmentation\n", + "def augment(x, shift=4): # Shift all images in the batch by up to \"shift\" pixels in any direction.\n", + " x_pad = np.pad(x, [[0, 0], [0, 0], [shift, shift], [shift, shift]])\n", + " rx, ry = np.random.randint(0, shift, size=2)\n", + " return x_pad[:, :, rx:rx + 28, ry:ry + 28]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "H-E-Tw1oA-LF" + }, + "outputs": [], + "source": [ + "def acc_loss_plot(loss, accuracy):\n", + " fig, a = plt.subplots(2, figsize=(10, 10))\n", + " epochs = range(len(loss))\n", + " a[0].plot(epochs, loss)\n", + " a[0].set_ylabel(\"loss\")\n", + " a[0].set_xlabel(\"epochs\")\n", + " a[1].plot(epochs, accuracy)\n", + " a[1].set_ylabel(\"accuracy\")\n", + " a[1].set_xlabel(\"epochs\")\n", + " plt.show()\n", + "\n", + "def lr_plot(learning_rates, losses):\n", + " plt.ylabel(\"loss\")\n", + " plt.xlabel(\"learning rate\")\n", + " plt.xscale(\"log\")\n", + " plt.plot(learning_rates, losses)\n", + "\n", + "def lr_mom_plot(learning_rates, momentums):\n", + " fig, a = plt.subplots(1, 2, figsize=(10,3))\n", + " a[0].set_title(\"Learning Rate\")\n", + " a[0].set_xlabel('iterations')\n", + " a[0].set_ylabel('lr')\n", + " a[0].plot(learning_rates)\n", + " a[1].set_title(\"Momentum\")\n", + " a[1].set_xlabel('iterations')\n", + " a[1].set_ylabel('momentum')\n", + " a[1].plot(momentums)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "UL10j4Me88c5" + }, + "outputs": [], + "source": [ + "def simple_net_block(nin, nout):\n", + " return objax.nn.Sequential([\n", + " objax.nn.Conv2D(nin, nout, k=3), objax.functional.leaky_relu,\n", + " objax.functional.max_pool_2d,\n", + " objax.nn.Conv2D(nout, nout, k=3), objax.functional.leaky_relu,\n", + " ])\n", + "\n", + "class SimpleNet(objax.Module):\n", + " def __init__(self, nclass, colors, n):\n", + " self.pre_conv = objax.nn.Sequential([objax.nn.Conv2D(colors, n, k=3), objax.functional.leaky_relu])\n", + " self.block1 = simple_net_block(1 * n, 2 * n)\n", + " self.block2 = simple_net_block(2 * n, 4 * n)\n", + " self.post_conv = objax.nn.Conv2D(4 * n, nclass, k=3)\n", + "\n", + " def __call__(self, x, training=False): # x = (batch, colors, height, width)\n", + " y = self.pre_conv(x)\n", + " y = self.block1(y)\n", + " y = self.block2(y)\n", + " logits = self.post_conv(y).mean((2, 3)) # logits = (batch, nclass)\n", + " if training:\n", + " return logits\n", + " return objax.functional.softmax(logits)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "7mWj4QxWVbI3" + }, + "outputs": [], + "source": [ + "batch = 512\n", + "test_batch = 2048\n", + "epochs = 50\n", + "lr = 0.003\n", + "train_size = train.image.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 238 + }, + "id": "DBKXaiasNYPP", + "outputId": "1f11718a-4c5d-4abb-aca5-bb0aff1191d3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(SimpleNet).pre_conv(Sequential)[0](Conv2D).b 16 (16, 1, 1)\n", + "(SimpleNet).pre_conv(Sequential)[0](Conv2D).w 144 (3, 3, 1, 16)\n", + "(SimpleNet).block1(Sequential)[0](Conv2D).b 32 (32, 1, 1)\n", + "(SimpleNet).block1(Sequential)[0](Conv2D).w 4608 (3, 3, 16, 32)\n", + "(SimpleNet).block1(Sequential)[3](Conv2D).b 32 (32, 1, 1)\n", + "(SimpleNet).block1(Sequential)[3](Conv2D).w 9216 (3, 3, 32, 32)\n", + "(SimpleNet).block2(Sequential)[0](Conv2D).b 64 (64, 1, 1)\n", + "(SimpleNet).block2(Sequential)[0](Conv2D).w 18432 (3, 3, 32, 64)\n", + "(SimpleNet).block2(Sequential)[3](Conv2D).b 64 (64, 1, 1)\n", + "(SimpleNet).block2(Sequential)[3](Conv2D).w 36864 (3, 3, 64, 64)\n", + "(SimpleNet).post_conv(Conv2D).b 10 (10, 1, 1)\n", + "(SimpleNet).post_conv(Conv2D).w 5760 (3, 3, 64, 10)\n", + "+Total(12) 75242\n" + ] + } + ], + "source": [ + "model = SimpleNet(nclass=10, colors=1, n=16)\n", + "original_weights = model.vars().tensors() # saving the initial weights, all the below experiments will be using this same weights.\n", + "print(model.vars())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oGUj3wrhVknG" + }, + "source": [ + "## Regular training with SGD" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 884 + }, + "id": "sWgSBId6-kAF", + "outputId": "e90bef4d-8e1e-411b-d307-9e12650ac007" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 2/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 86593.05img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0001 Accuracy 13.08 Loss 2.33\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 3/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88523.38img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0002 Accuracy 9.80 Loss 2.33\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 4/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88991.30img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0003 Accuracy 9.80 Loss 2.32\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 5/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 89247.93img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0004 Accuracy 9.80 Loss 2.32\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 6/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88412.02img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0005 Accuracy 9.80 Loss 2.32\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 7/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88727.59img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0006 Accuracy 9.80 Loss 2.32\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 8/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88869.77img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0007 Accuracy 9.82 Loss 2.31\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 9/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 87998.60img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0008 Accuracy 9.82 Loss 2.31\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88086.23img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0009 Accuracy 9.92 Loss 2.31\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 11/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88946.04img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0010 Accuracy 10.27 Loss 2.30\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 12/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 87248.79img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0011 Accuracy 11.19 Loss 2.30\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 13/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 85079.93img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0012 Accuracy 14.57 Loss 2.29\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 14/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88835.86img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0013 Accuracy 15.08 Loss 2.29\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 15/50 : 14%|█▍ | 8704/60416 [00:00<00:00, 85697.91img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0014 Accuracy 17.52 Loss 2.28\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 16/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 87876.57img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0015 Accuracy 23.02 Loss 2.27\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 17/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88567.80img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0016 Accuracy 24.27 Loss 2.26\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 18/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88092.66img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0017 Accuracy 25.01 Loss 2.24\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 19/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88833.41img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0018 Accuracy 31.00 Loss 2.22\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 20/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88844.85img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0019 Accuracy 34.81 Loss 2.18\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 21/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88184.70img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0020 Accuracy 30.33 Loss 2.12\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 22/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88825.05img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0021 Accuracy 38.27 Loss 2.03\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 23/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88235.83img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0022 Accuracy 29.19 Loss 1.88\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 24/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 89164.96img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0023 Accuracy 30.77 Loss 1.71\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 25/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88137.45img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0024 Accuracy 44.48 Loss 1.64\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 26/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88553.59img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0025 Accuracy 47.86 Loss 1.51\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 27/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 86943.64img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0026 Accuracy 43.22 Loss 1.37\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 28/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 87624.38img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0027 Accuracy 49.07 Loss 1.23\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 29/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 85400.42img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0028 Accuracy 61.34 Loss 1.08\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 30/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88383.11img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0029 Accuracy 54.22 Loss 0.99\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 31/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88590.33img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0030 Accuracy 67.99 Loss 0.87\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 32/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 86999.99img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0031 Accuracy 73.26 Loss 0.75\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 33/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88843.01img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0032 Accuracy 77.06 Loss 0.68\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 34/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 86265.11img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0033 Accuracy 72.43 Loss 0.62\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 35/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 82148.97img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0034 Accuracy 76.37 Loss 0.57\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 36/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88946.86img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0035 Accuracy 80.99 Loss 0.53\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 37/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 89022.04img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0036 Accuracy 80.53 Loss 0.53\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 38/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88539.60img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0037 Accuracy 82.85 Loss 0.46\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 39/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88502.70img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0038 Accuracy 78.69 Loss 0.45\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 40/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 87474.47img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0039 Accuracy 85.29 Loss 0.44\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 41/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88521.55img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0040 Accuracy 85.52 Loss 0.41\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 42/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 89200.97img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0041 Accuracy 86.48 Loss 0.39\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 43/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88478.19img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0042 Accuracy 88.26 Loss 0.38\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 44/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88767.93img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0043 Accuracy 77.10 Loss 0.37\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 45/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88623.64img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0044 Accuracy 88.17 Loss 0.38\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 46/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 89394.47img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0045 Accuracy 88.27 Loss 0.35\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 47/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88684.63img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0046 Accuracy 86.68 Loss 0.34\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 48/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88895.52img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0047 Accuracy 88.00 Loss 0.38\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 49/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 88530.47img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0048 Accuracy 89.35 Loss 0.32\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 50/50 : 15%|█▌ | 9216/60416 [00:00<00:00, 89427.56img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0049 Accuracy 90.10 Loss 0.31\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0050 Accuracy 90.06 Loss 0.31\n", + "Total training time: 78.19884443283081\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r" + ] + } + ], + "source": [ + "opt = objax.optimizer.SGD(model.vars())\n", + "predict = objax.Jit(model)\n", + "\n", + "def loss(x, y):\n", + " logits = model(x, training=True)\n", + " xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()\n", + " wd_loss = sum(jn.abs(v.value).sum() for k,v in model.vars().items() if k.endswith('.w'))\n", + " return xe_loss + wd_loss * 1e-5\n", + "\n", + "gv = objax.GradValues(loss, model.vars())\n", + "\n", + "def train_op(x, y, lr):\n", + " g, v = gv(x, y)\n", + " opt(lr=lr, grads=g)\n", + " return v\n", + "\n", + "train_op = objax.Jit(train_op, model.vars() + opt.vars())\n", + "\n", + "train_start = time()\n", + "loss_hist, accuracy_hist = [], []\n", + "\n", + "for epoch in range(epochs):\n", + " loop = trange(0, train_size, batch,\n", + " leave=False, unit='img', unit_scale=batch,\n", + " desc='Epoch %d/%d ' % (1 + epoch, epochs))\n", + " batch_loss = []\n", + " for it in loop:\n", + " # select random images from training set\n", + " sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])\n", + "\n", + " # feed the batch\n", + " v = train_op(augment(train.image[sel]), train.label[sel], lr)\n", + " batch_loss.append(v[0])\n", + "\n", + " loss = sum(batch_loss)/len(batch_loss)\n", + " loss_hist.append(loss)\n", + "\n", + " # Eval\n", + " accuracy = 0\n", + " for it in trange(0, test.image.shape[0], test_batch, leave=False, desc='Evaluating'):\n", + " x = test.image[it: it + test_batch]\n", + " xl = test.label[it: it + test_batch]\n", + " accuracy += (np.argmax(predict(x), axis=1) == xl).sum()\n", + " accuracy /= test.image.shape[0]\n", + " accuracy_hist.append(accuracy)\n", + " print(f'Epoch {epoch + 1:04d} Accuracy {100 * accuracy:.2f} Loss {loss:.2f}')\n", + "\n", + "train_end = time()\n", + "print('Total training time: ', (train_end-train_start))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 606 + }, + "id": "7CM6S1zmO7k8", + "outputId": "67775e1b-f78c-491d-848c-5750c08d3ac4" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "acc_loss_plot(loss_hist, accuracy_hist)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oznhDcfmVw7J" + }, + "source": [ + "## Learning Rate Finder\n", + "\n", + "1. Save the weights\n", + "2. Find the best LR\n", + "3. Restore the weights\n", + "## How to find the best Learning rate?\n", + "We need to pick the approximate value not too high or too low.\n", + "As you can see that picking up the lower learning rate (<10^-2) and the higher learning rate (>1) results in huge loss \n", + "so we choose something between these ranges usually where the slope is steepest in the downward direction.", + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 283 + }, + "id": "4ZPsnS20CQSw", + "outputId": "20a88ea7-152c-498a-82de-ef5c203c9210" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "# find best learning rate for faster convergence\n", + "model.vars().assign(original_weights) # restoring the initial weights\n", + "opt = objax.optimizer.SGD(model.vars())\n", + "predict = objax.Jit(model)\n", + "\n", + "def loss(x, y):\n", + " logits = model(x, training=True)\n", + " xe_loss = objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()\n", + " wd_loss = sum(jn.abs(v.value).sum() for k,v in model.vars().items() if k.endswith('.w'))\n", + " return xe_loss + wd_loss * 1e-5\n", + "\n", + "gv = objax.GradValues(loss, model.vars())\n", + "\n", + "def train_op(x, y, lr):\n", + " g, v = gv(x, y)\n", + " opt(lr=lr, grads=g)\n", + " return v\n", + "\n", + "train_op = objax.Jit(train_op, model.vars() + opt.vars())\n", + "\n", + "def find_lr(init_value=1e-8, final_value=10., beta=0.98):\n", + " num_batches = (train_size//batch) - 1\n", + " lr_update_factor = (final_value / init_value) ** (1/num_batches)\n", + " avg_loss, best_loss, best_lr, batch_num = 0., 0., 0., 0\n", + " current_lr = init_value\n", + "\n", + " loop = trange(0, train_size, batch, leave=False, unit='img',\n", + " unit_scale=batch)\n", + "\n", + " learning_rates, losses = [], []\n", + " for it in loop:\n", + " batch_num += 1\n", + "\n", + " # select random images from training set\n", + " sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])\n", + "\n", + " # feed the batch\n", + " batch_loss = train_op(train.image[sel], train.label[sel], current_lr)[0]\n", + "\n", + " # Compute the smoothed loss\n", + " avg_loss = beta * avg_loss + (1-beta) * batch_loss\n", + " smoothed_loss = avg_loss / (1 - beta ** batch_num)\n", + "\n", + " #Stop if the loss is exploding\n", + " if batch_num > 1 and smoothed_loss > 4 * best_loss:\n", + " return learning_rates, losses\n", + "\n", + " #Record the best loss\n", + " if smoothed_loss < best_loss or batch_num==1:\n", + " best_loss, best_lr = smoothed_loss, current_lr\n", + "\n", + " losses.append(smoothed_loss)\n", + " learning_rates.append(current_lr)\n", + " current_lr *= lr_update_factor\n", + "\n", + " return learning_rates, losses\n", + "\n", + "learning_rates, losses = find_lr()\n", + "lr_plot(learning_rates[10:-5], losses[10:-5])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IVKw_KZCXTOT" + }, + "source": [ + "## Training with best learning rate." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "id": "WMXnIgqLeADQ", + "outputId": "645ed81e-ab8e-41bd-8b8f-b0052bc4263d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 2/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 89874.45img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0001 Accuracy 18.99 Loss 2.25\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 3/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90893.92img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0002 Accuracy 45.40 Loss 2.04\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 4/10 : 15%|█▌ | 9216/60416 [00:00<00:00, 90936.84img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0003 Accuracy 54.44 Loss 1.40\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 5/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90601.27img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0004 Accuracy 78.53 Loss 0.69\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 6/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 89078.22img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0005 Accuracy 88.43 Loss 0.36\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 7/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90713.66img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0006 Accuracy 92.55 Loss 0.24\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 8/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90831.60img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0007 Accuracy 91.52 Loss 0.23\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 9/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90750.99img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0008 Accuracy 92.81 Loss 0.16\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 89652.48img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0009 Accuracy 89.35 Loss 0.20\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0010 Accuracy 95.11 Loss 0.15\n", + "Total training time: 15.313448905944824\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r" + ] + } + ], + "source": [ + "model.vars().assign(original_weights) # restoring the initial weights\n", + "opt = objax.optimizer.SGD(model.vars())\n", + "predict = objax.Jit(model)\n", + "\n", + "lr = 0.1\n", + "epochs = 10\n", + "\n", + "def loss(x, y):\n", + " logits = model(x, training=True)\n", + " return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()\n", + "\n", + "gv = objax.GradValues(loss, model.vars())\n", + "\n", + "def train_op(x, y, lr):\n", + " g, v = gv(x, y)\n", + " opt(lr=lr, grads=g)\n", + " return v\n", + "\n", + "train_op = objax.Jit(train_op, model.vars() + opt.vars())\n", + "\n", + "train_start = time()\n", + "loss_hist, accuracy_hist = [], []\n", + "\n", + "for epoch in range(epochs):\n", + " loop = trange(0, train_size, batch,\n", + " leave=False, unit='img', unit_scale=batch,\n", + " desc='Epoch %d/%d ' % (1 + epoch, epochs))\n", + " batch_loss = []\n", + " for it in loop:\n", + " # select random images from training set\n", + " sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])\n", + "\n", + " # feed the batch\n", + " v = train_op(augment(train.image[sel]), train.label[sel], lr)\n", + " batch_loss.append(v[0])\n", + "\n", + " loss = sum(batch_loss)/len(batch_loss)\n", + " loss_hist.append(loss)\n", + "\n", + " # Eval\n", + " accuracy = 0\n", + " for it in trange(0, test.image.shape[0], test_batch, leave=False, desc='Evaluating'):\n", + " x = test.image[it: it + test_batch]\n", + " xl = test.label[it: it + test_batch]\n", + " accuracy += (np.argmax(predict(x), axis=1) == xl).sum()\n", + " accuracy /= test.image.shape[0]\n", + " accuracy_hist.append(accuracy)\n", + " print(f'Epoch {epoch + 1:04d} Accuracy {100 * accuracy:.2f} Loss {loss:.2f}')\n", + "\n", + "train_end = time()\n", + "print('Total training time: ', (train_end-train_start))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 606 + }, + "id": "IcivNXQqmdcz", + "outputId": "95a62d8d-e3a5-4582-f1f5-b51277e39914" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "acc_loss_plot(loss_hist, accuracy_hist)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mCjvEEe2XjR2" + }, + "source": [ + "## Training with One Cycle Policy (Cyclical learning rate and Cyclical momentum)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 241 + }, + "id": "MAaZoBFk7vod", + "outputId": "15e47962-4d0b-4125-b651-f740e4aabd5e" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "def cyclical(stepsize, min_val, max_val, inverse=False):\n", + " import math\n", + " scaler = lambda x: 1\n", + " lr_lambda = lambda it: min_val + (max_val - min_val) * relative(it, stepsize, inverse)\n", + "\n", + " def relative(it, stepsize, inverse):\n", + " cycle = math.floor(1 + it / (2 * stepsize))\n", + " x = abs(it / stepsize - 2 * cycle + 1)\n", + " val = max(0, (1 - x)) * scaler(cycle)\n", + " return 1 - val if inverse else val\n", + " \n", + " return lr_lambda\n", + "\n", + "\n", + "# Example\n", + "n_cycles = 2\n", + "cycle_len = 10\n", + "step_size = cycle_len // 2\n", + "cyc_lr = cyclical(stepsize=step_size, min_val=3e-4, max_val=3e-3)\n", + "cyc_m = cyclical(stepsize=step_size, min_val=0.85, max_val=0.95, inverse=True)\n", + "\n", + "lr_list, m_list = [], []\n", + "for epoch in range(n_cycles):\n", + " for i in range(cycle_len):\n", + " iter_num = epoch*cycle_len + i\n", + " lr_list.append(cyc_lr(iter_num))\n", + " m_list.append(cyc_m(iter_num))\n", + "\n", + "lr_mom_plot(lr_list, m_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "id": "jJO5me5Inn-m", + "outputId": "e62012d7-12a2-4b6b-9192-10c5a198565d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 2/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 89271.56img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0001 Accuracy 42.67 Loss 2.13\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 3/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90507.81img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0002 Accuracy 54.67 Loss 1.35\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 4/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90364.89img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0003 Accuracy 81.41 Loss 0.43\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 5/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90104.67img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0004 Accuracy 93.69 Loss 0.18\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 6/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 89351.32img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0005 Accuracy 95.00 Loss 0.13\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 7/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90023.56img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0006 Accuracy 94.41 Loss 0.08\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 8/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 90246.37img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0007 Accuracy 96.68 Loss 0.07\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 9/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 89971.55img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0008 Accuracy 97.94 Loss 0.06\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 10/10 : 16%|█▌ | 9728/60416 [00:00<00:00, 86532.58img/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0009 Accuracy 98.20 Loss 0.05\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0010 Accuracy 98.10 Loss 0.04\n", + "Total training time: 15.809092283248901\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r" + ] + } + ], + "source": [ + "model.vars().assign(original_weights) # restoring the initial weights\n", + "opt = objax.optimizer.Momentum(model.vars(), 0.95)\n", + "predict = objax.Jit(model)\n", + "\n", + "lr = 0.1 # the best learning rate which we found previously\n", + "epochs = 10\n", + "num_iter_per_epoch = train_size//batch\n", + "tot_iter = epochs * num_iter_per_epoch\n", + "step_size = tot_iter // 2\n", + "\n", + "#initialize cyclical generators\n", + "cyc_lr = cyclical(stepsize=step_size, min_val=lr/10, max_val=lr)\n", + "cyc_mom = cyclical(stepsize=step_size, min_val=0.85, max_val=0.95, inverse=True)\n", + "\n", + "def loss(x, y):\n", + " logits = model(x, training=True)\n", + " return objax.functional.loss.cross_entropy_logits_sparse(logits, y).mean()\n", + "\n", + "gv = objax.GradValues(loss, model.vars())\n", + "\n", + "def train_op(x, y, lr, mom):\n", + " g, v = gv(x, y)\n", + " opt(lr=lr, grads=g, momentum=mom)\n", + " return v\n", + "\n", + "train_op = objax.Jit(train_op, model.vars() + opt.vars())\n", + "mom_hist, lr_hist, loss_hist, accuracy_hist = [], [], [], []\n", + "train_start = time()\n", + "\n", + "for epoch in range(epochs):\n", + " loop = trange(0, train_size, batch,\n", + " leave=False, unit='img', unit_scale=batch,\n", + " desc='Epoch %d/%d ' % (1 + epoch, epochs))\n", + " batch_loss = []\n", + " for it in loop:\n", + " iter_num = epoch*num_iter_per_epoch+it/batch\n", + " \n", + " # get new learning rate and momentum\n", + " lr = cyc_lr(iter_num)\n", + " mom = cyc_mom(iter_num)\n", + "\n", + " # select random images from training set\n", + " sel = np.random.randint(size=(batch,), low=0, high=train.image.shape[0])\n", + " \n", + " # feed the batch\n", + " v = train_op(augment(train.image[sel]), train.label[sel], lr, mom)\n", + "\n", + " batch_loss.append(v[0])\n", + " mom_hist.append(mom)\n", + " lr_hist.append(lr)\n", + "\n", + " loss = sum(batch_loss)/len(batch_loss)\n", + " loss_hist.append(loss)\n", + "\n", + " # Eval\n", + " accuracy = 0\n", + " for it in trange(0, test.image.shape[0], test_batch, leave=False, desc='Evaluating'):\n", + " x = test.image[it: it + test_batch]\n", + " xl = test.label[it: it + test_batch]\n", + " accuracy += (np.argmax(predict(x), axis=1) == xl).sum()\n", + " accuracy /= test.image.shape[0]\n", + " accuracy_hist.append(accuracy)\n", + " print(f'Epoch {epoch + 1:04d} Accuracy {100 * accuracy:.2f} Loss {loss:.2f}')\n", + "\n", + "train_end = time()\n", + "print('Total training time: ', (train_end-train_start))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 606 + }, + "id": "zSG6FmCq-HX4", + "outputId": "d9c5add0-8b14-438b-bb9c-0eebcb4065ad" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "acc_loss_plot(loss_hist, accuracy_hist)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 241 + }, + "id": "fOTiNEJuQfBA", + "outputId": "7ecdd587-dbb9-4af6-cecf-6a7a85dfd083" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light", + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "lr_mom_plot(lr_hist, mom_hist)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Super Convergence.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "1096b35a43204819b793ae71ca903009": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "36a97e7e6f934026817e68b6daab5e8b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3d1b718eec564ac6866d732e93c52362": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "559a266f43484b378a1fe4783178274b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3d1b718eec564ac6866d732e93c52362", + "placeholder": "​", + "style": "IPY_MODEL_36a97e7e6f934026817e68b6daab5e8b", + "value": " 4/4 [00:06<00:00, 1.57s/ file]" + } + }, + "7b3f8d1c9ab944c69c0146a0951a8807": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "80beed2bab07454eaf0eb26399ea19b9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "Dl Completed...: 100%", + "description_tooltip": null, + "layout": "IPY_MODEL_7b3f8d1c9ab944c69c0146a0951a8807", + "max": 4, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_ef01d022b37b4806bbbc05f5c2bade19", + "value": 4 + } + }, + "e7b07465fa61490dbb41545aac4fa9b4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_80beed2bab07454eaf0eb26399ea19b9", + "IPY_MODEL_559a266f43484b378a1fe4783178274b" + ], + "layout": "IPY_MODEL_1096b35a43204819b793ae71ca903009" + } + }, + "ef01d022b37b4806bbbc05f5c2bade19": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "initial" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}