Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
374 changes: 374 additions & 0 deletions docs/user_guide/Getting_Started.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Install the required packages\n",
"[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Bergschaf/visualime_guide/blob/master/Get_Started.ipynb)\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"!pip3 install numpy\n",
"!pip3 install matplotlib"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Import the required packages"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"from torchvision import datasets, transforms"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Define the transformations to prepare the data"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"transform = transforms.Compose([transforms.ToTensor(),\n",
" transforms.Normalize((0.5,), (0.5,)),\n",
" ])"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"1. ```transforms.ToTensor()``` converts the image to a tensor\n",
"2. ```transforms.Normalize((0.5,), (0.5,))``` normalizes the image"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Download the dataset"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!mkdir -p data\n",
"\n",
"testset = datasets.MNIST('data', download=True, train=False, transform=transform)\n",
"\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Analyze the dataset"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"dataiter = iter(testloader)\n",
"images, labels = next(dataiter)\n",
"images, labels = next(dataiter)\n",
"\n",
"print(images.shape)\n",
"print(labels.shape)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"The batch size is 64 and the image size is 28x28 and the number of channels is 1 (grayscale)\n",
"The labels are the corresponding numbers for the images"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Download the model"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!wget https://github.com/Bergschaf/visualime_guide/raw/master/models/mnist_model.pt\n",
"model = torch.load(\"mnist_model.pt\")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Test the model on a single image"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"img = images[0]\n",
"img = img.view(1, 784)\n",
"with torch.no_grad():\n",
" logps = model(img)\n",
"\n",
"ps = torch.exp(logps)\n",
"probab = list(ps.numpy()[0])\n",
"print(\"Predicted Digit =\", probab.index(max(probab)))\n",
"\n",
"plt.imshow(img.resize_(1, 28, 28).numpy().squeeze(), cmap='Greys_r')"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"# Explain the classification with visualime"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## Install and import visuallime"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!pip3 install visualime\n",
"from visualime.explain import explain_classification, render_explanation"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## Define helper Functions"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def to_visualime(image: np.ndarray):\n",
" \"\"\"\n",
" Converts an image of the shape [1,28,28] to the shape [28,28,3]\n",
" \"\"\"\n",
" image = image.squeeze()\n",
" image = np.stack((image, image, image), axis=2)\n",
" return image"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"```to_visualime()``` converts the image to the shape [28,28,3] which is required by visualime"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"def predict(imgs: np.ndarray):\n",
" \"\"\"\n",
" :param image: visualime RGB image of the shape [num_samples, 28,28,3]\n",
" :return:\n",
" \"\"\"\n",
" imgs = imgs[:, :, :, 0]\n",
" predictions = np.zeros((imgs.shape[0], 10))\n",
" for i in range(imgs.shape[0]):\n",
" image = imgs[i]\n",
" # convert the image to a tensor\n",
" image = torch.from_numpy(image)\n",
" image = image.view(1, 784)\n",
" with torch.no_grad():\n",
" prediction = model(image)\n",
" predictions[i] = torch.exp(prediction).numpy()[0]\n",
" return predictions\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"```predict()``` takes an image of the shape [num_samples, 28,28,3] (an array of visualime images) and returns the predictions of the model\n",
"This is required to explain the classification with visualime"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"## Explain the classification"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"img = images[5] # Choose an image"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"print(\"The network predicts: \", np.argmax(predict(np.array([to_visualime(img)]))))\n",
"\n",
"segment_mask, segment_weights = explain_classification(image=to_visualime(img), predict_fn=predict, num_of_samples=512)\n",
"\n",
"explanation = render_explanation(\n",
" to_visualime(img),\n",
" segment_mask,\n",
" segment_weights,\n",
" positive=\"green\",\n",
" negative=\"red\",\n",
" coverage=0.5,\n",
" opacity=1,\n",
" )\n",
"\n",
"plt.imshow(explanation)\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}