From 87c4b4a8e90ea124901ec74934d2c509a9e21527 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Morten=20Gr=C3=B8ftehauge?= Date: Wed, 18 Jan 2023 14:11:04 +0100 Subject: [PATCH] Reshape images to 2D so imshow can show them The images were shaped as (8, 8, 1) so `np.squeeze` removes the degenerate dimension and allows matplotlib to plot the image. --- .../Demo_for_training_a_model_with_a_learned_optimizer.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb b/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb index 598b498..b623b49 100644 --- a/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb +++ b/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb @@ -255,6 +255,7 @@ " 9:\t\"Ankle boot\"}\n", "\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "\n", "fig, axs = plt.subplots(5, 5)\n", @@ -264,7 +265,7 @@ " pred = fashion_mnist_label_to_name[int(jax.numpy.argmax(predictions,axis=1)[5*i+j])]\n", " real = fashion_mnist_label_to_name[labels[5*i+j]]\n", " color = 'g' if pred==real else 'r'\n", - " axs[i, j].imshow(batch[\"image\"][5*i+j])\n", + " axs[i, j].imshow(np.squeeze(batch[\"image\"][5*i+j]))\n", " axs[i, j].set_title(f'Prediction: {pred},\\n Real Label: {real}', color=color)\n", " axs[i, j].set_xticks([])\n", " axs[i, j].set_yticks([])\n",