diff --git a/docs/user_guide/Getting_Started.ipynb b/docs/user_guide/Getting_Started.ipynb new file mode 100644 index 0000000..5974e3f --- /dev/null +++ b/docs/user_guide/Getting_Started.ipynb @@ -0,0 +1,374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Install the required packages\n", + "[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Bergschaf/visualime_guide/blob/master/Get_Started.ipynb)\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n", + "!pip3 install numpy\n", + "!pip3 install matplotlib" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Import the required packages" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from torchvision import datasets, transforms" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Define the transformations to prepare the data" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "transform = transforms.Compose([transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,)),\n", + " ])" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "1. ```transforms.ToTensor()``` converts the image to a tensor\n", + "2. ```transforms.Normalize((0.5,), (0.5,))``` normalizes the image" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Download the dataset" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!mkdir -p data\n", + "\n", + "testset = datasets.MNIST('data', download=True, train=False, transform=transform)\n", + "\n", + "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Analyze the dataset" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "dataiter = iter(testloader)\n", + "images, labels = next(dataiter)\n", + "images, labels = next(dataiter)\n", + "\n", + "print(images.shape)\n", + "print(labels.shape)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "The batch size is 64 and the image size is 28x28 and the number of channels is 1 (grayscale)\n", + "The labels are the corresponding numbers for the images" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Download the model" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!wget https://github.com/Bergschaf/visualime_guide/raw/master/models/mnist_model.pt\n", + "model = torch.load(\"mnist_model.pt\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Test the model on a single image" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "img = images[0]\n", + "img = img.view(1, 784)\n", + "with torch.no_grad():\n", + " logps = model(img)\n", + "\n", + "ps = torch.exp(logps)\n", + "probab = list(ps.numpy()[0])\n", + "print(\"Predicted Digit =\", probab.index(max(probab)))\n", + "\n", + "plt.imshow(img.resize_(1, 28, 28).numpy().squeeze(), cmap='Greys_r')" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "# Explain the classification with visualime" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Install and import visuallime" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!pip3 install visualime\n", + "from visualime.explain import explain_classification, render_explanation" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Define helper Functions" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "def to_visualime(image: np.ndarray):\n", + " \"\"\"\n", + " Converts an image of the shape [1,28,28] to the shape [28,28,3]\n", + " \"\"\"\n", + " image = image.squeeze()\n", + " image = np.stack((image, image, image), axis=2)\n", + " return image" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "```to_visualime()``` converts the image to the shape [28,28,3] which is required by visualime" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "def predict(imgs: np.ndarray):\n", + " \"\"\"\n", + " :param image: visualime RGB image of the shape [num_samples, 28,28,3]\n", + " :return:\n", + " \"\"\"\n", + " imgs = imgs[:, :, :, 0]\n", + " predictions = np.zeros((imgs.shape[0], 10))\n", + " for i in range(imgs.shape[0]):\n", + " image = imgs[i]\n", + " # convert the image to a tensor\n", + " image = torch.from_numpy(image)\n", + " image = image.view(1, 784)\n", + " with torch.no_grad():\n", + " prediction = model(image)\n", + " predictions[i] = torch.exp(prediction).numpy()[0]\n", + " return predictions\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "```predict()``` takes an image of the shape [num_samples, 28,28,3] (an array of visualime images) and returns the predictions of the model\n", + "This is required to explain the classification with visualime" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## Explain the classification" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "img = images[5] # Choose an image" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "print(\"The network predicts: \", np.argmax(predict(np.array([to_visualime(img)]))))\n", + "\n", + "segment_mask, segment_weights = explain_classification(image=to_visualime(img), predict_fn=predict, num_of_samples=512)\n", + "\n", + "explanation = render_explanation(\n", + " to_visualime(img),\n", + " segment_mask,\n", + " segment_weights,\n", + " positive=\"green\",\n", + " negative=\"red\",\n", + " coverage=0.5,\n", + " opacity=1,\n", + " )\n", + "\n", + "plt.imshow(explanation)\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}