diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..2463912 Binary files /dev/null and b/.DS_Store differ diff --git a/examples/cel0_inference_bars_example.ipynb b/examples/cel0_inference_bars_example.ipynb new file mode 100644 index 0000000..cec360d --- /dev/null +++ b/examples/cel0_inference_bars_example.ipynb @@ -0,0 +1,235 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 25, + "id": "116276f7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from sparsecoding import inference\n", + "\n", + "from sparsecoding.data.utils import load_bars_dictionary\n", + "from sparsecoding.visualization import plot_dictionary" + ] + }, + { + "cell_type": "markdown", + "id": "7c1e8bb3", + "metadata": {}, + "source": [ + "## Load bar dictionary\n", + "\n", + "A good way of evaluating whether or not a inference method is working correctly is by generating data from a known dictionary. In this notebook, this is done using a dictionary consisting of horizontal/vertical bars. This dictionary is provided in a pickle file in this rep" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "a9532c7e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAFOCAYAAAAWx6x6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHUUlEQVR4nO3dy2rkMBBA0fHQ///LmkWyUQiM1bb7dc9ZN8EILS5F2dnGGOMPAAAZf5/9AAAAPJYABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAEDMbe8Pt2278jkAADho7z94MwEEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBidn8H8H/2fnemwncTAYBXZQIIABAjAAEAYgQgAEDMaTuAdt7W2ZucuUMA8BgmgAAAMQIQACBGAAIAxJy2A8g6O29r7EzO3B8A7mUCCAAQIwABAGIEIABAjB1A3oadt3X2JmfuEMAXE0AAgBgBCAAQIwABAGIEIABAjJdA4IN56WGNl2Zm7g98LhNAAIAYAQgAECMAAQBi7AACfLPzts7e5Mwd4l2YAAIAxAhAAIAYAQgAEGMHEIC72XlbY2dy5v48jwkgAECMAAQAiBGAAAAxdgAB4EHsvK2zNzk76w6ZAAIAxAhAAIAYAQgAECMAAQBivAQCALwsL85cwwQQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAICY21l/aIxx1p/6CNu2PfsRAAB+ZQIIABAjAAEAYgQgAEDMaTuAdt7W2ZucuUMA8BgmgAAAMQIQACBGAAIAxAhAAICY014CYZ2XHtZ4aWbm/gBwLxNAAIAYAQgAECMAAQBi7ADyNuy8rbM3OXOHAL6YAAIAxAhAAIAYAQgAEGMHED6Ynbc1diZn7g98LhNAAIAYAQgAECMAAQBi7AACfLPzts7e5Mwd4l2YAAIAxAhAAIAYAQgAECMAAQBivAQCwN289LDGSzMz9+d5TAABAGIEIABAjAAEAIixAwgAD2LnbZ29ydlZd8gEEAAgRgACAMQIQACAGDuAAMDLsjd5DRNAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMSc9iHo3/5Zs483rvl5hs5vnTM8zhke4/yOc4bHOcNjCk1jAggAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAzDbGGLt+uG1XPwsAAAfszDoTQACAGgEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIOb27Ac422/fv/ENwzU/z9D5rXOGxznDY5zfcc7wOGd4zJVNYwIIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQMw2xhi7frhtVz8LAAAH7Mw6E0AAgBoBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxNz2/nDvhwUBAHhtJoAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAx/wACY3rnm927RgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "patch_size = 16\n", + "n_basis = 2*patch_size\n", + "\n", + "# load bars dictionary \n", + "dictionary = load_bars_dictionary()\n", + "dictionary = dictionary / dictionary.norm(dim=0, keepdim=True) # assume the dictionary is normalized\n", + "patch_size = int(np.sqrt(dictionary.shape[0]))\n", + "n_basis = dictionary.shape[1]\n", + "\n", + "nrow = 8\n", + "fig,ax = plot_dictionary(dictionary,nrow=nrow,size=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "190116ab", + "metadata": {}, + "outputs": [], + "source": [ + "dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True).squeeze()[0]\n", + "assert dictionary_norms==1, \"Dictionary must be normalized\"" + ] + }, + { + "cell_type": "markdown", + "id": "291457ef", + "metadata": {}, + "source": [ + "## Generate random data from bars dictionary" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "93b5ed39", + "metadata": {}, + "outputs": [], + "source": [ + "n_samples = 100\n", + "min_coefficient_val = 0.8\n", + "\n", + "# generate coefficients\n", + "coefficients = torch.rand([n_samples,n_basis]) \n", + "coefficients[coefficients < min_coefficient_val] = 0\n", + "\n", + "# generate dataset\n", + "data = (dictionary@coefficients.t()).t()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "ad40de0a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAALSElEQVR4nO3dbU8cZR/GYbZS2gCC4kPiR1RTbIVieJCnImgptAmpfkQTbaoUWisge7812WtOZ7Mzuwv3cbz8ZzJzLS0/J72c2U632+1OAFB0Z9QLABhnIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgSTdQ9cWVkpzg8PDxtbzLBsbW0V569evSrOX758OfA1t7e3a80mJiYmmnoIanl5uZHztGV/f79ntra2NoKVDO7Zs2eNnavT6Qx8jqq/4zs7OwOfe9hWV1eL84ODg4HPXed3zZ0kQCCSAIFIAgQiCRCIJEBQe3ebm+fo6GjUS4gWFhZ6ZuO+5ipN7m5X/V8P/aja3W5i53zY1tfXi/N79+4N5fruJAECkQQIRBIgEEmAwMYNjJkmNm6qHre7iY8l/v3338V5E48lVm1w/Zs7SYBAJAECkQQIRBIgEEmAQCQBApEECEQSIBBJgEAkAQKRBAhqP7u9ublZnE9PTze2mGGpeja26itlP/3004GvWXpGtKmvjgXa404SIBBJgEAkAQKRBAhqb9zs7u4W54eHh40tZtSqNm5evnw58LlL31JX9QLUJl66CjTDnSRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJENR+6S43z+rq6qiXEK2trfXMzs/PR7ASqOZOEiAQSYBAJAECkQQIRBIgsLt9ix0cHIx6CdHs7GzPbNzXXGV/f3/US6Al7iQBApEECEQSIBBJgEAkAQKRBAhEEiAQSYBAJAECkQQIRBIgEEmAQCQBApEECEQSIBBJgEAkAQKRBAhEEiAQSYBAJAECkQQIan+l7NLSUnF+dXXV2GKGpeqzvHr1qjifmppq5Zqnp6cDnxdolztJgEAkAQKRBAhEEiAQSYCg9u720dFRcf78+fPGFjMsc3NzxXnV7vbLly8Hvub8/HzPrOpnd3x8PPD1gGa4kwQIRBIgEEmAQCQBApEECEQSIBBJgEAkAQKRBAhEEiDodLvd7qgXATCu3EkCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgSTdQ/sdDqtLWJ5ebln9uzZs4HPu7q6WpxfXl4W501c88GDB8X5yclJ7XM09bL4J0+eFOfff/997WOXlpaK86Ojo9rrePjwYXH+4sWL2ufo18bGRs9sb2+veGzVZ5yc7P31uLi4KB47Ozvbx+qylZWV4ryfP7c27e/v98zW1tZau97u7m5xvrm5OfC5nz59+p/HuJMECEQSIBBJgEAkAYLaGzfcPFX/sF3auKraADg9PS3O+9m4ef36dXHe5sZNaYOlauPmjz/+KM7v3r3bM6va9Ds+Pu5jddnh4WFxfv/+/drHtunjjz8e6jqqNsWauKaNG4ABiSRAIJIAgUgCBDZubrGqjZvSvOqJqtLTUBMTExMzMzO11/Ho0aPifG5urvY5+lV64ubOnfI9QT9P3FRt3HB7uZMECEQSIBBJgEAkAQKRBAjsbt9if/31V+15P8emeVvn6FcTn7G0u311dVU8dn5+vo/VZVXvQS29s7FqPW0qre/s7Ky161W9q/L9+/etXfPf3EkCBCIJEIgkQCCSAIFIAgS1d7dL35DWlNLzwQsLCwOft99vS2zimlXflvjFF18MfO5+NfHy1uvr6+K8n5fuVv2823zp7r1793pmVZ+xaod4VC/dPTg4KM5Lu+1Vx7bpww8/HOo6Sn9fm7pmna65kwQIRBIgEEmAQCQBgtobN1WPBjWh9G16z549G/i8VY9KVf3jexPX/PXXX4vzk5OT2ueo2nAChs+dJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQ1H7pLjAc3333XXG+srLSM7u4uGh7OT1K6zs9PR3q9SYmJibevXvX2jX/zZ0kQCCSAIFIAgQiCRCIJEBgdxvGzE8//VScT01N1T62TfPz80Ndx/T0dHHexDV//PHH/zzGnSRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEnt2GMfPo0aPa8/Pz87aXU2sdv//+e2vX+/bbb4vzP//8s7Vr/ps7SYBAJAECkQQIRBIgsHEDY+b58+fF+ezsbO1j2/TZZ58NdR0fffRRcd7ENY+Pj//zGHeSAIFIAgQiCRCIJEAgkgCB3e1bbHNzs/a80+kUj11aWirOZ2Zmaq+j6jG7ubm52ufo18bGRs/szp3yPUHVZ5yc7P31uLy8HGxh3DjuJAECkQQIRBIgEEmAQCQBArvbt9ju7m5x3u12e2ZPnjwpHvv27dvi/OjoqPY63rx5U5y/ePGi9jn6dX193TPb29srHnt2dlac3717t2dWtbtd5xlgbiZ3kgCBSAIEIgkQiCRAUHvj5ptvvmltEaVzN/H418OHD4vzi4uL4ryJay4uLhbnVY/EAePNby5AIJIAgUgCBCIJEIgkQNDplp5RA2BiYsKdJEAkkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQiCRAUPt7tzudTnG+s7PTM9va2uprEU2cY2lpqWd2dHTU1znGRVMvi6/6MxtnVX/upb8jTXnw4EFxfnJyUvscTb7gv+qzbm9vD3zu0s933H+2barz5+ZOEiAQSYBAJAECkQQIRBIgEEmAQCQBApEECEQSIBBJgKD2Y4nAcKyvrxfnl5eXA597Y2OjZ3Z1dTXweassLi4W53Nzc61ds2nuJAECkQQIRBIgEEmAQCQBArvbt9jjx49HvYS+raysFOfv3r1r7ZpVL4a9f/9+a9dMfvjhh77m/Zic7P2Vb+K8Vd68eVOcj8tLd/f29v7zGHeSAIFIAgQiCRCIJEAgkgCB3e1b7PDwcNRL6Nv09HRx3uZnef/+fXHezw7s06dPm1oOY8adJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQeOnuLba9vT3qJfRta2urOO90Oq1ds+orZT/55JPWrsnN4U4SIBBJgEAkAQKRBAhs3NxiZ2dno15C387Pz4vzNj/LKK6ZVG1edbvdgc897M28xcXF4nxhYWGo6xiEO0mAQCQBApEECEQSIBBJgMDu9i32zz//jHoJfatac5uf5erqaujXTHZ2dvqat3W9Jrx+/bo4Pzk5ae2a/aiz2+9OEiAQSYBAJAECkQQIRBIgsLt9i7148WLUS+jb/Px8cd7mZ6naxe5nB/b58+dNLYcx404SIBBJgEAkAQKRBAhEEiCovbv9+PHj4nxlZaVnVvWm5ypNnGN5eblndueO/wYAg1ERgEAkAQKRBAhEEiCovXFzeHhYnM/OztY+tkoT57i+vu6ZHR0d9XWOcfH06dNGzvPll182cp5h+vrrr4vz3377rbVrfvXVV8V51ct4+f/iThIgEEmAQCQBApEECEQSIKi9u725uVl73u9XcTZxjqWlpZ7ZzMxMX+e4bX755ZdRL6Fvn3/+eXHe5meZnCz/GvRzzZ9//rmp5TBm3EkCBCIJEIgkQCCSAIFIAgS1d7d3d3eL8w8++KD2sVWaOMfbt297Zjf12e2dnZ1GzrO6utrIeYZpfX29OL+4uGjtmouLi8V56Z0C/P9xJwkQiCRAIJIAgUgCBLU3brh5Dg4ORr2Evk1NTRXnbX6Wqm/mPDk5qX2O/f39ppbDmHEnCRCIJEAgkgCBSAIEIgkQdLrdbnfUiwAYV+4kAQKRBAhEEiAQSYBAJAECkQQIRBIgEEmAQCQBgv8BRSl4vmIKgj4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figure = plt.figure(figsize=(4,4))\n", + "cols, rows = 3, 3\n", + "for i in range(1, cols * rows + 1):\n", + " sample_idx = torch.randint(len(data), size=(1,)).item()\n", + " img = (data[sample_idx])\n", + " figure.add_subplot(rows, cols, i)\n", + " plt.axis(\"off\")\n", + " plt.imshow(img.squeeze().reshape([patch_size,patch_size]), cmap=\"gray\")\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ba432bec", + "metadata": {}, + "source": [ + "## Reconstruct the data with CEL0 Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "cb7fb67e", + "metadata": {}, + "outputs": [], + "source": [ + "# The CEL0 Algorithm\n", + "import time\n", + "\n", + "n_iter = 300\n", + "start = time.time()\n", + "cel0 = inference.CEL0(coeff_lr=1e-1,threshold=0.1,n_iter=n_iter)\n", + "A = cel0.infer(data, dictionary)\n", + "\n", + "reconstruction = (dictionary@A.t()).t()\n", + "\n", + "end = time.time()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "19040166", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running time of 300 iterations: 0.038336992263793945\n" + ] + } + ], + "source": [ + "print(\"Running time of\", n_iter, \"iterations:\", end - start)\n", + "# error = data - reconstruction\n", + "# print(\"Error\", error.norm(dim=0).mean())" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "db67eacc", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA7wAAAEnCAYAAACKfU+eAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABFmklEQVR4nO3de1yUZf7/8ffIYfCAGCqnQiUPecA0MQsPqbliqKwdLDuYaNpKaK3RQdEKtFbKNmPTwFwr6mea26amrVl881Sp67mDupWFQgWSWmCWiHD//nCZbeQ4MMMMw+v5eNyPmuu+7rk/czXzaT5c932NyTAMQwAAAAAAuJkmzg4AAAAAAABHoOAFAAAAALglCl4AAAAAgFui4AUAAAAAuCUKXgAAAACAW6LgBQAAAAC4JQpeAAAAAIBbouAFAAAAALglCl4AAAAAgFui4HVRGRkZMplM2rNnj7NDcSiTyWS1+fn5aciQIfrXv/5l83Nt375dycnJ+vnnn2sVy5AhQxQeHl6rYwHUXVneK9s8PT0VHBys22+/XV9//bWzw7OrtLQ0ZWRkODWGFStWKDU1tcJ9JpNJycnJ9RoPAMf57LPPNGnSJIWFhcnHx0ctWrRQnz59tGDBAp06dcph5121apV69Oihpk2bymQy6cCBA5KkRYsWqVOnTvL29pbJZNLPP/+siRMnqkOHDjafY8iQIRoyZIhd477YoUOHlJycrKNHjzr0PHAMCl443dixY7Vjxw598sknevHFF5WXl6eYmBibi97t27dr7ty5tS54AbiGV199VTt27ND//d//afr06Vq3bp0GDhyon376ydmh2Y2rF7w7duzQlClT6jcgAA7x97//XREREdq9e7ceeeQRbdy4UWvWrNGtt96qJUuWaPLkyQ45748//qi7775bHTt21MaNG7Vjxw516dJFBw4c0AMPPKChQ4dq06ZN2rFjh3x9ffX4449rzZo1Np8nLS1NaWlpDngF/3Po0CHNnTuXgreB8nR2AEBgYKCuvfZaSVL//v0VGRmpTp06KTU1VaNGjXJydADqW3h4uPr27Svpwl/uS0pKlJSUpLVr12rSpElOjq7+FRcXW2a860tZTgbQsO3YsUP33Xefhg8frrVr18psNlv2DR8+XA899JA2btzokHN/9dVXKi4u1vjx4zV48GBL+8GDByVJ9957r/r162dp79ixY63O071797oFCrfHDG8DMnHiRLVo0UL/+c9/NGLECDVv3lzBwcF6+umnJUk7d+7UwIED1bx5c3Xp0kWvvfaa1fE//vij4uPj1b17d7Vo0UIBAQG6/vrr9dFHH5U713fffaexY8fK19dXrVq10l133aXdu3fLZDKVm5XYs2eP/vjHP8rf318+Pj666qqr9I9//KPWr7Njx45q27atjh07JknKzMzUmDFjdNlll8nHx0edOnXS1KlTdeLECcsxycnJeuSRRyRJYWFhlksit2zZYumzYsUKRUZGqkWLFmrRooV69+6tl19+udz5d+/erUGDBqlZs2a6/PLL9fTTT6u0tLTWrwdA3ZQVv8ePH7e01TTvfP/99/rTn/6k0NBQeXt7KyQkRGPHjrV6ruzsbI0fP14BAQEym83q1q2bnnvuOavP/dGjR2UymfTXv/5VCxcuVFhYmFq0aKHIyEjt3LnT6pzffvutbr/9doWEhMhsNiswMFDDhg2zXMrXoUMHHTx4UFu3brXkqrLL+LZs2SKTyaT/9//+nx566CFdeumlMpvNOnLkiJKTk2Uymcq9xrJLwS+eeagq55XdOnLs2DGry8jLVHRJ8xdffKExY8bokksukY+Pj3r37l3u/zNl8a9cuVJz5sxRSEiIWrZsqT/84Q/68ssvy8UOwLHmz58vk8mkpUuXWhW7Zby9vfXHP/5RklRaWqoFCxaoa9euMpvNCggI0IQJE/Tdd9+VO+7//u//NGzYMLVs2VLNmjXTgAED9OGHH1r2T5w4UQMHDpQkjRs3TiaTyXLp8fjx4yVJ11xzjUwmkyZOnGg55uJLmktLS7Vo0SL17t1bTZs2VatWrXTttddq3bp1lj4VXdJ87tw5PfXUU5bX0rZtW02aNEk//vijVb8OHTpo9OjR2rhxo/r06aOmTZuqa9eueuWVVyx9MjIydOutt0qShg4dasmXZd+H9+/fr9GjR1v+HxISEqJRo0ZVOG5wDmZ4G5ji4mLdfPPNiouL0yOPPKIVK1YoMTFRhYWFevvttzVz5kxddtllWrRokSZOnKjw8HBFRERIkuUejaSkJAUFBemXX37RmjVrNGTIEH344YeWZHHmzBkNHTpUp06d0jPPPKNOnTpp48aNGjduXLl4Nm/erBtuuEHXXHONlixZIj8/P7355psaN26cfv31V0sSs8VPP/2kkydPqnPnzpKkb775RpGRkZoyZYr8/Px09OhRLVy4UAMHDtTnn38uLy8vTZkyRadOndKiRYu0evVqBQcHS/rfX/2eeOIJPfnkk7r55pv10EMPyc/PT1988YWlqC6Tl5enu+66Sw899JCSkpK0Zs0aJSYmKiQkRBMmTLD5tQCou6ysLElSly5dJNU873z//fe6+uqrVVxcrNmzZ+vKK6/UyZMn9f777+unn35SYGCgfvzxR/Xv31/nzp3Tk08+qQ4dOujdd9/Vww8/rG+++abcZXIvvviiunbtarkU+PHHH9fIkSOVlZUlPz8/SdLIkSNVUlKiBQsWqF27djpx4oS2b99uud1izZo1Gjt2rPz8/CzPf/EX0cTEREVGRmrJkiVq0qSJAgICbBqz6nJeWlqa/vSnP+mbb76p0SWEX375pfr376+AgAC98MILat26tZYvX66JEyfq+PHjevTRR636z549WwMGDNCyZctUWFiomTNnKiYmRocPH5aHh4dNrwVA7ZSUlGjTpk2KiIhQaGhotf3vu+8+LV26VNOnT9fo0aN19OhRPf7449qyZYv27dunNm3aSJKWL1+uCRMmaMyYMXrttdfk5eWll156SSNGjND777+vYcOG6fHHH1e/fv00bdo0zZ8/X0OHDlXLli0lSStXrtRTTz2lV199VV27dlXbtm0rjWnixIlavny5Jk+erHnz5snb21v79u2r8tLi0tJSjRkzRh999JEeffRR9e/fX8eOHVNSUpKGDBmiPXv2qGnTppb+n376qR566CHNmjVLgYGBWrZsmSZPnqxOnTrpuuuu06hRozR//nzNnj1bL774ovr06SPpwgTNmTNnNHz4cIWFhenFF19UYGCg8vLytHnzZp0+fbom/5lQHwy4pFdffdWQZOzevdvSFhsba0gy3n77bUtbcXGx0bZtW0OSsW/fPkv7yZMnDQ8PDyMhIaHSc5w/f94oLi42hg0bZtx0002W9hdffNGQZLz33ntW/adOnWpIMl599VVLW9euXY2rrrrKKC4utuo7evRoIzg42CgpKanydUoy4uPjjeLiYuPcuXPG4cOHjejoaEOS8eKLL5brX1paahQXFxvHjh0zJBnvvPOOZd+zzz5rSDKysrKsjvn2228NDw8P46677qoylsGDBxuSjH//+99W7d27dzdGjBhR5bEA6q4s7+3cudMoLi42Tp8+bWzcuNEICgoyrrvuOkueqWneueeeewwvLy/j0KFDlZ5z1qxZFX7u77vvPsNkMhlffvmlYRiGkZWVZUgyevbsaZw/f97Sb9euXYYkY+XKlYZhGMaJEycMSUZqamqVr7VHjx7G4MGDy7Vv3rzZkGRcd9115fYlJSUZFf1vu2zcynJfTXPeqFGjjPbt21e4T5KRlJRkeXz77bcbZrPZyM7OtuoXHR1tNGvWzPj555+t4h85cqRVv3/84x+GJGPHjh1VxgTAfvLy8gxJxu23315t38OHD1u+k/3ev//9b0OSMXv2bMMwDOPMmTOGv7+/ERMTY9WvpKTE6NWrl9GvXz9LW1k+eOutt6z6VvQd1zAufM/9fU7atm2bIcmYM2dOlbEPHjzYKp+uXLmy3PdlwzCM3bt3G5KMtLQ0S1v79u0NHx8f49ixY5a23377zfD39zemTp1qaXvrrbcMScbmzZutnnPPnj2GJGPt2rVVxgjn4pLmBsZkMmnkyJGWx56enurUqZOCg4N11VVXWdr9/f0VEBBQbgZzyZIl6tOnj3x8fOTp6SkvLy99+OGHOnz4sKXP1q1b5evrqxtuuMHq2DvuuMPq8ZEjR/Sf//xHd911lyTp/Pnzlm3kyJHKzc2t0SVsaWlp8vLykre3t7p166bt27dr3rx5io+PlyTl5+crLi5OoaGhlpjbt28vSVZxVyYzM1MlJSWaNm1atX2DgoKs7ieRpCuvvLLcOAJwnGuvvVZeXl6WPHTJJZfonXfekaenp01557333tPQoUPVrVu3Ss+1adMmde/evdznfuLEiTIMQ5s2bbJqHzVqlNUM5ZVXXilJlhzh7++vjh076tlnn9XChQu1f//+Wt0Sccstt9h8TBlbcl5Nbdq0ScOGDSs3SzRx4kT9+uuv2rFjh1V72SWSZS4eJwCuZfPmzZJU7sq8fv36qVu3bpbLlbdv365Tp04pNjbWKv+Wlpbqhhtu0O7du3XmzBm7xPTee+9Jks257N1331WrVq0UExNjFWPv3r0VFBRkdbubJPXu3Vvt2rWzPPbx8VGXLl1qlK86deqkSy65RDNnztSSJUt06NAhm2JF/aDgbWCaNWsmHx8fqzZvb2/5+/uX6+vt7a2zZ89aHi9cuFD33XefrrnmGr399tvauXOndu/erRtuuEG//fabpd/JkycVGBhY7vkubiu7B+7hhx+Wl5eX1VZWrP7+PtvK3Hbbbdq9e7f27NmjL7/8UidPntTjjz8u6cJlKVFRUVq9erUeffRRffjhh9q1a5flnrnfx12Zsvs1Lrvssmr7tm7dulyb2Wyu0XkA2Mfrr7+u3bt3a9OmTZo6daoOHz5s+YObLXnnxx9/rPZzf/LkScstEL8XEhJi2f97F+eIskuRy3KEyWTShx9+qBEjRmjBggXq06eP2rZtqwceeMCmy9sqiqmmbMl5NWXvcQLgeG3atFGzZs0st4VUpewzXNnnvGx/WQ4eO3ZsuRz8zDPPyDAMu/3M0Y8//igPDw8FBQXZdNzx48f1888/y9vbu1yMeXl55b6b1uW7n5+fn7Zu3arevXtr9uzZ6tGjh0JCQpSUlKTi4mKb4objcA9vI7J8+XINGTJE6enpVu0Xfwlr3bq1du3aVe74vLw8q8dl93IkJibq5ptvrvCcV1xxRbVxtW3b1rIozcW++OILffrpp8rIyFBsbKyl/ciRI9U+7++fX7qwEFdN7mEB4FzdunWz5IShQ4eqpKREy5Yt0z//+U/17NlTUs3yTtu2batdNKR169bKzc0t1/7DDz9I+l+es0X79u0ti0N99dVX+sc//qHk5GSdO3dOS5YsqdFzVLQ4VdkfO4uKiqzu+b34y5sjcp4jxgmAY3l4eGjYsGF677339N1331X5R7Cyoi83N7dcvx9++MHyGS/756JFiypdzb2iSZPaaNu2rUpKSpSXl2fTHwHbtGmj1q1bV7r6tK+vr13iK9OzZ0+9+eabMgxDn332mTIyMjRv3jw1bdpUs2bNsuu5UDvM8DYiJpOp3MIon332WblL0QYPHqzTp09bLiUp8+abb1o9vuKKK9S5c2d9+umn6tu3b4VbXZNK2Ze+i+N+6aWXyvWtbAYhKipKHh4e5Qp9AA3DggULdMkll+iJJ55Q586da5x3oqOjtXnz5ipvrRg2bJgOHTqkffv2WbW//vrrMplMGjp0aJ1i79Klix577DH17NnT6hy1uXKkbPXSzz77zKp9/fr1Vo9rmvNsiWHYsGHatGmTpcAt8/rrr6tZs2b8jBHgohITE2UYhu69916dO3eu3P7i4mKtX79e119/vaQLkyO/t3v3bh0+fFjDhg2TJA0YMECtWrXSoUOHKs3B3t7edok9Ojpakmz+/jZ69GidPHlSJSUlFcZXk8mYi9XkKhWTyaRevXrp+eefV6tWrcr9fwXOwwxvIzJ69Gg9+eSTSkpK0uDBg/Xll19q3rx5CgsL0/nz5y39YmNj9fzzz2v8+PF66qmn1KlTJ7333nt6//33JUlNmvzv7yQvvfSSoqOjNWLECE2cOFGXXnqpTp06pcOHD2vfvn1666236hRz165d1bFjR82aNUuGYcjf31/r169XZmZmub5lMz9/+9vfFBsbKy8vL11xxRXq0KGDZs+erSeffFK//fab7rjjDvn5+enQoUM6ceKE5s6dW6cYATjWJZdcosTERD366KNasWJFjfPOvHnz9N577+m6667T7Nmz1bNnT/3888/auHGjEhIS1LVrVz344IN6/fXXNWrUKM2bN0/t27fXv/71L6Wlpem+++6zrAxdU5999pmmT5+uW2+9VZ07d5a3t7c2bdqkzz77zOov/WUzAqtWrdLll18uHx8fSw6rzMiRI+Xv729ZrdTT01MZGRnKycmx6lfTnNezZ0+tXr1a6enpioiIUJMmTSq92iYpKUnvvvuuhg4dqieeeEL+/v5644039K9//UsLFiywrFANwLVERkYqPT1d8fHxioiI0H333acePXqouLhY+/fv19KlSxUeHq41a9boT3/6kxYtWqQmTZooOjraskpzaGioHnzwQUlSixYttGjRIsXGxurUqVMaO3asAgIC9OOPP+rTTz/Vjz/+aLcJhkGDBunuu+/WU089pePHj2v06NEym83av3+/mjVrpvvvv7/C426//Xa98cYbGjlypP785z+rX79+8vLy0nfffafNmzdrzJgxuummm2yKJTw8XJK0dOlS+fr6ysfHR2FhYdqxY4fS0tJ044036vLLL5dhGFq9erV+/vlnDR8+vM5jADtx5opZqFxlqzQ3b968XN/BgwcbPXr0KNfevn17Y9SoUZbHRUVFxsMPP2xceumlho+Pj9GnTx9j7dq15VbFMwzDyM7ONm6++WajRYsWhq+vr3HLLbcYGzZsKLcysmEYxqeffmrcdtttRkBAgOHl5WUEBQUZ119/vbFkyZJqX6ckY9q0aVX2OXTokDF8+HDD19fXuOSSS4xbb73VyM7OLreKqGEYRmJiohESEmI0adKk3Gp6r7/+unH11VcbPj4+RosWLYyrrrrKasXpysaxovEBYH+VrdxpGBdWzWzXrp3RuXNn4/z58zXOOzk5OcY999xjBAUFGV5eXkZISIhx2223GcePH7f0OXbsmHHnnXcarVu3Nry8vIwrrrjCePbZZ61WmS9bpfnZZ58tF9vvc9Hx48eNiRMnGl27djWaN29utGjRwrjyyiuN559/3mp156NHjxpRUVGGr6+vIcmSYypb1bTMrl27jP79+xvNmzc3Lr30UiMpKclYtmxZhSvUV5fzTp06ZYwdO9Zo1aqVYTKZrFaArii/fv7550ZMTIzh5+dneHt7G7169bJ6vqriLxu/i/sDqB8HDhwwYmNjjXbt2hne3t5G8+bNjauuusp44oknjPz8fMMwLqy0/MwzzxhdunQxvLy8jDZt2hjjx483cnJyyj3f1q1bjVGjRhn+/v6Gl5eXcemllxqjRo2y+uzXdZXmspief/55Izw83PD29jb8/PyMyMhIY/369ZY+F6/SbBgXfsXkr3/9q9GrVy9LDuzatasxdepU4+uvv7b0u/i7clXPmZqaaoSFhRkeHh6WfPaf//zHuOOOO4yOHTsaTZs2Nfz8/Ix+/foZGRkZ5f8jwGlMhmEY9V1ko2GaP3++HnvsMWVnZ9t1MRQAAAAAcAQuaUaFFi9eLOnCJcXFxcXatGmTXnjhBY0fP55iFwAAAECDQMGLCjVr1kzPP/+8jh49qqKiIrVr104zZ87UY4895uzQAAAAAKBGuKQZAAAAAOCW+FkiAAAAAIBbouAFAAAAALglCl4AAIBGatu2bYqJiVFISIhMJpPWrl1b7TFbt25VRESEfHx8dPnll2vJkiWODxQAasnlFq0qLS3VDz/8IF9fX5lMJmeHA+C/DMPQ6dOnFRISoiZN+FuZs5AjAdfUUHPkmTNn1KtXL02aNEm33HJLtf2zsrI0cuRI3XvvvVq+fLk++eQTxcfHq23btjU6vgy5DEBd2JJzXW7Rqu+++06hoaHODgNAJXJycvhpKiciRwKurSHnSJPJpDVr1ujGG2+stM/MmTO1bt06HT582NIWFxenTz/9VDt27KjxuchlAOyhJjnX5WZ4fX19JUkDNVKe8rLp2G+fubrW57185u5aH1tbtY3XGbFmP9avVse1e2qXnSOBs5xXsT7WBstnFM5BjqweORLO0Fhy5I4dOxQVFWXVNmLECL388ssqLi6Wl1fFeamoqEhFRUWWx2XzLTk5OWrZsqXjAgbglgoLCxUaGlqjnOtyBW/ZZS2e8pKnybYvc018fGp9XlvPZQ+1jZdY4RT/vRaES8+cixxZPWKFUzSSHJmXl6fAwECrtsDAQJ0/f14nTpxQcHBwhcelpKRo7ty55dpbtmxJwQug1mqScx12k0laWprCwsLk4+OjiIgIffTRR446FQA0KORHAA3ZxV8wy2Zrq/rimZiYqIKCAsuWk5Pj0BgBoIxDCt5Vq1ZpxowZmjNnjvbv369BgwYpOjpa2dnZjjgdADQY5EcADVlQUJDy8vKs2vLz8+Xp6anWrVtXepzZbLbM5jKrC6A+OaTgXbhwoSZPnqwpU6aoW7duSk1NVWhoqNLT0x1xOgBoMMiPABqyyMhIZWZmWrV98MEH6tu3b6X37wKAM9m94D137pz27t1bbkGDqKgobd++vVz/oqIiFRYWWm0A4I5szY8SORKAY/3yyy86cOCADhw4IOnCzw4dOHDActVJYmKiJkyYYOkfFxenY8eOKSEhQYcPH9Yrr7yil19+WQ8//LAzwgdUUmpoxzcn9c6B77Xjm5MqKbX/D9DUxzngOHZftOrEiRMqKSmpcEGDiy+BkSpfxAAA3I2t+VEiRwJwrD179mjo0KGWxwkJCZKk2NhYZWRkKDc31+qWi7CwMG3YsEEPPvigXnzxRYWEhOiFF16w6Td4AXvZ+EWu5q4/pNyCs5a2YD8fJcV01w3hFS+g5orngGM5bJXmihY0qGgxg8TEREtylf63xDQAuKua5keJHAnAsYYMGWJZdKoiGRkZ5doGDx6sffv2OTAqoHobv8jVfcv36eJ3b17BWd23fJ/Sx/epc0FaH+eA49m94G3Tpo08PDwqXNDg4lkN6cIiBmaz2d5hAIDLsTU/SuRIAAAuVlJqaO76Q+UKUenCL4SZJM1df0jDuwfJo0ntfiqsPs6B+mH3e3i9vb0VERFRbkGDzMxM9e/f396nA4AGg/wIAEDd7co6ZXWJ8cUMSbkFZ7Ur65RLnwP1wyGXNCckJOjuu+9W3759FRkZqaVLlyo7O1txcXGOOB0ANBjkRwAA6ib/dOWFaG36OescqB8OKXjHjRunkydPat68ecrNzVV4eLg2bNig9u3bO+J0ANBgkB8BAO6kpNTQrqxTyj99VgG+PuoX5u/wS3wDfH3s2s9Z50D9cNiiVfHx8YqPj3fU0wNAg0V+BAC4A2etYNwvzF/Bfj7KKzhb4T22JklBfheKb1c+B+qH3e/hBQAAAODeylYwvvg+17IVjDd+keuwc3s0MSkpprukC4Xn75U9TorpXqeZ5vo4B+oHBS8AAACAGqtuBWPpwgrGJaWV/+RVXd0QHqz08X0U0NL6lwyC/Hzs9nNB9XEOOJ7DLmkGAAAA4H5sWcE4smNrh8VxQ3iwBnRqo57JH0iSMiZdrUGd29p11rU+zgHHYoYXAAAAQI250grGvy88HbVgVn2cA45DwQsAAACgxljBGA0JBS8AAACAGitbwbiyeU6TLqzWzArGcAUUvAAAAABqjBWM0ZBQ8AIAAACwCSsYo6FglWYAAAAANmMFYzQEzPACAAAAqBVWMIaro+AFAAAAALglLmkGAAAAGqmSUkO7sk4p//RZBfj6MEsLt0PBCwAAADRCG7/I1dz1h5RbcNbSFuzno6SY7iw6BbfBJc0AAABAI7Pxi1zdt3yfVbErSXkFZ3Xf8n3a+EWukyID7IuCFwAAAGhESkoNzV1/SEYF+8ra5q4/pJLSinoADQsFLwAAANCI7Mo6VW5m9/cMSbkFZ7Ur61T9BdWIlZQa2vHNSb1z4Hvt+OYkf2iwM+7hBQAAABqR/NOVF7u16Yfa4z5qx2OGFwAAAGhEAnx97NoPtcN91PXD7gVvSkqKrr76avn6+iogIEA33nijvvzyS3ufBgAaHPIjAMAV9AvzV7Cfjyr78SGTLswy9gvzr8+wGhXuo64/di94t27dqmnTpmnnzp3KzMzU+fPnFRUVpTNnztj7VADQoJAfAQCuwKOJSUkx3SWpXNFb9jgppju/x+tA3Eddf+x+D+/GjRutHr/66qsKCAjQ3r17dd1119n7dADQYJAfAQCu4obwYKWP76OkdQd1vLDI0h7E/aP1gvuo64/DF60qKCiQJPn7V3xJRFFRkYqK/vchKywsdHRIAOASqsuPEjkSAOA4N4QHa0CnNuqZ/IEkKWPS1RrUuS0zu/WA+6jrj0MXrTIMQwkJCRo4cKDCw8Mr7JOSkiI/Pz/LFhoa6siQAMAl1CQ/SuRIAIBj/b647RfmT7FbT7iPuv44tOCdPn26PvvsM61cubLSPomJiSooKLBsOTk5jgwJAFxCTfKjRI4EAMAdcR91/XHYJc3333+/1q1bp23btumyyy6rtJ/ZbJbZbHZUGADgcmqaHyVyJAAA7or7qOuH3QtewzB0//33a82aNdqyZYvCwsLsfQoAaJDIjwAA4Pe4j9rx7F7wTps2TStWrNA777wjX19f5eXlSZL8/PzUtGlTe58OABoM8iMAALgY91E7lt3v4U1PT1dBQYGGDBmi4OBgy7Zq1Sp7nwoAGhTyIwAAQP2ye8FrGEaF28SJE+19KgBoUMiPAFxVWlqawsLC5OPjo4iICH300UdV9n/jjTfUq1cvNWvWTMHBwZo0aZJOnjxZT9ECQM05dJVmAAAAuLZVq1ZpxowZmjNnjvbv369BgwYpOjpa2dnZFfb/+OOPNWHCBE2ePFkHDx7UW2+9pd27d2vKlCn1HDkAVI+CFwAAoBFbuHChJk+erClTpqhbt25KTU1VaGio0tPTK+y/c+dOdejQQQ888IDCwsI0cOBATZ06VXv27KnnyAGgehS8AAAAjdS5c+e0d+9eRUVFWbVHRUVp+/btFR7Tv39/fffdd9qwYYMMw9Dx48f1z3/+U6NGjar0PEVFRSosLLTaAKA+UPACAAA0UidOnFBJSYkCAwOt2gMDAy0ryV+sf//+euONNzRu3Dh5e3srKChIrVq10qJFiyo9T0pKivz8/CxbaGioXV8HAFSGghcAAKCRM5msfwbFMIxybWUOHTqkBx54QE888YT27t2rjRs3KisrS3FxcZU+f2JiogoKCixbTk6OXeMHgMrY/Xd4AQAA0DC0adNGHh4e5WZz8/Pzy836lklJSdGAAQP0yCOPSJKuvPJKNW/eXIMGDdJTTz2l4ODgcseYzWaZzWb7vwAAqAYzvAAAAI2Ut7e3IiIilJmZadWemZmp/v37V3jMr7/+qiZNrL9Cenh4SLowMwwAroSCFwAAoBFLSEjQsmXL9Morr+jw4cN68MEHlZ2dbblEOTExURMmTLD0j4mJ0erVq5Wenq5vv/1Wn3zyiR544AH169dPISEhznoZAFAhLmkGAABoxMaNG6eTJ09q3rx5ys3NVXh4uDZs2KD27dtLknJzc61+k3fixIk6ffq0Fi9erIceekitWrXS9ddfr2eeecZZLwEAKkXBCwAA0MjFx8crPj6+wn0ZGRnl2u6//37df//9Do4KAOqOS5oBAAAAAG6JghcAAAAA4JYoeAEAAAAAbomCFwAAAADglih4AQAAAABuiYIXAAAAAOCWKHgBAAAAAG6JghcAAAAA4JYcXvCmpKTIZDJpxowZjj4VADQo5EcAAADHcmjBu3v3bi1dulRXXnmlI08DAA0O+REAAMDxHFbw/vLLL7rrrrv097//XZdccomjTgMADQ75EQAAoH44rOCdNm2aRo0apT/84Q9V9isqKlJhYaHVBgDurKb5USJHAgAA1IWnI570zTff1L59+7R79+5q+6akpGju3LmOCAMAXI4t+VEiRwIAANSF3Wd4c3Jy9Oc//1nLly+Xj49Ptf0TExNVUFBg2XJycuwdEgC4BFvzo0SOBAAAqAu7z/Du3btX+fn5ioiIsLSVlJRo27ZtWrx4sYqKiuTh4WHZZzabZTab7R0GALgcW/OjRI4EAACoC7sXvMOGDdPnn39u1TZp0iR17dpVM2fOLPdlDgAaC/IjAABA/bJ7wevr66vw8HCrtubNm6t169bl2gGgMSE/AgAA1C+H/g4vAAAAAADO4pBVmi+2ZcuW+jgNADQ45EcAAADHYYYXAAAAAOCWKHgBAAAAAG6JghcAAAAA4JYoeAEAAAAAbomCFwAAAADglih4AQAAAABuiYIXAAAAAOCWKHgBAAAAAG6JghcAAAAA4JYoeAEAAAAAbomCFwAAAADglih4AQAAAABuiYIXAAAAAOCWKHgBAAAaubS0NIWFhcnHx0cRERH66KOPquxfVFSkOXPmqH379jKbzerYsaNeeeWVeooWAGrO09kBAAAAwHlWrVqlGTNmKC0tTQMGDNBLL72k6OhoHTp0SO3atavwmNtuu03Hjx/Xyy+/rE6dOik/P1/nz5+v58gBoHoUvAAAAI3YwoULNXnyZE2ZMkWSlJqaqvfff1/p6elKSUkp13/jxo3aunWrvv32W/n7+0uSOnToUJ8hA0CNcUkzAABAI3Xu3Dnt3btXUVFRVu1RUVHavn17hcesW7dOffv21YIFC3TppZeqS5cuevjhh/Xbb7/VR8gAYBNmeAEAABqpEydOqKSkRIGBgVbtgYGBysvLq/CYb7/9Vh9//LF8fHy0Zs0anThxQvHx8Tp16lSl9/EWFRWpqKjI8riwsNB+LwIAquCQGd7vv/9e48ePV+vWrdWsWTP17t1be/fudcSpAKBBIT8CcEUmk8nqsWEY5drKlJaWymQy6Y033lC/fv00cuRILVy4UBkZGZXO8qakpMjPz8+yhYaG2v01AEBF7F7w/vTTTxowYIC8vLz03nvv6dChQ3ruuefUqlUre58KABoU8iMAV9OmTRt5eHiUm83Nz88vN+tbJjg4WJdeeqn8/Pwsbd26dZNhGPruu+8qPCYxMVEFBQWWLScnx34vAgCqYPdLmp955hmFhobq1VdftbSxkAEAkB8BuB5vb29FREQoMzNTN910k6U9MzNTY8aMqfCYAQMG6K233tIvv/yiFi1aSJK++uorNWnSRJdddlmFx5jNZpnNZvu/AACoht1neMsWMrj11lsVEBCgq666Sn//+98r7V9UVKTCwkKrDQDcka35USJHAnC8hIQELVu2TK+88ooOHz6sBx98UNnZ2YqLi5N0YXZ2woQJlv533nmnWrdurUmTJunQoUPatm2bHnnkEd1zzz1q2rSps14GAFTI7gXvt99+q/T0dHXu3Fnvv/++4uLi9MADD+j111+vsD/3dABoLGzNjxI5EoDjjRs3TqmpqZo3b5569+6tbdu2acOGDWrfvr0kKTc3V9nZ2Zb+LVq0UGZmpn7++Wf17dtXd911l2JiYvTCCy846yUAQKXsfklzaWmp+vbtq/nz50uSrrrqKh08eFDp6elWfx0sk5iYqISEBMvjwsJCvtABcEu25keJHAmgfsTHxys+Pr7CfRkZGeXaunbtqszMTAdHBQB1Z/cZ3uDgYHXv3t2qrVu3blZ/Gfw9s9msli1bWm0A4I5szY8SORIAAKAu7F7wDhgwQF9++aVV21dffWW5LAYAGivyIwAAQP2ye8H74IMPaufOnZo/f76OHDmiFStWaOnSpZo2bZq9TwUADQr5EQAAoH7ZveC9+uqrtWbNGq1cuVLh4eF68sknlZqaqrvuusvepwKABoX8CAAAUL/svmiVJI0ePVqjR492xFMDQINGfgQAAKg/dp/hBQAAAADAFVDwAgAAAADcEgUvAAAAAMAtUfACAAAAANwSBS8AAAAAwC1R8AIAAAAA3BIFLwAAAADALVHwAgAAAADcEgUvAAAAAMAtUfACAAAAANwSBS8AAAAAwC1R8AIAAAAA3BIFLwAAAADALVHwAgAAAADcEgUvAAAAAMAtUfACAAAAANwSBS8AAAAAwC1R8AIAAAAA3JLdC97z58/rscceU1hYmJo2barLL79c8+bNU2lpqb1PBQANCvkRAACgfnna+wmfeeYZLVmyRK+99pp69OihPXv2aNKkSfLz89Of//xne58OABoM8iMAAED9snvBu2PHDo0ZM0ajRo2SJHXo0EErV67Unj177H0qAGhQyI8AAAD1y+6XNA8cOFAffvihvvrqK0nSp59+qo8//lgjR46ssH9RUZEKCwutNgBwR7bmR4kcCQAAUBd2n+GdOXOmCgoK1LVrV3l4eKikpER/+ctfdMcdd1TYPyUlRXPnzrV3GADgcmzNjxI5EgAAoC7sPsO7atUqLV++XCtWrNC+ffv02muv6a9//atee+21CvsnJiaqoKDAsuXk5Ng7JABwCbbmR4kcCQAAUBd2n+F95JFHNGvWLN1+++2SpJ49e+rYsWNKSUlRbGxsuf5ms1lms9neYQCAy7E1P0rkSAAAgLqw+wzvr7/+qiZNrJ/Ww8ODn90A0OiRHwEAAOqX3Wd4Y2Ji9Je//EXt2rVTjx49tH//fi1cuFD33HOPvU8FAA0K+REAAKB+2b3gXbRokR5//HHFx8crPz9fISEhmjp1qp544gl7nwoAGhTyIwAAQP2y+yXNvr6+Sk1N1bFjx/Tbb7/pm2++0VNPPSVvb297nwoAGhTyIwBXlZaWprCwMPn4+CgiIkIfffRRjY775JNP5Onpqd69ezs2QACoJbsXvAAAAGg4Vq1apRkzZmjOnDnav3+/Bg0apOjoaGVnZ1d5XEFBgSZMmKBhw4bVU6QAYDsKXgAAgEZs4cKFmjx5sqZMmaJu3bopNTVVoaGhSk9Pr/K4qVOn6s4771RkZGQ9RQoAtqPgBQAAaKTOnTunvXv3Kioqyqo9KipK27dvr/S4V199Vd98842SkpJqdJ6ioiIVFhZabQBQHyh4AQAAGqkTJ06opKREgYGBVu2BgYHKy8ur8Jivv/5as2bN0htvvCFPz5qtf5qSkiI/Pz/LFhoaWufYAaAmKHgBAAAaOZPJZPXYMIxybZJUUlKiO++8U3PnzlWXLl1q/PyJiYkqKCiwbDk5OXWOGQBqwu4/SwQAAICGoU2bNvLw8Cg3m5ufn19u1leSTp8+rT179mj//v2aPn26JKm0tFSGYcjT01MffPCBrr/++nLHmc1mmc1mx7wIAKgCM7wAAACNlLe3tyIiIpSZmWnVnpmZqf79+5fr37JlS33++ec6cOCAZYuLi9MVV1yhAwcO6Jprrqmv0AGgRpjhBQAAaMQSEhJ09913q2/fvoqMjNTSpUuVnZ2tuLg4SRcuR/7+++/1+uuvq0mTJgoPD7c6PiAgQD4+PuXaAcAVUPACAAA0YuPGjdPJkyc1b9485ebmKjw8XBs2bFD79u0lSbm5udX+Ji8AuCoKXgAAgEYuPj5e8fHxFe7LyMio8tjk5GQlJyfbPygAsAPu4QUAAAAAuCUKXgAAAACAW6LgBQAAAAC4JQpeAAAAAIBbouAFAAAAALglCl4AAAAAgFui4AUAAAAAuCWbC95t27YpJiZGISEhMplMWrt2rdV+wzCUnJyskJAQNW3aVEOGDNHBgwftFS8AuCzyIwAAgGuxueA9c+aMevXqpcWLF1e4f8GCBVq4cKEWL16s3bt3KygoSMOHD9fp06frHCwAuDLyIwAAgGvxtPWA6OhoRUdHV7jPMAylpqZqzpw5uvnmmyVJr732mgIDA7VixQpNnTq1btECgAsjPwIAALgWu97Dm5WVpby8PEVFRVnazGazBg8erO3bt9vzVADQoJAfAQAA6p/NM7xVycvLkyQFBgZatQcGBurYsWMVHlNUVKSioiLL48LCQnuGBAAuoTb5USJHAgAA1IVDVmk2mUxWjw3DKNdWJiUlRX5+fpYtNDTUESEBgEuwJT9K5EgAAIC6sGvBGxQUJOl/Mxll8vPzy81qlElMTFRBQYFly8nJsWdIAOASapMfJXIkAABAXdi14A0LC1NQUJAyMzMtbefOndPWrVvVv3//Co8xm81q2bKl1QYA7qY2+VEiRwIAANSFzffw/vLLLzpy5IjlcVZWlg4cOCB/f3+1a9dOM2bM0Pz589W5c2d17txZ8+fPV7NmzXTnnXfaNXAAcDXkRwAAANdic8G7Z88eDR061PI4ISFBkhQbG6uMjAw9+uij+u233xQfH6+ffvpJ11xzjT744AP5+vraL2oAcEHkRwAAANdic8E7ZMgQGYZR6X6TyaTk5GQlJyfXJS4AaHDIjwAAAK7FIas0AwAAAADgbHb9HV57yn6in5r4+Nh0zL9v+WutzzfgzMO1Pra2ahuvM2J97Y7FtToutsl0O0cCZyk9e1aa946zw8B/kSMrR46EM5AjAcA1McMLAAAAAHBLFLwAAAAAALdEwQsAAAAAcEsUvAAAAAAAt0TBCwAAAABwSxS8AAAAAAC3RMELAAAAAHBLFLwAAAAAALdEwQsAAAAAcEsUvAAAAAAAt0TBCwAAAABwSxS8AAAAAAC3RMELAAAAAHBLns4OoDLt5u2Sp8nLpmOuaf5Qrc/Xac6OWh9bW7WN1xmxxpZOr9VxHR6v/1jhGOeNYn3r7CBgQY6sHDkSztDQc2RaWpqeffZZ5ebmqkePHkpNTdWgQYMq7Lt69Wqlp6frwIEDKioqUo8ePZScnKwRI0bUc9QAUD1meAEAABqxVatWacaMGZozZ47279+vQYMGKTo6WtnZ2RX237Ztm4YPH64NGzZo7969Gjp0qGJiYrR///56jhwAqkfBCwAA0IgtXLhQkydP1pQpU9StWzelpqYqNDRU6enpFfZPTU3Vo48+qquvvlqdO3fW/Pnz1blzZ61fv76eIweA6tlc8G7btk0xMTEKCQmRyWTS2rVrLfuKi4s1c+ZM9ezZU82bN1dISIgmTJigH374wZ4xA4BLIj8CaGjOnTunvXv3Kioqyqo9KipK27dvr9FzlJaW6vTp0/L396+0T1FRkQoLC602AKgPNhe8Z86cUa9evbR48eJy+3799Vft27dPjz/+uPbt26fVq1frq6++0h//+Ee7BAsAroz8CKChOXHihEpKShQYGGjVHhgYqLy8vBo9x3PPPaczZ87otttuq7RPSkqK/Pz8LFtoaGid4gaAmrJ50aro6GhFR0dXuM/Pz0+ZmZlWbYsWLVK/fv2UnZ2tdu3a1S5KAGgAyI8AGiqTyWT12DCMcm0VWblypZKTk/XOO+8oICCg0n6JiYlKSEiwPC4sLKToBVAvHL5Kc0FBgUwmk1q1auXoUwFAg0J+BOBsbdq0kYeHR7nZ3Pz8/HKzvhdbtWqVJk+erLfeekt/+MMfquxrNptlNpvrHC8A2Mqhi1adPXtWs2bN0p133qmWLVtW2Id7OgA0RjXJjxI5EoBjeXt7KyIiotwVKJmZmerfv3+lx61cuVITJ07UihUrNGrUKEeHCQC15rCCt7i4WLfffrtKS0uVlpZWaT/u6QDQ2NQ0P0rkSACOl5CQoGXLlumVV17R4cOH9eCDDyo7O1txcXGSLlyOPGHCBEv/lStXasKECXruued07bXXKi8vT3l5eSooKHDWSwCASjmk4C0uLtZtt92mrKwsZWZmVjl7kZiYqIKCAsuWk5PjiJAAwCXYkh8lciQAxxs3bpxSU1M1b9489e7dW9u2bdOGDRvUvn17SVJubq7Vb/K+9NJLOn/+vKZNm6bg4GDL9uc//9lZLwEAKmX3e3jLvsx9/fXX2rx5s1q3bl1lf+7pANBY2JofJXIkgPoRHx+v+Pj4CvdlZGRYPd6yZYvjAwIAO7G54P3ll1905MgRy+OsrCwdOHBA/v7+CgkJ0dixY7Vv3z69++67KikpsSyC4O/vL29vb/tFDgAuhvwIAADgWmwuePfs2aOhQ4daHpctMR8bG6vk5GStW7dOktS7d2+r4zZv3qwhQ4bUPlIAcHHkRwAAANdic8E7ZMgQGYZR6f6q9gGAOyM/AgAAuBaH/iwRAAAAAADOYvdFq+wla/7VauLjY9Mx39y2pNbn61QSV+tja6u28Toj1vfHPlur40Y0e9jOkcBZSs+elWa/4+ww8F/kyMqRI+EM5EgAcE3M8AIAAAAA3BIFLwAAAADALVHwAgAAAADcEgUvAAAAAMAtUfACAAAAANwSBS8AAAAAwC1R8AIAAAAA3BIFLwAAAADALVHwAgAAAADcEgUvAAAAAMAtUfACAAAAANwSBS8AAAAAwC1R8AIAAAAA3JKnswOoTNjs3fI0edl0TEePuFqfr9NDO2t9bG3VNl5nxDry10dqdVzHx3fYORI4y3mjWMecHQQsyJGVI0fCGciRAOCamOEFAAAAALglCl4AAAAAgFuyueDdtm2bYmJiFBISIpPJpLVr11bad+rUqTKZTEpNTa1DiADQMJAfAQAAXIvNBe+ZM2fUq1cvLV68uMp+a9eu1b///W+FhITUOjgAaEjIjwAAAK7F5kWroqOjFR0dXWWf77//XtOnT9f777+vUaNG1To4AGhIyI8AAACuxe738JaWluruu+/WI488oh49etj76QGgwSI/AgAA1C+7/yzRM888I09PTz3wwAM16l9UVKSioiLL48LCQnuHBAAuwdb8KJEjAQAA6sKuM7x79+7V3/72N2VkZMhkMtXomJSUFPn5+Vm20NBQe4YEAC6hNvlRIkcCAADUhV0L3o8++kj5+flq166dPD095enpqWPHjumhhx5Shw4dKjwmMTFRBQUFli0nJ8eeIQGAS6hNfpTIkQAAAHVh10ua7777bv3hD3+wahsxYoTuvvtuTZo0qcJjzGazzGazPcMAAJdTm/wokSMBAADqwuaC95dfftGRI0csj7OysnTgwAH5+/urXbt2at26tVV/Ly8vBQUF6Yorrqh7tADgwsiPAAAArsXmgnfPnj0aOnSo5XFCQoIkKTY2VhkZGXYLDAAaGvIjAACAa7G54B0yZIgMw6hx/6NHj9p6CgBokMiPAAAArsXuP0sEAACAhiUtLU3PPvuscnNz1aNHD6WmpmrQoEGV9t+6dasSEhJ08OBBhYSE6NFHH1VcXJxDYispNbQr65TyT59VgK+P+oX5y6OJif122l/X8a/r8a4Qv7Nfo6Ofv6HvrysKXgAAgEZs1apVmjFjhtLS0jRgwAC99NJLio6O1qFDh9SuXbty/bOysjRy5Ejde++9Wr58uT755BPFx8erbdu2uuWWW+wa28YvcjV3/SHlFpy1tAX7+SgpprtuCA9mfx3313X863q8K8Tv7Nfo6Odv6PvtwWTYcv1dPSgsLJSfn5+GaIw8TV42HXsk9dpan7fTjJ21Pra2ahuvM2I9+mRkrY7r8PgOO0cCZzlvFGuL3lFBQYFatmzp7HAaLXJk9ciRcIaGnCOvueYa9enTR+np6Za2bt266cYbb1RKSkq5/jNnztS6det0+PBhS1tcXJw+/fRT7dhRs/d0WS6rarw2fpGr+5bv08VfVMvmff50XZiWbstify33p4/voxvCg/XrufPq/sT7kqRD80aomfeF+bDqxr+ux7tC/JKc+hqdPYauvr9sfCpSkxxShhleAACARurcuXPau3evZs2aZdUeFRWl7du3V3jMjh07FBUVZdU2YsQIvfzyyyouLpaXl21/jKtISamhuesPXfgibBgyl5wr1+f1zf+RdxXTNuyvfL9JUsrq/RrWoaVKz5fIfL5IklT6668qPe+pklJDKav3y/u/7fY+3hXin//2Pkkmp71GZ4+hq+4v8vCWYTLJJGnu+kMa3j2ozpc3U/ACAAA0UidOnFBJSYkCAwOt2gMDA5WXl1fhMXl5eRX2P3/+vE6cOKHg4PIzMkVFRSoq+t8X98LCwirj2pV1ynKJo7nknNa+O6dGrwe2ObLiwj/X/vdxzu/GeUk9HF9X9XF+R5/D2WPoam4c/RcVeZplSMotOKtdWacU2bF1tcdVpYl9QgMAAEBDZTJZz6AYhlGurbr+FbWXSUlJkZ+fn2ULDQ2tMp7802er3A+gcbBHLmCGFwAAoJFq06aNPDw8ys3m5ufnl5vFLRMUFFRhf09PT7VuXfFMTGJiouW3yaULM7xVFb0Bvj6Wfy/y8NaNo/9S7WuB7TIm9dM1Yf7l2v+ddUoTX93l8OPrqj7O7+hzOHsMXU2Rh7fV49/ngtqi4AUAAGikvL29FRERoczMTN10002W9szMTI0ZM6bCYyIjI7V+/Xqrtg8++EB9+/at9P5ds9kss9lc47j6hfkr2M9HeQVnZZhMKvIsf2wTk2QYKrfgDfur32+SFOTno37dLlWTCu6P7Netqfxb+10Yfwcc7wrxB7Y0SzLpeKFzXqOzx9DV91vGp4I/BtiKS5oBAAAasYSEBC1btkyvvPKKDh8+rAcffFDZ2dmW39VNTEzUhAkTLP3j4uJ07NgxJSQk6PDhw3rllVf08ssv6+GHH7ZbTB5NTEqK6S7pfyu2ljH9d7t3UBj7a7lfkpJiule6GFB141/X410h/uQ/9lDyH533Gh39/A19v1T1+NiCghcAAKARGzdunFJTUzVv3jz17t1b27Zt04YNG9S+fXtJUm5urrKzsy39w8LCtGHDBm3ZskW9e/fWk08+qRdeeMHuv8F7Q3iw0sf3UZCf9SWNQX4+Sh/fR4kju7O/Dvur+43T6sa/rse7QvzOfo3OHkNX38/v8FaA35h0HH5jEg35NybdCTmyeuRIOAM50ja2/IZmSamhXVmnlH/6rAJ8L1zi+PtZH/bXbX91HH28K8Tv7Nfo6Odv6Psrwu/wAgAAwC14NDFV+bMk7K/b/uo4+nhXiN/Zr9HRz9/Q99cVlzQDAAAAANwSBS8AAAAAwC253CXNZbcUn1dx5WtYV6L0bO1/mPi8UVzrY2urtvESK5zhvC78t3Sx2/4bHXJk9YgVzkCOtE3ZOBUWFjo5EgANUVnuqEnOdblFq7777rsqf4gcgHPl5OTosssuc3YYjRY5EnBt5MiaIZcBsIea5FyXK3hLS0v1ww8/yNfXVyZT+dW5CgsLFRoaqpycHFZBrADjUzXGp2pVjY9hGDp9+rRCQkLUpAl3QzhLVTmS93fVGJ+qMT5Vq258yJG2qe773sV4f9YN41d3jGHd2Hv8bMm5LndJc5MmTWr0l9GWLVvyZqsC41M1xqdqlY2Pn5+fE6LB79UkR/L+rhrjUzXGp2pVjQ85suZq+n3vYrw/64bxqzvGsG7sOX41zbn8CRIAAAAA4JYoeAEAAAAAbqnBFbxms1lJSUkym83ODsUlMT5VY3yqxvg0bPz3qxrjUzXGp2qMj3Mx/nXD+NUdY1g3zhw/l1u0CgAAAAAAe2hwM7wAAAAAANQEBS8AAAAAwC1R8AIAAAAA3BIFLwAAAADALTWogjctLU1hYWHy8fFRRESEPvroI2eH5DKSk5NlMpmstqCgIGeH5TTbtm1TTEyMQkJCZDKZtHbtWqv9hmEoOTlZISEhatq0qYYMGaKDBw86J1gnqG58Jk6cWO79dO211zonWNQI+bFy5Edr5MeqkR9dEzmuZvh8101KSoquvvpq+fr6KiAgQDfeeKO+/PJLqz6MYeXS09N15ZVXqmXLlmrZsqUiIyP13nvvWfY7a+waTMG7atUqzZgxQ3PmzNH+/fs1aNAgRUdHKzs729mhuYwePXooNzfXsn3++efODslpzpw5o169emnx4sUV7l+wYIEWLlyoxYsXa/fu3QoKCtLw4cN1+vTpeo7UOaobH0m64YYbrN5PGzZsqMcIYQvyY/XIj/9Dfqwa+dH1kONqjs933WzdulXTpk3Tzp07lZmZqfPnzysqKkpnzpyx9GEMK3fZZZfp6aef1p49e7Rnzx5df/31GjNmjKWoddrYGQ1Ev379jLi4OKu2rl27GrNmzXJSRK4lKSnJ6NWrl7PDcEmSjDVr1lgel5aWGkFBQcbTTz9taTt79qzh5+dnLFmyxAkROtfF42MYhhEbG2uMGTPGKfHAduTHqpEfK0d+rBr50TWQ42qHz3fd5efnG5KMrVu3GobBGNbGJZdcYixbtsypY9cgZnjPnTunvXv3Kioqyqo9KipK27dvd1JUrufrr79WSEiIwsLCdPvtt+vbb791dkguKSsrS3l5eVbvJ7PZrMGDB/N++p0tW7YoICBAXbp00b333qv8/Hxnh4QKkB9rhvxYM+THmiE/1h9ynP3w+bZdQUGBJMnf318SY2iLkpISvfnmmzpz5owiIyOdOnYNouA9ceKESkpKFBgYaNUeGBiovLw8J0XlWq655hq9/vrrev/99/X3v/9deXl56t+/v06ePOns0FxO2XuG91PloqOj9cYbb2jTpk167rnntHv3bl1//fUqKipydmi4CPmxeuTHmiM/Vo/8WL/IcfbD59s2hmEoISFBAwcOVHh4uCTGsCY+//xztWjRQmazWXFxcVqzZo26d+/u1LHzdOiz25nJZLJ6bBhGubbGKjo62vLvPXv2VGRkpDp27KjXXntNCQkJTozMdfF+qty4ceMs/x4eHq6+ffuqffv2+te//qWbb77ZiZGhMryfK0d+tB3vp8qRH52D96T9MJY1M336dH322Wf6+OOPy+1jDCt3xRVX6MCBA/r555/19ttvKzY2Vlu3brXsd8bYNYgZ3jZt2sjDw6Nc9Z+fn1/urwS4oHnz5urZs6e+/vprZ4ficspWZ+X9VHPBwcFq37497ycXRH60HfmxcuRH25EfHYscZz98vmvu/vvv17p167R582ZddtlllnbGsHre3t7q1KmT+vbtq5SUFPXq1Ut/+9vfnDp2DaLg9fb2VkREhDIzM63aMzMz1b9/fydF5dqKiop0+PBhBQcHOzsUlxMWFqagoCCr99O5c+e0detW3k+VOHnypHJycng/uSDyo+3Ij5UjP9qO/OhY5Dj74fNdPcMwNH36dK1evVqbNm1SWFiY1X7G0HaGYaioqMipY9dgLmlOSEjQ3Xffrb59+yoyMlJLly5Vdna24uLinB2aS3j44YcVExOjdu3aKT8/X0899ZQKCwsVGxvr7NCc4pdfftGRI0csj7OysnTgwAH5+/urXbt2mjFjhubPn6/OnTurc+fOmj9/vpo1a6Y777zTiVHXn6rGx9/fX8nJybrlllsUHByso0ePavbs2WrTpo1uuukmJ0aNypAfq0Z+tEZ+rBr50fWQ42qOz3fdTJs2TStWrNA777wjX19fy2ykn5+fmjZtKpPJxBhWYfbs2YqOjlZoaKhOnz6tN998U1u2bNHGjRudO3YOXQPazl588UWjffv2hre3t9GnTx/LEuEwjHHjxhnBwcGGl5eXERISYtx8883GwYMHnR2W02zevNmQVG6LjY01DOPCsvJJSUlGUFCQYTabjeuuu874/PPPnRt0PapqfH799VcjKirKaNu2reHl5WW0a9fOiI2NNbKzs50dNqpAfqwc+dEa+bFq5EfXRI6rGT7fdVPR2EkyXn31VUsfxrBy99xzj+Vz2rZtW2PYsGHGBx98YNnvrLEzGYZhOLakBgAAAACg/jWIe3gBAAAAALAVBS8AAAAAwC1R8AIAAAAA3BIFLwAAAADALVHwAgAAAADcEgUvAAAAAMAtUfACAAAAANwSBS8AAAAAwC1R8AIAAAAA3BIFLwAAAADALVHwAgAAAADcEgUvAAAAAMAt/X8SAc88VeiXWgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ind = 0\n", + "fig,(ax1,ax2,ax3)= plt.subplots(figsize=(12,3),nrows=1, ncols=3)\n", + "ax1.imshow(data[ind,:].detach().reshape(patch_size,patch_size))\n", + "ax1.set_title('Image Patch')\n", + "\n", + "ax2.imshow(reconstruction[ind,:].detach().reshape(patch_size,patch_size))\n", + "ax2.set_title('Reconstruction')\n", + "\n", + "ax3.stem(A[ind,:].reshape(-1))\n", + "ax3.set_title('Coefficients')\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sparsecoding", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "vscode": { + "interpreter": { + "hash": "f8dee6493e897117ffdd342ef2978c68dd4ef475114648de71e6f497d2bb56cc" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sparsecoding/inference.py b/sparsecoding/inference.py index b3c3cb2..d056111 100644 --- a/sparsecoding/inference.py +++ b/sparsecoding/inference.py @@ -932,3 +932,134 @@ def infer(self, data, dictionary): residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features] return coefficients.detach() + + +class CEL0(InferenceMethod): + def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=1e-2, return_all_coefficients="none", solver=None): + """ + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + coeff_lr : float, default=1e-3 + Update rate of coefficient dynamics + threshold : float, default=1e-2 + Threshold for non-linearity + return_all_coefficients : str, {"none", "active"}, default="none" + Returns all coefficients during inference procedure if not equal + to "none". If return_all_coefficients=="active", + active units (a) (output of thresholding function over u) returned. + User beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] https://arxiv.org/abs/2301.10002 + """ + super().__init__(solver) + self.threshold = threshold + self.coeff_lr = coeff_lr + self.n_iter = n_iter + self.return_all_coefficients = return_all_coefficients + self.dictionary_norms = None + + def threshold_nonlinearity(self, u, a=1): + ''' + CEL0 thresholding function: A continuous exact l0 penalty + + Note: It is assumed that the dictionary is normalized + + Parameters + ---------- + u : array-like, shape [batch_size, n_basis] + a : the norm of the column of the dictionary, default=1 + + Returns + ------- + re : array-like, shape [batch_size, n_basis] + + ''' + if a * self.coeff_lr < 1: + num = (np.abs(u) - torch.sqrt(2 * self.threshold) * a * self.coeff_lr) + num[num < 0] = 0 + den = 1 - a ** 2 * self.coeff_lr + re = np.sign(u) * np.minimum(np.abs(u), np.divide(num, den)) # * (a ** 2 * self.coeff_lr < 1) + return re + else: + # TODO: This is not the same as the paper + larger = u[np.abs(u) < torch.sqrt(2 * self.threshold * self.coeff_lr)] + equal = u[np.abs(u) == torch.sqrt(2 * self.threshold * self.coeff_lr)] + re = larger + equal + return re + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients using provided dictionary + + Parameters + ---------- + dictionary : array-like, shape [n_features, n_basis] + + data : array-like, shape [n_samples, n_features] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + Check for nans in coefficients on each iteration. Setting this to + False can speed up inference time. + + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # initialize + if coeff_0 is not None: + u = coeff_0.to(device) + else: + u = torch.zeros((batch_size, n_basis)).to(device) + + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + + self.dictionary_norms = torch.norm(dictionary, dim=0, keepdim=True).squeeze()[0] + assert self.dictionary_norms == 1, "Dictionary must be normalized" + + for i in range(self.n_iter): + # check return all + if self.return_all_coefficients != "none": + if self.return_all_coefficients == "active": + coefficients = torch.concat( + [coefficients, self.CEL0Thresholding(u).clone().unsqueeze(1)], dim=1) + else: + coefficients = torch.concat( + [coefficients, u.clone().unsqueeze(1)], dim=1) + + # compute new + # Step 1: Gradient descent on u + recon = u @ dictionary.T + residual = data - recon + dLda = residual @ dictionary + u = u + self.coeff_lr * dLda + + # Step 2: Thresholding + u = self.threshold_nonlinearity(u) + + if use_checknan: + self.checknan(u, "coefficients") + + # return active units if return_all_coefficients in ["none", "active"] + if self.return_all_coefficients == "active": + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) + else: + final_coefficients = u + coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1) + + return coefficients.squeeze() diff --git a/tests/inference/test_CEL0.py b/tests/inference/test_CEL0.py new file mode 100644 index 0000000..8f18939 --- /dev/null +++ b/tests/inference/test_CEL0.py @@ -0,0 +1,41 @@ +import unittest + +from sparsecoding import inference +from tests.testing_utilities import TestCase +from tests.inference.common import ( + DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE +) + + +class TestCEL0(TestCase): + def test_shape(self): + """ + Test that CEL0 inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(DATAS, DATASET): + inference_method = inference.CEL0(N_ITER) + a = inference_method.infer(data, DICTIONARY) + self.assertShapeEqual(a, dataset.weights) + + inference_method = inference.CEL0(N_ITER, return_all_coefficients=True) + a = inference_method.infer(data, DICTIONARY) + self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) + + def test_inference(self): + """ + Test that CEL0 inference recovers the correct weights. + """ + N_ITER = 1000 + + for (data, dataset) in zip(DATAS, DATASET): + inference_method = inference.CEL0(n_iter=N_ITER, coeff_lr=1e-1, threshold=5e-1) + + a = inference_method.infer(data, DICTIONARY) + + self.assertAllClose(a, dataset.weights, atol=5e-2, rtol=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/.DS_Store b/tutorials/.DS_Store new file mode 100644 index 0000000..3619f3f Binary files /dev/null and b/tutorials/.DS_Store differ