diff --git a/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb b/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb index 598b498..b623b49 100644 --- a/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb +++ b/learned_optimization/research/general_lopt/Demo_for_training_a_model_with_a_learned_optimizer.ipynb @@ -255,6 +255,7 @@ " 9:\t\"Ankle boot\"}\n", "\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "\n", "\n", "fig, axs = plt.subplots(5, 5)\n", @@ -264,7 +265,7 @@ " pred = fashion_mnist_label_to_name[int(jax.numpy.argmax(predictions,axis=1)[5*i+j])]\n", " real = fashion_mnist_label_to_name[labels[5*i+j]]\n", " color = 'g' if pred==real else 'r'\n", - " axs[i, j].imshow(batch[\"image\"][5*i+j])\n", + " axs[i, j].imshow(np.squeeze(batch[\"image\"][5*i+j]))\n", " axs[i, j].set_title(f'Prediction: {pred},\\n Real Label: {real}', color=color)\n", " axs[i, j].set_xticks([])\n", " axs[i, j].set_yticks([])\n",