diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..2463912 Binary files /dev/null and b/.DS_Store differ diff --git a/examples/cel0_inference_bars_example.ipynb b/examples/cel0_inference_bars_example.ipynb new file mode 100644 index 0000000..cec360d --- /dev/null +++ b/examples/cel0_inference_bars_example.ipynb @@ -0,0 +1,235 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 25, + "id": "116276f7", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from sparsecoding import inference\n", + "\n", + "from sparsecoding.data.utils import load_bars_dictionary\n", + "from sparsecoding.visualization import plot_dictionary" + ] + }, + { + "cell_type": "markdown", + "id": "7c1e8bb3", + "metadata": {}, + "source": [ + "## Load bar dictionary\n", + "\n", + "A good way of evaluating whether or not a inference method is working correctly is by generating data from a known dictionary. In this notebook, this is done using a dictionary consisting of horizontal/vertical bars. This dictionary is provided in a pickle file in this rep" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "a9532c7e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAFOCAYAAAAWx6x6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHUUlEQVR4nO3dy2rkMBBA0fHQ///LmkWyUQiM1bb7dc9ZN8EILS5F2dnGGOMPAAAZf5/9AAAAPJYABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAEDMbe8Pt2278jkAADho7z94MwEEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBidn8H8H/2fnemwncTAYBXZQIIABAjAAEAYgQgAEDMaTuAdt7W2ZucuUMA8BgmgAAAMQIQACBGAAIAxJy2A8g6O29r7EzO3B8A7mUCCAAQIwABAGIEIABAjB1A3oadt3X2JmfuEMAXE0AAgBgBCAAQIwABAGIEIABAjJdA4IN56WGNl2Zm7g98LhNAAIAYAQgAECMAAQBi7AACfLPzts7e5Mwd4l2YAAIAxAhAAIAYAQgAEGMHEIC72XlbY2dy5v48jwkgAECMAAQAiBGAAAAxdgAB4EHsvK2zNzk76w6ZAAIAxAhAAIAYAQgAECMAAQBivAQCALwsL85cwwQQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAICY21l/aIxx1p/6CNu2PfsRAAB+ZQIIABAjAAEAYgQgAEDMaTuAdt7W2ZucuUMA8BgmgAAAMQIQACBGAAIAxAhAAICY014CYZ2XHtZ4aWbm/gBwLxNAAIAYAQgAECMAAQBi7ADyNuy8rbM3OXOHAL6YAAIAxAhAAIAYAQgAEGMHED6Ynbc1diZn7g98LhNAAIAYAQgAECMAAQBi7AACfLPzts7e5Mwd4l2YAAIAxAhAAIAYAQgAECMAAQBivAQCwN289LDGSzMz9+d5TAABAGIEIABAjAAEAIixAwgAD2LnbZ29ydlZd8gEEAAgRgACAMQIQACAGDuAAMDLsjd5DRNAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMSc9iHo3/5Zs483rvl5hs5vnTM8zhke4/yOc4bHOcNjCk1jAggAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAzDbGGLt+uG1XPwsAAAfszDoTQACAGgEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIOb27Ac422/fv/ENwzU/z9D5rXOGxznDY5zfcc7wOGd4zJVNYwIIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQIwABACIEYAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAxAhAAIEYAAgDECEAAgBgBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxAhAAIAYAQgAECMAAQBiBCAAQMw2xhi7frhtVz8LAAAH7Mw6E0AAgBoBCAAQIwABAGIEIABAjAAEAIgRgAAAMQIQACBGAAIAxNz2/nDvhwUBAHhtJoAAADECEAAgRgACAMQIQACAGAEIABAjAAEAYgQgAECMAAQAiBGAAAAx/wACY3rnm927RgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "patch_size = 16\n", + "n_basis = 2*patch_size\n", + "\n", + "# load bars dictionary \n", + "dictionary = load_bars_dictionary()\n", + "dictionary = dictionary / dictionary.norm(dim=0, keepdim=True) # assume the dictionary is normalized\n", + "patch_size = int(np.sqrt(dictionary.shape[0]))\n", + "n_basis = dictionary.shape[1]\n", + "\n", + "nrow = 8\n", + "fig,ax = plot_dictionary(dictionary,nrow=nrow,size=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "190116ab", + "metadata": {}, + "outputs": [], + "source": [ + "dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True).squeeze()[0]\n", + "assert dictionary_norms==1, \"Dictionary must be normalized\"" + ] + }, + { + "cell_type": "markdown", + "id": "291457ef", + "metadata": {}, + "source": [ + "## Generate random data from bars dictionary" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "93b5ed39", + "metadata": {}, + "outputs": [], + "source": [ + "n_samples = 100\n", + "min_coefficient_val = 0.8\n", + "\n", + "# generate coefficients\n", + "coefficients = torch.rand([n_samples,n_basis]) \n", + "coefficients[coefficients < min_coefficient_val] = 0\n", + "\n", + "# generate dataset\n", + "data = (dictionary@coefficients.t()).t()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "ad40de0a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAALSElEQVR4nO3dbU8cZR/GYbZS2gCC4kPiR1RTbIVieJCnImgptAmpfkQTbaoUWisge7812WtOZ7Mzuwv3cbz8ZzJzLS0/J72c2U632+1OAFB0Z9QLABhnIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgSTdQ9cWVkpzg8PDxtbzLBsbW0V569evSrOX758OfA1t7e3a80mJiYmmnoIanl5uZHztGV/f79ntra2NoKVDO7Zs2eNnavT6Qx8jqq/4zs7OwOfe9hWV1eL84ODg4HPXed3zZ0kQCCSAIFIAgQiCRCIJEBQe3ebm+fo6GjUS4gWFhZ6ZuO+5ipN7m5X/V8P/aja3W5i53zY1tfXi/N79+4N5fruJAECkQQIRBIgEEmAwMYNjJkmNm6qHre7iY8l/v3338V5E48lVm1w/Zs7SYBAJAECkQQIRBIgEEmAQCQBApEECEQSIBBJgEAkAQKRBAhqP7u9ublZnE9PTze2mGGpeja26itlP/3004GvWXpGtKmvjgXa404SIBBJgEAkAQKRBAhqb9zs7u4W54eHh40tZtSqNm5evnw58LlL31JX9QLUJl66CjTDnSRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJENR+6S43z+rq6qiXEK2trfXMzs/PR7ASqOZOEiAQSYBAJAECkQQIRBIgsLt9ix0cHIx6CdHs7GzPbNzXXGV/f3/US6Al7iQBApEECEQSIBBJgEAkAQKRBAhEEiAQSYBAJAECkQQIRBIgEEmAQCQBApEECEQSIBBJgEAkAQKRBAhEEiAQSYBAJAECkQQIan+l7NLSUnF+dXXV2GKGpeqzvHr1qjifmppq5Zqnp6cDnxdolztJgEAkAQKRBAhEEiAQSYCg9u720dFRcf78+fPGFjMsc3NzxXnV7vbLly8Hvub8/HzPrOpnd3x8PPD1gGa4kwQIRBIgEEmAQCQBApEECEQSIBBJgEAkAQKRBAhEEiDodLvd7qgXATCu3EkCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgSTdQ/sdDqtLWJ5ebln9uzZs4HPu7q6WpxfXl4W501c88GDB8X5yclJ7XM09bL4J0+eFOfff/997WOXlpaK86Ojo9rrePjwYXH+4sWL2ufo18bGRs9sb2+veGzVZ5yc7P31uLi4KB47Ozvbx+qylZWV4ryfP7c27e/v98zW1tZau97u7m5xvrm5OfC5nz59+p/HuJMECEQSIBBJgEAkAYLaGzfcPFX/sF3auKraADg9PS3O+9m4ef36dXHe5sZNaYOlauPmjz/+KM7v3r3bM6va9Ds+Pu5jddnh4WFxfv/+/drHtunjjz8e6jqqNsWauKaNG4ABiSRAIJIAgUgCBDZubrGqjZvSvOqJqtLTUBMTExMzMzO11/Ho0aPifG5urvY5+lV64ubOnfI9QT9P3FRt3HB7uZMECEQSIBBJgEAkAQKRBAjsbt9if/31V+15P8emeVvn6FcTn7G0u311dVU8dn5+vo/VZVXvQS29s7FqPW0qre/s7Ky161W9q/L9+/etXfPf3EkCBCIJEIgkQCCSAIFIAgS1d7dL35DWlNLzwQsLCwOft99vS2zimlXflvjFF18MfO5+NfHy1uvr6+K8n5fuVv2823zp7r1793pmVZ+xaod4VC/dPTg4KM5Lu+1Vx7bpww8/HOo6Sn9fm7pmna65kwQIRBIgEEmAQCQBgtobN1WPBjWh9G16z549G/i8VY9KVf3jexPX/PXXX4vzk5OT2ueo2nAChs+dJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQ1H7pLjAc3333XXG+srLSM7u4uGh7OT1K6zs9PR3q9SYmJibevXvX2jX/zZ0kQCCSAIFIAgQiCRCIJEBgdxvGzE8//VScT01N1T62TfPz80Ndx/T0dHHexDV//PHH/zzGnSRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEnt2GMfPo0aPa8/Pz87aXU2sdv//+e2vX+/bbb4vzP//8s7Vr/ps7SYBAJAECkQQIRBIgsHEDY+b58+fF+ezsbO1j2/TZZ58NdR0fffRRcd7ENY+Pj//zGHeSAIFIAgQiCRCIJEAgkgCB3e1bbHNzs/a80+kUj11aWirOZ2Zmaq+j6jG7ubm52ufo18bGRs/szp3yPUHVZ5yc7P31uLy8HGxh3DjuJAECkQQIRBIgEEmAQCQBArvbt9ju7m5x3u12e2ZPnjwpHvv27dvi/OjoqPY63rx5U5y/ePGi9jn6dX193TPb29srHnt2dlac3717t2dWtbtd5xlgbiZ3kgCBSAIEIgkQiCRAUHvj5ptvvmltEaVzN/H418OHD4vzi4uL4ryJay4uLhbnVY/EAePNby5AIJIAgUgCBCIJEIgkQNDplp5RA2BiYsKdJEAkkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQiCRAUPt7tzudTnG+s7PTM9va2uprEU2cY2lpqWd2dHTU1znGRVMvi6/6MxtnVX/upb8jTXnw4EFxfnJyUvscTb7gv+qzbm9vD3zu0s933H+2barz5+ZOEiAQSYBAJAECkQQIRBIgEEmAQCQBApEECEQSIBBJgKD2Y4nAcKyvrxfnl5eXA597Y2OjZ3Z1dTXweassLi4W53Nzc61ds2nuJAECkQQIRBIgEEmAQCQBArvbt9jjx49HvYS+raysFOfv3r1r7ZpVL4a9f/9+a9dMfvjhh77m/Zic7P2Vb+K8Vd68eVOcj8tLd/f29v7zGHeSAIFIAgQiCRCIJEAgkgCB3e1b7PDwcNRL6Nv09HRx3uZnef/+fXHezw7s06dPm1oOY8adJEAgkgCBSAIEIgkQiCRAIJIAgUgCBCIJEIgkQCCSAIFIAgQiCRCIJEAgkgCBSAIEIgkQeOnuLba9vT3qJfRta2urOO90Oq1ds+orZT/55JPWrsnN4U4SIBBJgEAkAQKRBAhs3NxiZ2dno15C387Pz4vzNj/LKK6ZVG1edbvdgc897M28xcXF4nxhYWGo6xiEO0mAQCQBApEECEQSIBBJgMDu9i32zz//jHoJfatac5uf5erqaujXTHZ2dvqat3W9Jrx+/bo4Pzk5ae2a/aiz2+9OEiAQSYBAJAECkQQIRBIgsLt9i7148WLUS+jb/Px8cd7mZ6naxe5nB/b58+dNLYcx404SIBBJgEAkAQKRBAhEEiCovbv9+PHj4nxlZaVnVvWm5ypNnGN5eblndueO/wYAg1ERgEAkAQKRBAhEEiCovXFzeHhYnM/OztY+tkoT57i+vu6ZHR0d9XWOcfH06dNGzvPll182cp5h+vrrr4vz3377rbVrfvXVV8V51ct4+f/iThIgEEmAQCQBApEECEQSIKi9u725uVl73u9XcTZxjqWlpZ7ZzMxMX+e4bX755ZdRL6Fvn3/+eXHe5meZnCz/GvRzzZ9//rmp5TBm3EkCBCIJEIgkQCCSAIFIAgS1d7d3d3eL8w8++KD2sVWaOMfbt297Zjf12e2dnZ1GzrO6utrIeYZpfX29OL+4uGjtmouLi8V56Z0C/P9xJwkQiCRAIJIAgUgCBLU3brh5Dg4ORr2Evk1NTRXnbX6Wqm/mPDk5qX2O/f39ppbDmHEnCRCIJEAgkgCBSAIEIgkQdLrdbnfUiwAYV+4kAQKRBAhEEiAQSYBAJAECkQQIRBIgEEmAQCQBgv8BRSl4vmIKgj4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figure = plt.figure(figsize=(4,4))\n", + "cols, rows = 3, 3\n", + "for i in range(1, cols * rows + 1):\n", + " sample_idx = torch.randint(len(data), size=(1,)).item()\n", + " img = (data[sample_idx])\n", + " figure.add_subplot(rows, cols, i)\n", + " plt.axis(\"off\")\n", + " plt.imshow(img.squeeze().reshape([patch_size,patch_size]), cmap=\"gray\")\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ba432bec", + "metadata": {}, + "source": [ + "## Reconstruct the data with CEL0 Algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "cb7fb67e", + "metadata": {}, + "outputs": [], + "source": [ + "# The CEL0 Algorithm\n", + "import time\n", + "\n", + "n_iter = 300\n", + "start = time.time()\n", + "cel0 = inference.CEL0(coeff_lr=1e-1,threshold=0.1,n_iter=n_iter)\n", + "A = cel0.infer(data, dictionary)\n", + "\n", + "reconstruction = (dictionary@A.t()).t()\n", + "\n", + "end = time.time()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "19040166", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running time of 300 iterations: 0.038336992263793945\n" + ] + } + ], + "source": [ + "print(\"Running time of\", n_iter, \"iterations:\", end - start)\n", + "# error = data - reconstruction\n", + "# print(\"Error\", error.norm(dim=0).mean())" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "db67eacc", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ind = 0\n", + "fig,(ax1,ax2,ax3)= plt.subplots(figsize=(12,3),nrows=1, ncols=3)\n", + "ax1.imshow(data[ind,:].detach().reshape(patch_size,patch_size))\n", + "ax1.set_title('Image Patch')\n", + "\n", + "ax2.imshow(reconstruction[ind,:].detach().reshape(patch_size,patch_size))\n", + "ax2.set_title('Reconstruction')\n", + "\n", + "ax3.stem(A[ind,:].reshape(-1))\n", + "ax3.set_title('Coefficients')\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sparsecoding", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "vscode": { + "interpreter": { + "hash": "f8dee6493e897117ffdd342ef2978c68dd4ef475114648de71e6f497d2bb56cc" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sparsecoding/inference.py b/sparsecoding/inference.py index b3c3cb2..d056111 100644 --- a/sparsecoding/inference.py +++ b/sparsecoding/inference.py @@ -932,3 +932,134 @@ def infer(self, data, dictionary): residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features] return coefficients.detach() + + +class CEL0(InferenceMethod): + def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=1e-2, return_all_coefficients="none", solver=None): + """ + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + coeff_lr : float, default=1e-3 + Update rate of coefficient dynamics + threshold : float, default=1e-2 + Threshold for non-linearity + return_all_coefficients : str, {"none", "active"}, default="none" + Returns all coefficients during inference procedure if not equal + to "none". If return_all_coefficients=="active", + active units (a) (output of thresholding function over u) returned. + User beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] https://arxiv.org/abs/2301.10002 + """ + super().__init__(solver) + self.threshold = threshold + self.coeff_lr = coeff_lr + self.n_iter = n_iter + self.return_all_coefficients = return_all_coefficients + self.dictionary_norms = None + + def threshold_nonlinearity(self, u, a=1): + ''' + CEL0 thresholding function: A continuous exact l0 penalty + + Note: It is assumed that the dictionary is normalized + + Parameters + ---------- + u : array-like, shape [batch_size, n_basis] + a : the norm of the column of the dictionary, default=1 + + Returns + ------- + re : array-like, shape [batch_size, n_basis] + + ''' + if a * self.coeff_lr < 1: + num = (np.abs(u) - torch.sqrt(2 * self.threshold) * a * self.coeff_lr) + num[num < 0] = 0 + den = 1 - a ** 2 * self.coeff_lr + re = np.sign(u) * np.minimum(np.abs(u), np.divide(num, den)) # * (a ** 2 * self.coeff_lr < 1) + return re + else: + # TODO: This is not the same as the paper + larger = u[np.abs(u) < torch.sqrt(2 * self.threshold * self.coeff_lr)] + equal = u[np.abs(u) == torch.sqrt(2 * self.threshold * self.coeff_lr)] + re = larger + equal + return re + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients using provided dictionary + + Parameters + ---------- + dictionary : array-like, shape [n_features, n_basis] + + data : array-like, shape [n_samples, n_features] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + Check for nans in coefficients on each iteration. Setting this to + False can speed up inference time. + + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # initialize + if coeff_0 is not None: + u = coeff_0.to(device) + else: + u = torch.zeros((batch_size, n_basis)).to(device) + + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + + self.dictionary_norms = torch.norm(dictionary, dim=0, keepdim=True).squeeze()[0] + assert self.dictionary_norms == 1, "Dictionary must be normalized" + + for i in range(self.n_iter): + # check return all + if self.return_all_coefficients != "none": + if self.return_all_coefficients == "active": + coefficients = torch.concat( + [coefficients, self.CEL0Thresholding(u).clone().unsqueeze(1)], dim=1) + else: + coefficients = torch.concat( + [coefficients, u.clone().unsqueeze(1)], dim=1) + + # compute new + # Step 1: Gradient descent on u + recon = u @ dictionary.T + residual = data - recon + dLda = residual @ dictionary + u = u + self.coeff_lr * dLda + + # Step 2: Thresholding + u = self.threshold_nonlinearity(u) + + if use_checknan: + self.checknan(u, "coefficients") + + # return active units if return_all_coefficients in ["none", "active"] + if self.return_all_coefficients == "active": + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) + else: + final_coefficients = u + coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1) + + return coefficients.squeeze() diff --git a/tests/inference/test_CEL0.py b/tests/inference/test_CEL0.py new file mode 100644 index 0000000..8f18939 --- /dev/null +++ b/tests/inference/test_CEL0.py @@ -0,0 +1,41 @@ +import unittest + +from sparsecoding import inference +from tests.testing_utilities import TestCase +from tests.inference.common import ( + DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE +) + + +class TestCEL0(TestCase): + def test_shape(self): + """ + Test that CEL0 inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(DATAS, DATASET): + inference_method = inference.CEL0(N_ITER) + a = inference_method.infer(data, DICTIONARY) + self.assertShapeEqual(a, dataset.weights) + + inference_method = inference.CEL0(N_ITER, return_all_coefficients=True) + a = inference_method.infer(data, DICTIONARY) + self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) + + def test_inference(self): + """ + Test that CEL0 inference recovers the correct weights. + """ + N_ITER = 1000 + + for (data, dataset) in zip(DATAS, DATASET): + inference_method = inference.CEL0(n_iter=N_ITER, coeff_lr=1e-1, threshold=5e-1) + + a = inference_method.infer(data, DICTIONARY) + + self.assertAllClose(a, dataset.weights, atol=5e-2, rtol=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/.DS_Store b/tutorials/.DS_Store new file mode 100644 index 0000000..3619f3f Binary files /dev/null and b/tutorials/.DS_Store differ