From 57759732c0bbe9c9e5a48adf6d427ebfcfeb1da2 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 10:28:04 -0800 Subject: [PATCH 01/20] move tests --- {tests => sparsecoding}/transforms/test_patch.py | 0 {tests => sparsecoding}/transforms/test_whiten.py | 0 tests/transforms/__init__.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {tests => sparsecoding}/transforms/test_patch.py (100%) rename {tests => sparsecoding}/transforms/test_whiten.py (100%) delete mode 100644 tests/transforms/__init__.py diff --git a/tests/transforms/test_patch.py b/sparsecoding/transforms/test_patch.py similarity index 100% rename from tests/transforms/test_patch.py rename to sparsecoding/transforms/test_patch.py diff --git a/tests/transforms/test_whiten.py b/sparsecoding/transforms/test_whiten.py similarity index 100% rename from tests/transforms/test_whiten.py rename to sparsecoding/transforms/test_whiten.py diff --git a/tests/transforms/__init__.py b/tests/transforms/__init__.py deleted file mode 100644 index e69de29..0000000 From e375013a4ba5b2f5b3bdded08e65b24cb559fcc8 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 10:28:18 -0800 Subject: [PATCH 02/20] import order --- tests/inference/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/inference/common.py b/tests/inference/common.py index 2305ed7..8e47d3c 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,7 +1,7 @@ import torch -from sparsecoding.priors import L0Prior, SpikeSlabPrior from sparsecoding.datasets import BarsDataset +from sparsecoding.priors import L0Prior, SpikeSlabPrior torch.manual_seed(1997) From 7ad6114d947c6e73ecd559a05be49000db0b07ab Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:03:28 -0800 Subject: [PATCH 03/20] dev requirements --- requirements-dev.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 requirements-dev.txt diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..4191a3f --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +pylint +pytest +pyright \ No newline at end of file From 5e0bd604cfee910ad77ae056c4a68099ab891328 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:04:15 -0800 Subject: [PATCH 04/20] create pytest fixtures --- conftest.py | 9 +++++++ sparsecoding/test_utils/__init__.py | 7 ++++++ sparsecoding/test_utils/asserts.py | 11 +++++++++ sparsecoding/test_utils/asserts_test.py | 26 ++++++++++++++++++++ sparsecoding/test_utils/constant_fixtures.py | 12 +++++++++ sparsecoding/test_utils/dataset_fixtures.py | 25 +++++++++++++++++++ sparsecoding/test_utils/model_fixtures.py | 12 +++++++++ sparsecoding/test_utils/prior_fixtures.py | 24 ++++++++++++++++++ 8 files changed, 126 insertions(+) create mode 100644 conftest.py create mode 100644 sparsecoding/test_utils/__init__.py create mode 100644 sparsecoding/test_utils/asserts.py create mode 100644 sparsecoding/test_utils/asserts_test.py create mode 100644 sparsecoding/test_utils/constant_fixtures.py create mode 100644 sparsecoding/test_utils/dataset_fixtures.py create mode 100644 sparsecoding/test_utils/model_fixtures.py create mode 100644 sparsecoding/test_utils/prior_fixtures.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..18ac5fd --- /dev/null +++ b/conftest.py @@ -0,0 +1,9 @@ +from sparsecoding.test_utils import (bars_datas_fixture, bars_datasets_fixture, + bars_dictionary_fixture, + dataset_size_fixture, patch_size_fixture, + priors_fixture) + +# We import and define all fixtures in this file. +# This allows users to avoid any dependency fixtures. +# NOTE: This means pytest should only be run from this directory. +__all__ = ['dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture'] diff --git a/sparsecoding/test_utils/__init__.py b/sparsecoding/test_utils/__init__.py new file mode 100644 index 0000000..7b8371b --- /dev/null +++ b/sparsecoding/test_utils/__init__.py @@ -0,0 +1,7 @@ +from .asserts import assert_allclose, assert_shape_equal +from .constant_fixtures import dataset_size_fixture, patch_size_fixture +from .dataset_fixtures import bars_datas_fixture, bars_datasets_fixture +from .model_fixtures import bars_dictionary_fixture +from .prior_fixtures import priors_fixture + +__all__ = ['assert_allclose', 'assert_shape_equal', 'dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture'] diff --git a/sparsecoding/test_utils/asserts.py b/sparsecoding/test_utils/asserts.py new file mode 100644 index 0000000..bc7056a --- /dev/null +++ b/sparsecoding/test_utils/asserts.py @@ -0,0 +1,11 @@ +import numpy as np + +# constants +DEFAULT_ATOL = 1e-6 +DEFAULT_RTOL = 1e-5 + +def assert_allclose(a: np.ndarray, b: np.ndarray, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL) -> None: + return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + +def assert_shape_equal(a: np.ndarray, b: np.ndarray) -> None: + assert a.shape == b.shape \ No newline at end of file diff --git a/sparsecoding/test_utils/asserts_test.py b/sparsecoding/test_utils/asserts_test.py new file mode 100644 index 0000000..912b0ad --- /dev/null +++ b/sparsecoding/test_utils/asserts_test.py @@ -0,0 +1,26 @@ + +import numpy as np +import torch + +from .asserts import assert_allclose, assert_shape_equal + + +def test_pytorch_all_close(): + result = torch.ones([10, 10]) + 1e-10 + expected = torch.ones([10, 10]) + assert_allclose(result, expected) + +def test_np_all_close(): + result = np.ones([100, 100]) + 1e-10 + expected = np.ones([100, 100]) + assert_allclose(result, expected) + +def test_assert_pytorch_shape_equal(): + a = torch.zeros([10, 10]) + b = torch.ones([10, 10]) + assert_shape_equal(a, b) + +def test_assert_np_shape_equal(): + a = np.zeros([100, 100]) + b = np.ones([100, 100]) + assert_shape_equal(a, b) diff --git a/sparsecoding/test_utils/constant_fixtures.py b/sparsecoding/test_utils/constant_fixtures.py new file mode 100644 index 0000000..eac6253 --- /dev/null +++ b/sparsecoding/test_utils/constant_fixtures.py @@ -0,0 +1,12 @@ +import pytest + +PATCH_SIZE = 8 +DATASET_SIZE = 1000 + +@pytest.fixture() +def patch_size_fixture() -> int: + return PATCH_SIZE + +@pytest.fixture() +def dataset_size_fixture() -> int: + return DATASET_SIZE \ No newline at end of file diff --git a/sparsecoding/test_utils/dataset_fixtures.py b/sparsecoding/test_utils/dataset_fixtures.py new file mode 100644 index 0000000..9e3dcc9 --- /dev/null +++ b/sparsecoding/test_utils/dataset_fixtures.py @@ -0,0 +1,25 @@ + +import pytest +import torch + +from sparsecoding.datasets import BarsDataset +from sparsecoding.priors import Prior + + +@pytest.fixture() +def bars_datasets_fixture(patch_size_fixture: int, dataset_size_fixture: int, priors_fixture: list[Prior]) -> list[BarsDataset]: + return [ + BarsDataset( + patch_size=patch_size_fixture, + dataset_size=dataset_size_fixture, + prior=prior, + ) + for prior in priors_fixture + ] + +@pytest.fixture() +def bars_datas_fixture(patch_size_fixture: int, dataset_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> list[torch.Tensor]: + return [ + dataset.data.reshape((dataset_size_fixture, patch_size_fixture * patch_size_fixture)) + for dataset in bars_datasets_fixture + ] \ No newline at end of file diff --git a/sparsecoding/test_utils/model_fixtures.py b/sparsecoding/test_utils/model_fixtures.py new file mode 100644 index 0000000..95573f6 --- /dev/null +++ b/sparsecoding/test_utils/model_fixtures.py @@ -0,0 +1,12 @@ + +import pytest +import torch + +from sparsecoding.datasets import BarsDataset + +torch.manual_seed(1997) + +@pytest.fixture() +def bars_dictionary_fixture(patch_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> torch.Tensor: + """Return a bars dataset basis reshaped to represent a dictionary.""" + return bars_datasets_fixture[0].basis.reshape((2 * patch_size_fixture, patch_size_fixture * patch_size_fixture)).T \ No newline at end of file diff --git a/sparsecoding/test_utils/prior_fixtures.py b/sparsecoding/test_utils/prior_fixtures.py new file mode 100644 index 0000000..ffa2ef1 --- /dev/null +++ b/sparsecoding/test_utils/prior_fixtures.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from sparsecoding.priors import L0Prior, Prior, SpikeSlabPrior + + +@pytest.fixture() +def priors_fixture(patch_size_fixture: int) -> list[Prior]: + return [ + SpikeSlabPrior( + dim=2 * patch_size_fixture, + p_spike=0.8, + scale=1.0, + positive_only=True, + ), + L0Prior( + prob_distr=( + torch.nn.functional.one_hot( + torch.tensor(1), + num_classes=2 * patch_size_fixture, + ).type(torch.float32) + ), + ), + ] \ No newline at end of file From ede331bd07e3e5cfb1627012ceb82186a78274bf Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:05:01 -0800 Subject: [PATCH 05/20] inference reorganization & tests --- sparsecoding/inference.py | 940 ------------------ sparsecoding/inference/__init__.py | 21 + sparsecoding/inference/iht.py | 83 ++ sparsecoding/inference/inference_method.py | 73 ++ sparsecoding/inference/ista.py | 127 +++ sparsecoding/inference/ista_test.py | 29 + sparsecoding/inference/lca.py | 167 ++++ sparsecoding/inference/lca_test.py | 41 + sparsecoding/inference/lsm.py | 147 +++ sparsecoding/inference/lsm_test.py | 30 + sparsecoding/inference/mp.py | 77 ++ sparsecoding/inference/omp.py | 86 ++ sparsecoding/inference/pytorch_optimizer.py | 81 ++ .../inference/pytorch_optimizer_test.py | 66 ++ sparsecoding/inference/vanilla.py | 125 +++ sparsecoding/inference/vanilla_test.py | 35 + tests/inference/__init__.py | 0 tests/inference/common.py | 41 - tests/inference/test_ISTA.py | 41 - tests/inference/test_LCA.py | 48 - tests/inference/test_LSM.py | 35 - tests/inference/test_PyTorchOptimizer.py | 72 -- tests/inference/test_Vanilla.py | 42 - 23 files changed, 1188 insertions(+), 1219 deletions(-) delete mode 100644 sparsecoding/inference.py create mode 100644 sparsecoding/inference/__init__.py create mode 100644 sparsecoding/inference/iht.py create mode 100644 sparsecoding/inference/inference_method.py create mode 100644 sparsecoding/inference/ista.py create mode 100644 sparsecoding/inference/ista_test.py create mode 100644 sparsecoding/inference/lca.py create mode 100644 sparsecoding/inference/lca_test.py create mode 100644 sparsecoding/inference/lsm.py create mode 100644 sparsecoding/inference/lsm_test.py create mode 100644 sparsecoding/inference/mp.py create mode 100644 sparsecoding/inference/omp.py create mode 100644 sparsecoding/inference/pytorch_optimizer.py create mode 100644 sparsecoding/inference/pytorch_optimizer_test.py create mode 100644 sparsecoding/inference/vanilla.py create mode 100644 sparsecoding/inference/vanilla_test.py delete mode 100644 tests/inference/__init__.py delete mode 100644 tests/inference/common.py delete mode 100644 tests/inference/test_ISTA.py delete mode 100644 tests/inference/test_LCA.py delete mode 100644 tests/inference/test_LSM.py delete mode 100644 tests/inference/test_PyTorchOptimizer.py delete mode 100644 tests/inference/test_Vanilla.py diff --git a/sparsecoding/inference.py b/sparsecoding/inference.py deleted file mode 100644 index 83ed11f..0000000 --- a/sparsecoding/inference.py +++ /dev/null @@ -1,940 +0,0 @@ -import numpy as np -import torch - - -class InferenceMethod: - """Base class for inference method.""" - - def __init__(self, solver): - """ - Parameters - ---------- - """ - self.solver = solver - - def initialize(self, a): - """Define initial coefficients. - - Parameters - ---------- - - Returns - ------- - """ - raise NotImplementedError - - def grad(self): - """Compute the gradient step. - - Parameters - ---------- - - Returns - ------- - """ - raise NotImplementedError - - def infer(self, dictionary, data, coeff_0=None, use_checknan=False): - """Infer the coefficients given a dataset and dictionary. - - Parameters - ---------- - dictionary : array-like, shape [n_features,n_basis] - - data : array-like, shape [n_samples,n_features] - - coeff_0 : array-like, shape [n_samples,n_basis], optional - Initial coefficient values. - use_checknan : bool, default=False - Check for nans in coefficients on each iteration - - Returns - ------- - coefficients : array-like, shape [n_samples,n_basis] - """ - raise NotImplementedError - - @staticmethod - def checknan(data=torch.tensor(0), name="data"): - """Check for nan values in data. - - Parameters - ---------- - data : array-like, optional - Data to check for nans - name : str, default="data" - Name to add to error, if one is thrown - - Raises - ------ - ValueError - If the nan found in data - """ - if torch.isnan(data).any(): - raise ValueError("InferenceMethod error: nan in %s." % (name)) - - -class LCA(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients="none", nonnegative=False): - """Method implemented according locally competative algorithm (LCA) - with the ideal soft thresholding function. - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run - coeff_lr : float, default=1e-3 - Update rate of coefficient dynamics - threshold : float, default=0.1 - Threshold for non-linearity - stop_early : bool, default=False - Stops dynamics early based on change in coefficents - epsilon : float, default=1e-2 - Only used if stop_early True, specifies criteria to stop dynamics - nonnegative : bool, default=False - Constrain coefficients to be nonnegative - return_all_coefficients : str, {"none", "membrane", "active"}, default="none" - Returns all coefficients during inference procedure if not equal - to "none". If return_all_coefficients=="membrane", membrane - potentials (u) returned. If return_all_coefficients=="active", - active units (a) (output of thresholding function over u) returned. - User beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Rozell, C. J., Johnson, D. H., Baraniuk, R. G., & Olshausen, - B. A. (2008). Sparse coding via thresholding and local competition - in neural circuits. Neural computation, 20(10), 2526-2563. - """ - super().__init__(solver) - self.threshold = threshold - self.coeff_lr = coeff_lr - self.stop_early = stop_early - self.epsilon = epsilon - self.n_iter = n_iter - self.nonnegative = nonnegative - if return_all_coefficients not in ["none", "membrane", "active"]: - raise ValueError("Invalid input for return_all_coefficients. Valid" - "inputs are: \"none\", \"membrane\", \"active\".") - self.return_all_coefficients = return_all_coefficients - - def threshold_nonlinearity(self, u): - """Soft threshhold function - - Parameters - ---------- - u : array-like, shape [batch_size, n_basis] - Membrane potentials - - Returns - ------- - a : array-like, shape [batch_size, n_basis] - Activations - """ - if self.nonnegative: - a = (u - self.threshold).clamp(min=0.) - else: - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a - return a - - def grad(self, b, G, u, a): - """Compute the gradient step on membrane potentials - - Parameters - ---------- - b : array-like, shape [batch_size, n_coefficients] - Driver signal for coefficients - G : array-like, shape [n_coefficients, n_coefficients] - Inhibition matrix - a : array-like, shape [batch_size, n_coefficients] - Currently active coefficients - - Returns - ------- - du : array-like, shape [batch_size, n_coefficients] - Gradient of membrane potentials - """ - du = b-u-(G@a.t()).t() - return du - - def infer(self, data, dictionary, coeff_0=None, use_checknan=False): - """Infer coefficients using provided dictionary - - Parameters - ---------- - dictionary : array-like, shape [n_features, n_basis] - - data : array-like, shape [n_samples, n_features] - - coeff_0 : array-like, shape [n_samples, n_basis], optional - Initial coefficient values - use_checknan : bool, default=False - Check for nans in coefficients on each iteration. Setting this to - False can speed up inference time. - - Returns - ------- - coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] - First case occurs if return_all_coefficients == "none". If - return_all_coefficients != "none", returned shape is second case. - Returned dimension along dim 1 can be less than n_iter when - stop_early==True and stopping criteria met. - """ - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # initialize - if coeff_0 is not None: - u = coeff_0.to(device) - else: - u = torch.zeros((batch_size, n_basis)).to(device) - - coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - - b = (dictionary.t()@data.t()).t() - G = dictionary.t()@dictionary-torch.eye(n_basis).to(device) - for i in range(self.n_iter): - # store old membrane potentials to evalute stop early condition - if self.stop_early: - old_u = u.clone().detach() - - # check return all - if self.return_all_coefficients != "none": - if self.return_all_coefficients == "active": - coefficients = torch.concat( - [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) - else: - coefficients = torch.concat( - [coefficients, u.clone().unsqueeze(1)], dim=1) - - # compute new - a = self.threshold_nonlinearity(u) - du = self.grad(b, G, u, a) - u = u + self.coeff_lr*du - - # check stopping condition - if self.stop_early: - relative_change_in_coeff = torch.linalg.norm(old_u - u)/torch.linalg.norm(old_u) - if relative_change_in_coeff < self.epsilon: - break - - if use_checknan: - self.checknan(u, "coefficients") - - # return active units if return_all_coefficients in ["none", "active"] - if self.return_all_coefficients == "membrane": - coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) - else: - final_coefficients = self.threshold_nonlinearity(u) - coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1) - - return coefficients.squeeze() - - -class Vanilla(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, sparsity_penalty=0.2, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients=False): - """Gradient descent with Euler's method on model in Olshausen & Field - (1997) with laplace prior over coefficients (corresponding to l-1 norm - penalty). - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run - coeff_lr : float, default=1e-3 - Update rate of coefficient dynamics - sparsity_penalty : float, default=0.2 - - stop_early : bool, default=False - Stops dynamics early based on change in coefficents - epsilon : float, default=1e-2 - Only used if stop_early True, specifies criteria to stop dynamics - return_all_coefficients : str, default=False - Returns all coefficients during inference procedure if True - User beware: If n_iter is large, setting this parameter to True - Can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Olshausen, B. A., & Field, D. J. (1997). Sparse coding with an - overcomplete basis set: A strategy employed by V1?. Vision research, - 37(23), 3311-3325. - """ - super().__init__(solver) - self.coeff_lr = coeff_lr - self.sparsity_penalty = sparsity_penalty - self.stop_early = stop_early - self.epsilon = epsilon - self.n_iter = n_iter - self.return_all_coefficients = return_all_coefficients - - def grad(self, residual, dictionary, a): - """Compute the gradient step on coefficients - - Parameters - ---------- - residual : array-like, shape [batch_size, n_features] - Residual between reconstructed image and original - dictionary : array-like, shape [n_features,n_coefficients] - Dictionary - a : array-like, shape [batch_size, n_coefficients] - Coefficients - - Returns - ------- - da : array-like, shape [batch_size, n_coefficients] - Gradient of membrane potentials - """ - da = (dictionary.t()@residual.t()).t() - \ - self.sparsity_penalty*torch.sign(a) - return da - - def infer(self, data, dictionary, coeff_0=None, use_checknan=False): - """Infer coefficients using provided dictionary - - Parameters - ---------- - dictionary : array-like, shape [n_features, n_basis] - Dictionary - data : array like, shape [n_samples, n_features] - - coeff_0 : array-like, shape [n_samples, n_basis], optional - Initial coefficient values - use_checknan : bool, default=False - check for nans in coefficients on each iteration. Setting this to - False can speed up inference time - - Returns - ------- - coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] - First case occurs if return_all_coefficients == "none". If - return_all_coefficients != "none", returned shape is second case. - Returned dimension along dim 1 can be less than n_iter when - stop_early==True and stopping criteria met. - """ - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # initialize - if coeff_0 is not None: - a = coeff_0.to(device) - else: - a = torch.rand((batch_size, n_basis)).to(device)-0.5 - - coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - - residual = data - (dictionary@a.t()).t() - for i in range(self.n_iter): - - if self.return_all_coefficients: - coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) - - if self.stop_early: - old_a = a.clone().detach() - - da = self.grad(residual, dictionary, a) - a = a + self.coeff_lr*da - - if self.stop_early: - if torch.linalg.norm(old_a - a)/torch.linalg.norm(old_a) < self.epsilon: - break - - residual = data - (dictionary@a.t()).t() - - if use_checknan: - self.checknan(a, "coefficients") - - coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) - return torch.squeeze(coefficients) - - -class ISTA(InferenceMethod): - def __init__(self, n_iter=100, sparsity_penalty=1e-2, stop_early=False, - epsilon=1e-2, solver=None, return_all_coefficients=False): - """Iterative shrinkage-thresholding algorithm for solving LASSO problems. - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run - sparsity_penalty : float, default=0.2 - - stop_early : bool, default=False - Stops dynamics early based on change in coefficents - epsilon : float, default=1e-2 - Only used if stop_early True, specifies criteria to stop dynamics - return_all_coefficients : str, default=False - Returns all coefficients during inference procedure if True - User beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Beck, A., & Teboulle, M. (2009). A fast iterative - shrinkage-thresholding algorithm for linear inverse problems. - SIAM journal on imaging sciences, 2(1), 183-202. - """ - super().__init__(solver) - self.n_iter = n_iter - self.sparsity_penalty = sparsity_penalty - self.stop_early = stop_early - self.epsilon = epsilon - self.coefficients = None - self.return_all_coefficients = return_all_coefficients - - def threshold_nonlinearity(self, u): - """Soft threshhold function - - Parameters - ---------- - u : array-likes, shape [batch_size, n_basis] - Membrane potentials - - Returns - ------- - a : array-like, shape [batch_size, n_basis] - activations - """ - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a - return a - - def infer(self, data, dictionary, coeff_0=None, use_checknan=False): - """Infer coefficients for each image in data using dictionary elements. - Uses ISTA (Beck & Taboulle 2009), equations 1.4 and 1.5. - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - - dictionary : array-like, shape [n_features, n_basis] - - coeff_0 : array-like, shape [n_samples, n_basis], optional - Initial coefficient values - use_checknan : bool, default=False - Check for nans in coefficients on each iteration. Setting this to - False can speed up inference time. - Returns - ------- - coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] - First case occurs if return_all_coefficients == "none". If - return_all_coefficients != "none", returned shape is second case. - Returned dimension along dim 1 can be less than n_iter when - stop_early==True and stopping criteria met. - """ - batch_size = data.shape[0] - n_basis = dictionary.shape[1] - device = dictionary.device - - # Calculate stepsize based on largest eigenvalue of - # dictionary.T @ dictionary. - lipschitz_constant = torch.linalg.eigvalsh( - torch.mm(dictionary.T, dictionary))[-1] - stepsize = 1. / lipschitz_constant - self.threshold = stepsize * self.sparsity_penalty - - # Initialize coefficients. - if coeff_0 is not None: - u = coeff_0.to(device) - else: - u = torch.zeros((batch_size, n_basis)).to(device) - coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - self.coefficients = self.threshold_nonlinearity(u) - residual = torch.mm(dictionary, self.coefficients.T).T - data - - for _ in range(self.n_iter): - if self.stop_early: - old_u = u.clone().detach() - - if self.return_all_coefficients: - coefficients = torch.concat([coefficients, - self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) - - u -= stepsize * torch.mm(residual, dictionary) - self.coefficients = self.threshold_nonlinearity(u) - - if self.stop_early: - # Stopping condition is function of change of the coefficients. - a_change = torch.mean( - torch.abs(old_u - u) / stepsize) - if a_change < self.epsilon: - break - - residual = torch.mm(dictionary, self.coefficients.T).T - data - u = self.coefficients - - if use_checknan: - self.checknan(u, "coefficients") - - coefficients = torch.concat([coefficients, self.coefficients.clone().unsqueeze(1)], dim=1) - return torch.squeeze(coefficients) - - -class LSM(InferenceMethod): - def __init__(self, n_iter=100, n_iter_LSM=6, beta=0.01, alpha=80.0, - sigma=0.005, sparse_threshold=10**-2, solver=None, - return_all_coefficients=False): - """Infer latent coefficients generating data given dictionary. - Method implemented according to "Group Sparse Coding with a Laplacian - Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010) - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run for an optimizer - n_iter_LSM : int, default=6 - Number of iterations to run the outer loop of LSM - beta : float, default=0.01 - LSM parameter used to update lambdas - alpha : float, default=80.0 - LSM parameter used to update lambdas - sigma : float, default=0.005 - LSM parameter used to compute the loss function - sparse_threshold : float, default=10**-2 - Threshold used to discard smallest coefficients in the final - solution SM parameter used to compute the loss function - return_all_coefficients : bool, default=False - Returns all coefficients during inference procedure if True - User beware: If n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Garrigues, P., & Olshausen, B. (2010). Group sparse coding with - a laplacian scale mixture prior. Advances in neural information - processing systems, 23. - """ - super().__init__(solver) - self.n_iter = n_iter - self.n_iter_LSM = n_iter_LSM - self.beta = beta - self.alpha = alpha - self.sigma = sigma - self.sparse_threshold = sparse_threshold - self.return_all_coefficients = return_all_coefficients - - def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma): - """Compute LSM loss according to equation (7) in (P. J. Garrigues & - B. A. Olshausen, 2010) - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - Data to be used in sparse coding - dictionary : array-like, shape [n_features, n_basis] - Dictionary to be used - coefficients : array-like, shape [batch_size, n_basis] - The current values of coefficients - lambdas : array-like, shape [batch_size, n_basis] - The current values of regularization coefficient for all basis - sigma : float, default=0.005 - LSM parameter used to compute the loss functions - - Returns - ------- - loss : array-like, shape [batch_size, 1] - Loss values for each data sample - """ - - # Compute loss - preds = torch.mm(dictionary, coefficients.t()).t() - mse_loss = (1/(2*(sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) - sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True) - loss = mse_loss + sparse_loss - return loss - - def infer(self, data, dictionary): - """Infer coefficients for each image in data using dict elements - dictionary using Laplacian Scale Mixture (LSM) - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - Data to be used in sparse coding - dictionary : array-like, shape [n_features, n_basis] - Dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like, shape [batch_size, n_basis] - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Initialize coefficients for the whole batch - coefficients = torch.zeros(batch_size, n_basis, device=device, requires_grad=True) - - # Set up optimizer - optimizer = torch.optim.Adam([coefficients], lr=1e-1) - - # Outer loop, set sparsity penalties (lambdas). - for i in range(self.n_iter_LSM): - # Compute the initial values of lambdas - lambdas = ( - (self.alpha + 1) - / (self.beta + torch.abs(coefficients.detach())) - ) - - # Inner loop, optimize coefficients w/ current sparsity penalties. - # Exits early if converged before `n_iter`s. - last_loss = None - for t in range(self.n_iter): - # compute LSM loss for the current iteration - loss = self.lsm_Loss( - data=data, - dictionary=dictionary, - coefficients=coefficients, - lambdas=lambdas, - sigma=self.sigma, - ) - loss = torch.sum(loss) - - # Backward pass: compute gradient and update model parameters. - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # Break if coefficients have converged. - if ( - last_loss is not None - and loss > 1.05 * last_loss - ): - break - - last_loss = loss - - # Sparsify the final solution by discarding the small coefficients - coefficients.data[torch.abs(coefficients.data) - < self.sparse_threshold] = 0 - - return coefficients.detach() - - -class PyTorchOptimizer(InferenceMethod): - def __init__(self, optimizer_f, loss_f, n_iter=100, solver=None): - """Infer coefficients using provided loss functional and optimizer - - Parameters - ---------- - optimizer : function handle - Pytorch optimizer handle have single parameter: - (coefficients) - where coefficients is of shape [batch_size, n_basis] - loss_f : function handle - Must have parameters: - (data, dictionary, coefficients) - where data is of shape [batch_size, n_features] - and loss_f must return tensor of size [batch_size,] - n_iter : int, default=100 - Number of iterations to run for an optimizer - solver : default=None - """ - super().__init__(solver) - self.optimizer_f = optimizer_f - self.loss_f = loss_f - self.n_iter = n_iter - - def infer(self, data, dictionary, coeff_0=None): - """Infer coefficients for each image in data using dict elements - dictionary by minimizing provided loss function with provided - optimizer. - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - Data to be used in sparse coding - - dictionary : array-like, shape [n_features, n_basis] - Dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like, shape [batch_size, n_basis] - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Initialize coefficients for the whole batch - - # initialize - if coeff_0 is not None: - coefficients = coeff_0.requires_grad_(True) - else: - coefficients = torch.zeros((batch_size, n_basis), requires_grad=True, device=device) - - optimizer = self.optimizer_f([coefficients]) - - for i in range(self.n_iter): - - # compute LSM loss for the current iteration - loss = self.loss_f( - data=data, - dictionary=dictionary, - coefficients=coefficients, - ) - - optimizer.zero_grad() - - # Backward pass: compute gradient of the loss with respect to - # model parameters - loss.backward(torch.ones((batch_size,), device=device)) - - # Calling the step function on an Optimizer makes an update to its - # parameters - optimizer.step() - - return coefficients.detach() - - -class IHT(InferenceMethod): - """ - Infer coefficients for each image in data using elements dictionary. - Method description can be traced to - "Iterative Hard Thresholding for Compressed Sensing" (T. Blumensath & M. E. Davies, 2009) - """ - - def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False): - ''' - - Parameters - ---------- - sparsity : scalar (1,) - Sparsity of the solution. The number of active coefficients will be set - to ceil(sparsity * data_dim) at the end of each iterative update. - n_iter : scalar (1,) default=100 - number of iterations to run for an inference method - return_all_coefficients : string (1,) default=False - returns all coefficients during inference procedure if True - user beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This function typically used for - debugging - solver : default=None - ''' - super().__init__(solver) - self.n_iter = n_iter - self.sparsity = sparsity - self.return_all_coefficients = return_all_coefficients - - def infer(self, data, dictionary): - """ - Infer coefficients for each image in data using dict elements dictionary using Iterative Hard Thresholding (IHT) - - Parameters - ---------- - data : array-like (batch_size, n_features) - data to be used in sparse coding - dictionary : array-like, (n_features, n_basis) - dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like (batch_size, n_basis) - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) - - # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) - - for _ in range(self.n_iter): - # Compute the prediction given the current coefficients - preds = coefficients @ dictionary.T # [batch_size, n_features] - - # Compute the residual - delta = data - preds # [batch_size, n_features] - - # Compute the similarity between the residual and the atoms in the dictionary - update = delta @ dictionary # [batch_size, n_basis] - coefficients = coefficients + update # [batch_size, n_basis] - - # Apply kWTA nonlinearity - topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1) - - # Reconstruct coefficients using the output of torch.topk - coefficients = ( - torch.sign(coefficients) - * torch.zeros(batch_size, n_basis, device=device).scatter_(1, indices, topK_values) - ) - - return coefficients.detach() - - -class MP(InferenceMethod): - """ - Infer coefficients for each image in data using elements dictionary. - Method description can be traced - to "Matching Pursuits with Time-Frequency Dictionaries" (S. G. Mallat & Z. Zhang, 1993) - """ - - def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' - - Parameters - ---------- - sparsity : scalar (1,) - sparsity of the solution - return_all_coefficients : string (1,) default=False - returns all coefficients during inference procedure if True - user beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This function typically used for - debugging - solver : default=None - ''' - super().__init__(solver) - self.sparsity = sparsity - self.return_all_coefficients = return_all_coefficients - - def infer(self, data, dictionary): - """ - Infer coefficients for each image in data using dict elements dictionary using Matching Pursuit (MP) - - Parameters - ---------- - data : array-like (batch_size, n_features) - data to be used in sparse coding - dictionary : array-like, (n_features, n_basis) - dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like (batch_size, n_basis) - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) - - # Get dictionary norms in case atoms are not normalized - dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) - - # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) - - residual = data.clone() # [batch_size, n_features] - - for _ in range(K): - # Select which (coefficient, basis function) pair to update using the inner product. - candidate_coefs = residual @ dictionary # [batch_size, n_basis] - top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] - - # Update the coefficient. - top_coefs = candidate_coefs[torch.arange(batch_size), top_coef_idxs] # [batch_size] - coefficients[torch.arange(batch_size), top_coef_idxs] = top_coefs - - # Explain away/subtract the chosen coefficient and corresponding basis from the residual. - top_coef_bases = dictionary[..., top_coef_idxs] # [n_features, batch_size] - residual = residual - top_coefs.reshape(batch_size, 1) * top_coef_bases.T # [batch_size, n_features] - - return coefficients.detach() - - -class OMP(InferenceMethod): - """ - Infer coefficients for each image in data using elements dictionary. - Method description can be traced to: - "Orthogonal Matching Pursuit: Recursive Function Approximation with Application to Wavelet Decomposition" - (Y. Pati & R. Rezaiifar & P. Krishnaprasad, 1993) - """ - - def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' - - Parameters - ---------- - sparsity : scalar (1,) - sparsity of the solution - return_all_coefficients : string (1,) default=False - returns all coefficients during inference procedure if True - user beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This function typically used for - debugging - solver : default=None - ''' - super().__init__(solver) - self.sparsity = sparsity - self.return_all_coefficients = return_all_coefficients - - def infer(self, data, dictionary): - """ - Infer coefficients for each image in data using dict elements dictionary using Orthogonal Matching Pursuit (OMP) - - Parameters - ---------- - data : array-like (batch_size, n_features) - data to be used in sparse coding - dictionary : array-like, (n_features, n_basis) - dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like (batch_size, n_basis) - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) - - # Get dictionary norms in case atoms are not normalized - dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) - - # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) - - residual = data.clone() # [batch_size, n_features] - - # The basis functions that are used to infer the coefficients will be updated each time. - used_basis_fns = torch.zeros((batch_size, n_basis), dtype=torch.bool) - - for t in range(K): - # Select which (coefficient, basis function) pair to update using the inner product. - candidate_coefs = residual @ dictionary # [batch_size, n_basis] - top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] - used_basis_fns[torch.arange(batch_size), top_coef_idxs] = True - - # Update the coefficients - used_dictionary = dictionary[..., used_basis_fns.nonzero()[:, 1]].reshape((n_features, batch_size, t + 1)) - - (used_coefficients, _, _, _) = torch.linalg.lstsq( - used_dictionary.permute((1, 0, 2)), # [batch_size, n_features, t + 1] - data.reshape(batch_size, n_features, 1), - ) # [batch_size, t + 1, 1] - coefficients[used_basis_fns] = used_coefficients.reshape(-1) - - # Update the residual. - residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features] - - return coefficients.detach() diff --git a/sparsecoding/inference/__init__.py b/sparsecoding/inference/__init__.py new file mode 100644 index 0000000..0cfa46d --- /dev/null +++ b/sparsecoding/inference/__init__.py @@ -0,0 +1,21 @@ +from .iht import IHT +from .inference_method import InferenceMethod +from .ista import ISTA +from .lca import LCA +from .lsm import LSM +from .mp import MP +from .omp import OMP +from .pytorch_optimizer import PyTorchOptimizer +from .vanilla import Vanilla + +__all__ = [ + 'IHT', + 'InferenceMethod', + 'ISTA', + 'LCA', + 'LSM', + 'MP', + 'OMP', + 'PyTorchOptimizer', + 'Vanilla' +] \ No newline at end of file diff --git a/sparsecoding/inference/iht.py b/sparsecoding/inference/iht.py new file mode 100644 index 0000000..e62dacd --- /dev/null +++ b/sparsecoding/inference/iht.py @@ -0,0 +1,83 @@ +import numpy as np +import torch + +from .inference_method import InferenceMethod + + +class IHT(InferenceMethod): + """ + Infer coefficients for each image in data using elements dictionary. + Method description can be traced to + "Iterative Hard Thresholding for Compressed Sensing" (T. Blumensath & M. E. Davies, 2009) + """ + + def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False): + ''' + + Parameters + ---------- + sparsity : scalar (1,) + Sparsity of the solution. The number of active coefficients will be set + to ceil(sparsity * data_dim) at the end of each iterative update. + n_iter : scalar (1,) default=100 + number of iterations to run for an inference method + return_all_coefficients : string (1,) default=False + returns all coefficients during inference procedure if True + user beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This function typically used for + debugging + solver : default=None + ''' + super().__init__(solver) + self.n_iter = n_iter + self.sparsity = sparsity + self.return_all_coefficients = return_all_coefficients + + def infer(self, data, dictionary): + """ + Infer coefficients for each image in data using dict elements dictionary using Iterative Hard Thresholding (IHT) + + Parameters + ---------- + data : array-like (batch_size, n_features) + data to be used in sparse coding + dictionary : array-like, (n_features, n_basis) + dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like (batch_size, n_basis) + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Define signal sparsity + K = np.ceil(self.sparsity*n_basis).astype(int) + + # Initialize coefficients for the whole batch + coefficients = torch.zeros( + batch_size, n_basis, requires_grad=False, device=device) + + for _ in range(self.n_iter): + # Compute the prediction given the current coefficients + preds = coefficients @ dictionary.T # [batch_size, n_features] + + # Compute the residual + delta = data - preds # [batch_size, n_features] + + # Compute the similarity between the residual and the atoms in the dictionary + update = delta @ dictionary # [batch_size, n_basis] + coefficients = coefficients + update # [batch_size, n_basis] + + # Apply kWTA nonlinearity + topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1) + + # Reconstruct coefficients using the output of torch.topk + coefficients = ( + torch.sign(coefficients) + * torch.zeros(batch_size, n_basis, device=device).scatter_(1, indices, topK_values) + ) + + return coefficients.detach() diff --git a/sparsecoding/inference/inference_method.py b/sparsecoding/inference/inference_method.py new file mode 100644 index 0000000..98ff0b8 --- /dev/null +++ b/sparsecoding/inference/inference_method.py @@ -0,0 +1,73 @@ +import torch + + +class InferenceMethod: + """Base class for inference method.""" + + def __init__(self, solver): + """ + Parameters + ---------- + """ + self.solver = solver + + def initialize(self, a): + """Define initial coefficients. + + Parameters + ---------- + + Returns + ------- + """ + raise NotImplementedError + + def grad(self): + """Compute the gradient step. + + Parameters + ---------- + + Returns + ------- + """ + raise NotImplementedError + + def infer(self, dictionary, data, coeff_0=None, use_checknan=False): + """Infer the coefficients given a dataset and dictionary. + + Parameters + ---------- + dictionary : array-like, shape [n_features,n_basis] + + data : array-like, shape [n_samples,n_features] + + coeff_0 : array-like, shape [n_samples,n_basis], optional + Initial coefficient values. + use_checknan : bool, default=False + Check for nans in coefficients on each iteration + + Returns + ------- + coefficients : array-like, shape [n_samples,n_basis] + """ + raise NotImplementedError + + @staticmethod + def checknan(data=torch.tensor(0), name="data"): + """Check for nan values in data. + + Parameters + ---------- + data : array-like, optional + Data to check for nans + name : str, default="data" + Name to add to error, if one is thrown + + Raises + ------ + ValueError + If the nan found in data + """ + if torch.isnan(data).any(): + raise ValueError("InferenceMethod error: nan in %s." % (name)) \ No newline at end of file diff --git a/sparsecoding/inference/ista.py b/sparsecoding/inference/ista.py new file mode 100644 index 0000000..5c1abfc --- /dev/null +++ b/sparsecoding/inference/ista.py @@ -0,0 +1,127 @@ +import torch + +from .inference_method import InferenceMethod + + +class ISTA(InferenceMethod): + def __init__(self, n_iter=100, sparsity_penalty=1e-2, stop_early=False, + epsilon=1e-2, solver=None, return_all_coefficients=False): + """Iterative shrinkage-thresholding algorithm for solving LASSO problems. + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + sparsity_penalty : float, default=0.2 + + stop_early : bool, default=False + Stops dynamics early based on change in coefficents + epsilon : float, default=1e-2 + Only used if stop_early True, specifies criteria to stop dynamics + return_all_coefficients : str, default=False + Returns all coefficients during inference procedure if True + User beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Beck, A., & Teboulle, M. (2009). A fast iterative + shrinkage-thresholding algorithm for linear inverse problems. + SIAM journal on imaging sciences, 2(1), 183-202. + """ + super().__init__(solver) + self.n_iter = n_iter + self.sparsity_penalty = sparsity_penalty + self.stop_early = stop_early + self.epsilon = epsilon + self.coefficients = None + self.return_all_coefficients = return_all_coefficients + + def threshold_nonlinearity(self, u): + """Soft threshhold function + + Parameters + ---------- + u : array-likes, shape [batch_size, n_basis] + Membrane potentials + + Returns + ------- + a : array-like, shape [batch_size, n_basis] + activations + """ + a = (torch.abs(u) - self.threshold).clamp(min=0.) + a = torch.sign(u)*a + return a + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients for each image in data using dictionary elements. + Uses ISTA (Beck & Taboulle 2009), equations 1.4 and 1.5. + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + + dictionary : array-like, shape [n_features, n_basis] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + Check for nans in coefficients on each iteration. Setting this to + False can speed up inference time. + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size = data.shape[0] + n_basis = dictionary.shape[1] + device = dictionary.device + + # Calculate stepsize based on largest eigenvalue of + # dictionary.T @ dictionary. + lipschitz_constant = torch.linalg.eigvalsh( + torch.mm(dictionary.T, dictionary))[-1] + stepsize = 1. / lipschitz_constant + self.threshold = stepsize * self.sparsity_penalty + + # Initialize coefficients. + if coeff_0 is not None: + u = coeff_0.to(device) + else: + u = torch.zeros((batch_size, n_basis)).to(device) + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + self.coefficients = self.threshold_nonlinearity(u) + residual = torch.mm(dictionary, self.coefficients.T).T - data + + for _ in range(self.n_iter): + if self.stop_early: + old_u = u.clone().detach() + + if self.return_all_coefficients: + coefficients = torch.concat([coefficients, + self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) + + u -= stepsize * torch.mm(residual, dictionary) + self.coefficients = self.threshold_nonlinearity(u) + + if self.stop_early: + # Stopping condition is function of change of the coefficients. + a_change = torch.mean( + torch.abs(old_u - u) / stepsize) + if a_change < self.epsilon: + break + + residual = torch.mm(dictionary, self.coefficients.T).T - data + u = self.coefficients + + if use_checknan: + self.checknan(u, "coefficients") + + coefficients = torch.concat([coefficients, self.coefficients.clone().unsqueeze(1)], dim=1) + return torch.squeeze(coefficients) diff --git a/sparsecoding/inference/ista_test.py b/sparsecoding/inference/ista_test.py new file mode 100644 index 0000000..6a99f5a --- /dev/null +++ b/sparsecoding/inference/ista_test.py @@ -0,0 +1,29 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """Test that ISTA inference returns expected shapes.""" + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.ISTA(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + inference_method = inference.ISTA(N_ITER, return_all_coefficients=True) + a = inference_method.infer(data, bars_dictionary_fixture) + assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) + +def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """Test that ISTA inference recovers the correct weights.""" + N_ITER = 5000 + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.ISTA(n_iter=N_ITER) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/inference/lca.py b/sparsecoding/inference/lca.py new file mode 100644 index 0000000..49ef1a5 --- /dev/null +++ b/sparsecoding/inference/lca.py @@ -0,0 +1,167 @@ +import torch + +from .inference_method import InferenceMethod + + +class LCA(InferenceMethod): + def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, + stop_early=False, epsilon=1e-2, solver=None, + return_all_coefficients="none", nonnegative=False): + """Method implemented according locally competative algorithm (LCA) + with the ideal soft thresholding function. + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + coeff_lr : float, default=1e-3 + Update rate of coefficient dynamics + threshold : float, default=0.1 + Threshold for non-linearity + stop_early : bool, default=False + Stops dynamics early based on change in coefficents + epsilon : float, default=1e-2 + Only used if stop_early True, specifies criteria to stop dynamics + nonnegative : bool, default=False + Constrain coefficients to be nonnegative + return_all_coefficients : str, {"none", "membrane", "active"}, default="none" + Returns all coefficients during inference procedure if not equal + to "none". If return_all_coefficients=="membrane", membrane + potentials (u) returned. If return_all_coefficients=="active", + active units (a) (output of thresholding function over u) returned. + User beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Rozell, C. J., Johnson, D. H., Baraniuk, R. G., & Olshausen, + B. A. (2008). Sparse coding via thresholding and local competition + in neural circuits. Neural computation, 20(10), 2526-2563. + """ + super().__init__(solver) + self.threshold = threshold + self.coeff_lr = coeff_lr + self.stop_early = stop_early + self.epsilon = epsilon + self.n_iter = n_iter + self.nonnegative = nonnegative + if return_all_coefficients not in ["none", "membrane", "active"]: + raise ValueError("Invalid input for return_all_coefficients. Valid" + "inputs are: \"none\", \"membrane\", \"active\".") + self.return_all_coefficients = return_all_coefficients + + def threshold_nonlinearity(self, u): + """Soft threshhold function + + Parameters + ---------- + u : array-like, shape [batch_size, n_basis] + Membrane potentials + + Returns + ------- + a : array-like, shape [batch_size, n_basis] + Activations + """ + if self.nonnegative: + a = (u - self.threshold).clamp(min=0.) + else: + a = (torch.abs(u) - self.threshold).clamp(min=0.) + a = torch.sign(u)*a + return a + + def grad(self, b, G, u, a): + """Compute the gradient step on membrane potentials + + Parameters + ---------- + b : array-like, shape [batch_size, n_coefficients] + Driver signal for coefficients + G : array-like, shape [n_coefficients, n_coefficients] + Inhibition matrix + a : array-like, shape [batch_size, n_coefficients] + Currently active coefficients + + Returns + ------- + du : array-like, shape [batch_size, n_coefficients] + Gradient of membrane potentials + """ + du = b-u-(G@a.t()).t() + return du + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients using provided dictionary + + Parameters + ---------- + dictionary : array-like, shape [n_features, n_basis] + + data : array-like, shape [n_samples, n_features] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + Check for nans in coefficients on each iteration. Setting this to + False can speed up inference time. + + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # initialize + if coeff_0 is not None: + u = coeff_0.to(device) + else: + u = torch.zeros((batch_size, n_basis)).to(device) + + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + + b = (dictionary.t()@data.t()).t() + G = dictionary.t()@dictionary-torch.eye(n_basis).to(device) + for i in range(self.n_iter): + # store old membrane potentials to evalute stop early condition + if self.stop_early: + old_u = u.clone().detach() + + # check return all + if self.return_all_coefficients != "none": + if self.return_all_coefficients == "active": + coefficients = torch.concat( + [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) + else: + coefficients = torch.concat( + [coefficients, u.clone().unsqueeze(1)], dim=1) + + # compute new + a = self.threshold_nonlinearity(u) + du = self.grad(b, G, u, a) + u = u + self.coeff_lr*du + + # check stopping condition + if self.stop_early: + relative_change_in_coeff = torch.linalg.norm(old_u - u)/torch.linalg.norm(old_u) + if relative_change_in_coeff < self.epsilon: + break + + if use_checknan: + self.checknan(u, "coefficients") + + # return active units if return_all_coefficients in ["none", "active"] + if self.return_all_coefficients == "membrane": + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) + else: + final_coefficients = self.threshold_nonlinearity(u) + coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1) + + return coefficients.squeeze() diff --git a/sparsecoding/inference/lca_test.py b/sparsecoding/inference/lca_test.py new file mode 100644 index 0000000..2fe955f --- /dev/null +++ b/sparsecoding/inference/lca_test.py @@ -0,0 +1,41 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that LCA inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LCA(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + for retval in ["active", "membrane"]: + inference_method = inference.LCA(N_ITER, return_all_coefficients=retval) + a = inference_method.infer(data, bars_dictionary_fixture) + assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) + +def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that LCA inference recovers the correct weights. + """ + LR = 5e-2 + THRESHOLD = 0.1 + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LCA( + coeff_lr=LR, + threshold=THRESHOLD, + n_iter=N_ITER, + ) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) \ No newline at end of file diff --git a/sparsecoding/inference/lsm.py b/sparsecoding/inference/lsm.py new file mode 100644 index 0000000..c3faaf6 --- /dev/null +++ b/sparsecoding/inference/lsm.py @@ -0,0 +1,147 @@ +import torch + +from .inference_method import InferenceMethod + + +class LSM(InferenceMethod): + def __init__(self, n_iter=100, n_iter_LSM=6, beta=0.01, alpha=80.0, + sigma=0.005, sparse_threshold=10**-2, solver=None, + return_all_coefficients=False): + """Infer latent coefficients generating data given dictionary. + Method implemented according to "Group Sparse Coding with a Laplacian + Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010) + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run for an optimizer + n_iter_LSM : int, default=6 + Number of iterations to run the outer loop of LSM + beta : float, default=0.01 + LSM parameter used to update lambdas + alpha : float, default=80.0 + LSM parameter used to update lambdas + sigma : float, default=0.005 + LSM parameter used to compute the loss function + sparse_threshold : float, default=10**-2 + Threshold used to discard smallest coefficients in the final + solution SM parameter used to compute the loss function + return_all_coefficients : bool, default=False + Returns all coefficients during inference procedure if True + User beware: If n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Garrigues, P., & Olshausen, B. (2010). Group sparse coding with + a laplacian scale mixture prior. Advances in neural information + processing systems, 23. + """ + super().__init__(solver) + self.n_iter = n_iter + self.n_iter_LSM = n_iter_LSM + self.beta = beta + self.alpha = alpha + self.sigma = sigma + self.sparse_threshold = sparse_threshold + self.return_all_coefficients = return_all_coefficients + + def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma): + """Compute LSM loss according to equation (7) in (P. J. Garrigues & + B. A. Olshausen, 2010) + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + Data to be used in sparse coding + dictionary : array-like, shape [n_features, n_basis] + Dictionary to be used + coefficients : array-like, shape [batch_size, n_basis] + The current values of coefficients + lambdas : array-like, shape [batch_size, n_basis] + The current values of regularization coefficient for all basis + sigma : float, default=0.005 + LSM parameter used to compute the loss functions + + Returns + ------- + loss : array-like, shape [batch_size, 1] + Loss values for each data sample + """ + + # Compute loss + preds = torch.mm(dictionary, coefficients.t()).t() + mse_loss = (1/(2*(sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) + sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True) + loss = mse_loss + sparse_loss + return loss + + def infer(self, data, dictionary): + """Infer coefficients for each image in data using dict elements + dictionary using Laplacian Scale Mixture (LSM) + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + Data to be used in sparse coding + dictionary : array-like, shape [n_features, n_basis] + Dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like, shape [batch_size, n_basis] + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Initialize coefficients for the whole batch + coefficients = torch.zeros(batch_size, n_basis, device=device, requires_grad=True) + + # Set up optimizer + optimizer = torch.optim.Adam([coefficients], lr=1e-1) + + # Outer loop, set sparsity penalties (lambdas). + for i in range(self.n_iter_LSM): + # Compute the initial values of lambdas + lambdas = ( + (self.alpha + 1) + / (self.beta + torch.abs(coefficients.detach())) + ) + + # Inner loop, optimize coefficients w/ current sparsity penalties. + # Exits early if converged before `n_iter`s. + last_loss = None + for t in range(self.n_iter): + # compute LSM loss for the current iteration + loss = self.lsm_Loss( + data=data, + dictionary=dictionary, + coefficients=coefficients, + lambdas=lambdas, + sigma=self.sigma, + ) + loss = torch.sum(loss) + + # Backward pass: compute gradient and update model parameters. + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Break if coefficients have converged. + if ( + last_loss is not None + and loss > 1.05 * last_loss + ): + break + + last_loss = loss + + # Sparsify the final solution by discarding the small coefficients + coefficients.data[torch.abs(coefficients.data) + < self.sparse_threshold] = 0 + + return coefficients.detach() diff --git a/sparsecoding/inference/lsm_test.py b/sparsecoding/inference/lsm_test.py new file mode 100644 index 0000000..b7e4b05 --- /dev/null +++ b/sparsecoding/inference/lsm_test.py @@ -0,0 +1,30 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that LSM inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LSM(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + +def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that LSM inference recovers the correct weights. + """ + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LSM(n_iter=N_ITER) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/inference/mp.py b/sparsecoding/inference/mp.py new file mode 100644 index 0000000..2d413fc --- /dev/null +++ b/sparsecoding/inference/mp.py @@ -0,0 +1,77 @@ +import numpy as np +import torch + +from .inference_method import InferenceMethod + + +class MP(InferenceMethod): + """ + Infer coefficients for each image in data using elements dictionary. + Method description can be traced + to "Matching Pursuits with Time-Frequency Dictionaries" (S. G. Mallat & Z. Zhang, 1993) + """ + + def __init__(self, sparsity, solver=None, return_all_coefficients=False): + ''' + + Parameters + ---------- + sparsity : scalar (1,) + sparsity of the solution + return_all_coefficients : string (1,) default=False + returns all coefficients during inference procedure if True + user beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This function typically used for + debugging + solver : default=None + ''' + super().__init__(solver) + self.sparsity = sparsity + self.return_all_coefficients = return_all_coefficients + + def infer(self, data, dictionary): + """ + Infer coefficients for each image in data using dict elements dictionary using Matching Pursuit (MP) + + Parameters + ---------- + data : array-like (batch_size, n_features) + data to be used in sparse coding + dictionary : array-like, (n_features, n_basis) + dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like (batch_size, n_basis) + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Define signal sparsity + K = np.ceil(self.sparsity*n_basis).astype(int) + + # Get dictionary norms in case atoms are not normalized + dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) + + # Initialize coefficients for the whole batch + coefficients = torch.zeros( + batch_size, n_basis, requires_grad=False, device=device) + + residual = data.clone() # [batch_size, n_features] + + for _ in range(K): + # Select which (coefficient, basis function) pair to update using the inner product. + candidate_coefs = residual @ dictionary # [batch_size, n_basis] + top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] + + # Update the coefficient. + top_coefs = candidate_coefs[torch.arange(batch_size), top_coef_idxs] # [batch_size] + coefficients[torch.arange(batch_size), top_coef_idxs] = top_coefs + + # Explain away/subtract the chosen coefficient and corresponding basis from the residual. + top_coef_bases = dictionary[..., top_coef_idxs] # [n_features, batch_size] + residual = residual - top_coefs.reshape(batch_size, 1) * top_coef_bases.T # [batch_size, n_features] + + return coefficients.detach() diff --git a/sparsecoding/inference/omp.py b/sparsecoding/inference/omp.py new file mode 100644 index 0000000..33c49d7 --- /dev/null +++ b/sparsecoding/inference/omp.py @@ -0,0 +1,86 @@ +import numpy as np +import torch + +from .inference_method import InferenceMethod + + +class OMP(InferenceMethod): + """ + Infer coefficients for each image in data using elements dictionary. + Method description can be traced to: + "Orthogonal Matching Pursuit: Recursive Function Approximation with Application to Wavelet Decomposition" + (Y. Pati & R. Rezaiifar & P. Krishnaprasad, 1993) + """ + + def __init__(self, sparsity, solver=None, return_all_coefficients=False): + ''' + + Parameters + ---------- + sparsity : scalar (1,) + sparsity of the solution + return_all_coefficients : string (1,) default=False + returns all coefficients during inference procedure if True + user beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This function typically used for + debugging + solver : default=None + ''' + super().__init__(solver) + self.sparsity = sparsity + self.return_all_coefficients = return_all_coefficients + + def infer(self, data, dictionary): + """ + Infer coefficients for each image in data using dict elements dictionary using Orthogonal Matching Pursuit (OMP) + + Parameters + ---------- + data : array-like (batch_size, n_features) + data to be used in sparse coding + dictionary : array-like, (n_features, n_basis) + dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like (batch_size, n_basis) + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Define signal sparsity + K = np.ceil(self.sparsity*n_basis).astype(int) + + # Get dictionary norms in case atoms are not normalized + dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) + + # Initialize coefficients for the whole batch + coefficients = torch.zeros( + batch_size, n_basis, requires_grad=False, device=device) + + residual = data.clone() # [batch_size, n_features] + + # The basis functions that are used to infer the coefficients will be updated each time. + used_basis_fns = torch.zeros((batch_size, n_basis), dtype=torch.bool) + + for t in range(K): + # Select which (coefficient, basis function) pair to update using the inner product. + candidate_coefs = residual @ dictionary # [batch_size, n_basis] + top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] + used_basis_fns[torch.arange(batch_size), top_coef_idxs] = True + + # Update the coefficients + used_dictionary = dictionary[..., used_basis_fns.nonzero()[:, 1]].reshape((n_features, batch_size, t + 1)) + + (used_coefficients, _, _, _) = torch.linalg.lstsq( + used_dictionary.permute((1, 0, 2)), # [batch_size, n_features, t + 1] + data.reshape(batch_size, n_features, 1), + ) # [batch_size, t + 1, 1] + coefficients[used_basis_fns] = used_coefficients.reshape(-1) + + # Update the residual. + residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features] + + return coefficients.detach() diff --git a/sparsecoding/inference/pytorch_optimizer.py b/sparsecoding/inference/pytorch_optimizer.py new file mode 100644 index 0000000..7c13079 --- /dev/null +++ b/sparsecoding/inference/pytorch_optimizer.py @@ -0,0 +1,81 @@ +import torch + +from .inference_method import InferenceMethod + + +class PyTorchOptimizer(InferenceMethod): + def __init__(self, optimizer_f, loss_f, n_iter=100, solver=None): + """Infer coefficients using provided loss functional and optimizer + + Parameters + ---------- + optimizer : function handle + Pytorch optimizer handle have single parameter: + (coefficients) + where coefficients is of shape [batch_size, n_basis] + loss_f : function handle + Must have parameters: + (data, dictionary, coefficients) + where data is of shape [batch_size, n_features] + and loss_f must return tensor of size [batch_size,] + n_iter : int, default=100 + Number of iterations to run for an optimizer + solver : default=None + """ + super().__init__(solver) + self.optimizer_f = optimizer_f + self.loss_f = loss_f + self.n_iter = n_iter + + def infer(self, data, dictionary, coeff_0=None): + """Infer coefficients for each image in data using dict elements + dictionary by minimizing provided loss function with provided + optimizer. + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + Data to be used in sparse coding + + dictionary : array-like, shape [n_features, n_basis] + Dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like, shape [batch_size, n_basis] + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Initialize coefficients for the whole batch + + # initialize + if coeff_0 is not None: + coefficients = coeff_0.requires_grad_(True) + else: + coefficients = torch.zeros((batch_size, n_basis), requires_grad=True, device=device) + + optimizer = self.optimizer_f([coefficients]) + + for i in range(self.n_iter): + + # compute LSM loss for the current iteration + loss = self.loss_f( + data=data, + dictionary=dictionary, + coefficients=coefficients, + ) + + optimizer.zero_grad() + + # Backward pass: compute gradient of the loss with respect to + # model parameters + loss.backward(torch.ones((batch_size,), device=device)) + + # Calling the step function on an Optimizer makes an update to its + # parameters + optimizer.step() + + return coefficients.detach() diff --git a/sparsecoding/inference/pytorch_optimizer_test.py b/sparsecoding/inference/pytorch_optimizer_test.py new file mode 100644 index 0000000..f90fd66 --- /dev/null +++ b/sparsecoding/inference/pytorch_optimizer_test.py @@ -0,0 +1,66 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def lasso_loss(data, dictionary, coefficients, sparsity_penalty): + """ + Generic MSE + l1-norm loss. + """ + batch_size = data.shape[0] + datahat = (dictionary@coefficients.t()).t() + + mse_loss = torch.linalg.vector_norm(datahat-data, dim=1).square() + sparse_loss = torch.sum(torch.abs(coefficients), axis=1) + + total_loss = (mse_loss + sparsity_penalty*sparse_loss)/batch_size + return total_loss + +def loss_fn(data, dictionary, coefficients): + return lasso_loss( + data, + dictionary, + coefficients, + sparsity_penalty=1., + ) + +def optimizer_fn(coefficients): + return torch.optim.Adam( + coefficients, + lr=0.1, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0, + ) + +def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that PyTorchOptimizer inference returns expected shapes. + """ + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.PyTorchOptimizer( + optimizer_fn, + loss_fn, + n_iter=10, + ) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + +def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that PyTorchOptimizer inference recovers the correct weights. + """ + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.PyTorchOptimizer( + optimizer_fn, + loss_fn, + n_iter=N_ITER, + ) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=1e-1, rtol=1e-1) diff --git a/sparsecoding/inference/vanilla.py b/sparsecoding/inference/vanilla.py new file mode 100644 index 0000000..473fcb8 --- /dev/null +++ b/sparsecoding/inference/vanilla.py @@ -0,0 +1,125 @@ +import torch + +from .inference_method import InferenceMethod + + +class Vanilla(InferenceMethod): + def __init__(self, n_iter=100, coeff_lr=1e-3, sparsity_penalty=0.2, + stop_early=False, epsilon=1e-2, solver=None, + return_all_coefficients=False): + """Gradient descent with Euler's method on model in Olshausen & Field + (1997) with laplace prior over coefficients (corresponding to l-1 norm + penalty). + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + coeff_lr : float, default=1e-3 + Update rate of coefficient dynamics + sparsity_penalty : float, default=0.2 + + stop_early : bool, default=False + Stops dynamics early based on change in coefficents + epsilon : float, default=1e-2 + Only used if stop_early True, specifies criteria to stop dynamics + return_all_coefficients : str, default=False + Returns all coefficients during inference procedure if True + User beware: If n_iter is large, setting this parameter to True + Can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Olshausen, B. A., & Field, D. J. (1997). Sparse coding with an + overcomplete basis set: A strategy employed by V1?. Vision research, + 37(23), 3311-3325. + """ + super().__init__(solver) + self.coeff_lr = coeff_lr + self.sparsity_penalty = sparsity_penalty + self.stop_early = stop_early + self.epsilon = epsilon + self.n_iter = n_iter + self.return_all_coefficients = return_all_coefficients + + def grad(self, residual, dictionary, a): + """Compute the gradient step on coefficients + + Parameters + ---------- + residual : array-like, shape [batch_size, n_features] + Residual between reconstructed image and original + dictionary : array-like, shape [n_features,n_coefficients] + Dictionary + a : array-like, shape [batch_size, n_coefficients] + Coefficients + + Returns + ------- + da : array-like, shape [batch_size, n_coefficients] + Gradient of membrane potentials + """ + da = (dictionary.t()@residual.t()).t() - \ + self.sparsity_penalty*torch.sign(a) + return da + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients using provided dictionary + + Parameters + ---------- + dictionary : array-like, shape [n_features, n_basis] + Dictionary + data : array like, shape [n_samples, n_features] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + check for nans in coefficients on each iteration. Setting this to + False can speed up inference time + + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # initialize + if coeff_0 is not None: + a = coeff_0.to(device) + else: + a = torch.rand((batch_size, n_basis)).to(device)-0.5 + + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + + residual = data - (dictionary@a.t()).t() + for i in range(self.n_iter): + + if self.return_all_coefficients: + coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) + + if self.stop_early: + old_a = a.clone().detach() + + da = self.grad(residual, dictionary, a) + a = a + self.coeff_lr*da + + if self.stop_early: + if torch.linalg.norm(old_a - a)/torch.linalg.norm(old_a) < self.epsilon: + break + + residual = data - (dictionary@a.t()).t() + + if use_checknan: + self.checknan(a, "coefficients") + + coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) + return torch.squeeze(coefficients) diff --git a/sparsecoding/inference/vanilla_test.py b/sparsecoding/inference/vanilla_test.py new file mode 100644 index 0000000..57c2244 --- /dev/null +++ b/sparsecoding/inference/vanilla_test.py @@ -0,0 +1,35 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that Vanilla inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.Vanilla(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + inference_method = inference.Vanilla(N_ITER, return_all_coefficients=True) + a = inference_method.infer(data, bars_dictionary_fixture) + assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) + +def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + """ + Test that Vanilla inference recovers the correct weights. + """ + LR = 5e-2 + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.Vanilla(coeff_lr=LR, n_iter=N_ITER) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/inference/common.py b/tests/inference/common.py deleted file mode 100644 index 8e47d3c..0000000 --- a/tests/inference/common.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - -from sparsecoding.datasets import BarsDataset -from sparsecoding.priors import L0Prior, SpikeSlabPrior - -torch.manual_seed(1997) - -PATCH_SIZE = 8 -DATASET_SIZE = 1000 - -PRIORS = [ - SpikeSlabPrior( - dim=2 * PATCH_SIZE, - p_spike=0.8, - scale=1.0, - positive_only=True, - ), - L0Prior( - prob_distr=( - torch.nn.functional.one_hot( - torch.tensor(1), - num_classes=2 * PATCH_SIZE, - ).type(torch.float32) - ), - ), -] - -DATASET = [ - BarsDataset( - patch_size=PATCH_SIZE, - dataset_size=DATASET_SIZE, - prior=prior, - ) - for prior in PRIORS -] - -DATAS = [ - dataset.data.reshape((DATASET_SIZE, PATCH_SIZE * PATCH_SIZE)) - for dataset in DATASET -] -DICTIONARY = DATASET[0].basis.reshape((2 * PATCH_SIZE, PATCH_SIZE * PATCH_SIZE)).T diff --git a/tests/inference/test_ISTA.py b/tests/inference/test_ISTA.py deleted file mode 100644 index 0edb50e..0000000 --- a/tests/inference/test_ISTA.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import ( - DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE -) - - -class TestISTA(TestCase): - def test_shape(self): - """ - Test that ISTA inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.ISTA(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - inference_method = inference.ISTA(N_ITER, return_all_coefficients=True) - a = inference_method.infer(data, DICTIONARY) - self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) - - def test_inference(self): - """ - Test that ISTA inference recovers the correct weights. - """ - N_ITER = 5000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.ISTA(n_iter=N_ITER) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_LCA.py b/tests/inference/test_LCA.py deleted file mode 100644 index 8628377..0000000 --- a/tests/inference/test_LCA.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import ( - DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE -) - - -class TestLCA(TestCase): - def test_shape(self): - """ - Test that LCA inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LCA(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - for retval in ["active", "membrane"]: - inference_method = inference.LCA(N_ITER, return_all_coefficients=retval) - a = inference_method.infer(data, DICTIONARY) - self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) - - def test_inference(self): - """ - Test that LCA inference recovers the correct weights. - """ - LR = 5e-2 - THRESHOLD = 0.1 - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LCA( - coeff_lr=LR, - threshold=THRESHOLD, - n_iter=N_ITER, - ) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_LSM.py b/tests/inference/test_LSM.py deleted file mode 100644 index 0084daf..0000000 --- a/tests/inference/test_LSM.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import DATAS, DATASET, DICTIONARY - - -class TestLSM(TestCase): - def test_shape(self): - """ - Test that LSM inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LSM(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - def test_inference(self): - """ - Test that LSM inference recovers the correct weights. - """ - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LSM(n_iter=N_ITER) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_PyTorchOptimizer.py b/tests/inference/test_PyTorchOptimizer.py deleted file mode 100644 index 3c69885..0000000 --- a/tests/inference/test_PyTorchOptimizer.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import DATAS, DATASET, DICTIONARY - - -class TestPyTorchOptimizer(TestCase): - def lasso_loss(data, dictionary, coefficients, sparsity_penalty): - """ - Generic MSE + l1-norm loss. - """ - batch_size = data.shape[0] - datahat = (dictionary@coefficients.t()).t() - - mse_loss = torch.linalg.vector_norm(datahat-data, dim=1).square() - sparse_loss = torch.sum(torch.abs(coefficients), axis=1) - - total_loss = (mse_loss + sparsity_penalty*sparse_loss)/batch_size - return total_loss - - def loss_fn(data, dictionary, coefficients): - return TestPyTorchOptimizer.lasso_loss( - data, - dictionary, - coefficients, - sparsity_penalty=1., - ) - - def optimizer_fn(coefficients): - return torch.optim.Adam( - coefficients, - lr=0.1, - betas=(0.9, 0.999), - eps=1e-08, - weight_decay=0, - ) - - def test_shape(self): - """ - Test that PyTorchOptimizer inference returns expected shapes. - """ - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.PyTorchOptimizer( - TestPyTorchOptimizer.optimizer_fn, - TestPyTorchOptimizer.loss_fn, - n_iter=10, - ) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - def test_inference(self): - """ - Test that PyTorchOptimizer inference recovers the correct weights. - """ - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.PyTorchOptimizer( - TestPyTorchOptimizer.optimizer_fn, - TestPyTorchOptimizer.loss_fn, - n_iter=N_ITER, - ) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=1e-1, rtol=1e-1) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_Vanilla.py b/tests/inference/test_Vanilla.py deleted file mode 100644 index af9ae2f..0000000 --- a/tests/inference/test_Vanilla.py +++ /dev/null @@ -1,42 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import ( - DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE -) - - -class TestVanilla(TestCase): - def test_shape(self): - """ - Test that Vanilla inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.Vanilla(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - inference_method = inference.Vanilla(N_ITER, return_all_coefficients=True) - a = inference_method.infer(data, DICTIONARY) - self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) - - def test_inference(self): - """ - Test that Vanilla inference recovers the correct weights. - """ - LR = 5e-2 - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.Vanilla(coeff_lr=LR, n_iter=N_ITER) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() From e75aa74492b8afe2d4b9a7f6accc7a1c2aac0c68 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:05:15 -0800 Subject: [PATCH 06/20] priors reorganize & tests --- sparsecoding/priors.py | 167 ------------------- sparsecoding/priors/__init__.py | 9 + sparsecoding/priors/l0_prior.py | 65 ++++++++ sparsecoding/priors/lo_prior_test.py | 35 ++++ sparsecoding/priors/prior.py | 34 ++++ sparsecoding/priors/spike_slab_prior.py | 72 ++++++++ sparsecoding/priors/spike_slab_prior_test.py | 54 ++++++ tests/priors/__init__.py | 0 tests/priors/test_l0.py | 41 ----- tests/priors/test_spike_slab.py | 59 ------- 10 files changed, 269 insertions(+), 267 deletions(-) delete mode 100644 sparsecoding/priors.py create mode 100644 sparsecoding/priors/__init__.py create mode 100644 sparsecoding/priors/l0_prior.py create mode 100644 sparsecoding/priors/lo_prior_test.py create mode 100644 sparsecoding/priors/prior.py create mode 100644 sparsecoding/priors/spike_slab_prior.py create mode 100644 sparsecoding/priors/spike_slab_prior_test.py delete mode 100644 tests/priors/__init__.py delete mode 100644 tests/priors/test_l0.py delete mode 100644 tests/priors/test_spike_slab.py diff --git a/sparsecoding/priors.py b/sparsecoding/priors.py deleted file mode 100644 index 026ac43..0000000 --- a/sparsecoding/priors.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -from torch.distributions.laplace import Laplace - -from abc import ABC, abstractmethod - - -class Prior(ABC): - """A distribution over weights. - - Parameters - ---------- - weights_dim : int - Number of weights for each sample. - """ - @abstractmethod - def D(self): - """ - Number of weights per sample. - """ - - @abstractmethod - def sample( - self, - num_samples: int = 1, - ): - """Sample weights from the prior. - - Parameters - ---------- - num_samples : int, default=1 - Number of samples. - - Returns - ------- - samples : Tensor, shape [num_samples, self.D] - Sampled weights. - """ - - -class SpikeSlabPrior(Prior): - """Prior where weights are drawn from a "spike-and-slab" distribution. - - The "spike" is at 0 and the "slab" is Laplacian. - - See: - https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf - for a good review of the spike-and-slab model. - - Parameters - ---------- - dim : int - Number of weights per sample. - p_spike : float - The probability of the weight being 0. - scale : float - The "scale" of the Laplacian distribution (larger is wider). - positive_only : bool - Ensure that the weights are positive by taking the absolute value - of weights sampled from the Laplacian. - """ - - def __init__( - self, - dim: int, - p_spike: float, - scale: float, - positive_only: bool = True, - ): - if dim < 0: - raise ValueError(f"`dim` should be nonnegative, got {dim}.") - if p_spike < 0 or p_spike > 1: - raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") - if scale <= 0: - raise ValueError(f"`scale` must be positive, got {scale}.") - - self.dim = dim - self.p_spike = p_spike - self.scale = scale - self.positive_only = positive_only - - @property - def D(self): - return self.dim - - def sample(self, num_samples: int): - N = num_samples - - zero_weights = torch.zeros((N, self.D), dtype=torch.float32) - slab_weights = Laplace( - loc=zero_weights, - scale=torch.full((N, self.D), self.scale, dtype=torch.float32), - ).sample() # [N, D] - - if self.positive_only: - slab_weights = torch.abs(slab_weights) - - spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike - - weights = torch.where( - spike_over_slab, - zero_weights, - slab_weights, - ) - - return weights - - -class L0Prior(Prior): - """Prior with a distribution over the l0-norm of the weights. - - A class of priors where the weights are binary; - the distribution is over the l0-norm of the weight vector - (how many weights are active). - - Parameters - ---------- - prob_distr : Tensor, shape [D], dtype float32 - Probability distribution over the l0-norm of the weights. - """ - - def __init__( - self, - prob_distr: torch.Tensor, - ): - if prob_distr.dim() != 1: - raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") - if prob_distr.dtype != torch.float32: - raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") - if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): - raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") - - self.prob_distr = prob_distr - - @property - def D(self): - return self.prob_distr.shape[0] - - def sample( - self, - num_samples: int - ): - N = num_samples - - num_active_weights = 1 + torch.multinomial( - input=self.prob_distr, - num_samples=num_samples, - replacement=True, - ) # [N] - - d_idxs = torch.arange(self.D) - active_idx_mask = ( - d_idxs.reshape(1, self.D) - < num_active_weights.reshape(N, 1) - ) # [N, self.D] - - n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] - # Need to shuffle here so that it's not always the first weights that are active. - shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] - shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] - - # [num_active_weights], [num_active_weights] - active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] - - weights = torch.zeros((N, self.D), dtype=torch.float32) - weights[active_weight_idxs] += 1. - - return weights diff --git a/sparsecoding/priors/__init__.py b/sparsecoding/priors/__init__.py new file mode 100644 index 0000000..94e9994 --- /dev/null +++ b/sparsecoding/priors/__init__.py @@ -0,0 +1,9 @@ +from .l0_prior import L0Prior +from .prior import Prior +from .spike_slab_prior import SpikeSlabPrior + +__all__ = [ + "Prior", + "L0Prior", + "SpikeSlabPrior", +] \ No newline at end of file diff --git a/sparsecoding/priors/l0_prior.py b/sparsecoding/priors/l0_prior.py new file mode 100644 index 0000000..1e41853 --- /dev/null +++ b/sparsecoding/priors/l0_prior.py @@ -0,0 +1,65 @@ +import torch + +from .prior import Prior + + +class L0Prior(Prior): + """Prior with a distribution over the l0-norm of the weights. + + A class of priors where the weights are binary; + the distribution is over the l0-norm of the weight vector + (how many weights are active). + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + + self.prob_distr = prob_distr + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample( + self, + num_samples: int + ): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = ( + d_idxs.reshape(1, self.D) + < num_active_weights.reshape(N, 1) + ) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + weights[active_weight_idxs] += 1. + + return weights diff --git a/sparsecoding/priors/lo_prior_test.py b/sparsecoding/priors/lo_prior_test.py new file mode 100644 index 0000000..9dbb0e4 --- /dev/null +++ b/sparsecoding/priors/lo_prior_test.py @@ -0,0 +1,35 @@ +import torch + +from sparsecoding.priors import L0Prior + + +def test_l0_prior(): + N = 10000 + prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) + + torch.manual_seed(1997) + + D = prob_distr.shape[0] + + l0_prior = L0Prior(prob_distr) + weights = l0_prior.sample(N) + + assert weights.shape == (N, D) + + # Check uniform distribution over which weights are active. + per_weight_hist = torch.sum(weights, dim=0) # [D] + normalized_per_weight_hist = per_weight_hist / torch.sum(per_weight_hist) # [D] + assert torch.allclose( + normalized_per_weight_hist, + torch.full((D,), 0.25, dtype=torch.float32), + atol=1e-2, + ) + + # Check the distribution over the l0-norm of the weights. + num_active_per_sample = torch.sum(weights, dim=1) # [N] + for num_active in range(1, 5): + assert torch.allclose( + torch.sum(num_active_per_sample == num_active) / N, + prob_distr[num_active - 1], + atol=1e-2, + ) diff --git a/sparsecoding/priors/prior.py b/sparsecoding/priors/prior.py new file mode 100644 index 0000000..4940661 --- /dev/null +++ b/sparsecoding/priors/prior.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod + + +class Prior(ABC): + """A distribution over weights. + + Parameters + ---------- + weights_dim : int + Number of weights for each sample. + """ + @abstractmethod + def D(self): + """ + Number of weights per sample. + """ + + @abstractmethod + def sample( + self, + num_samples: int = 1, + ): + """Sample weights from the prior. + + Parameters + ---------- + num_samples : int, default=1 + Number of samples. + + Returns + ------- + samples : Tensor, shape [num_samples, self.D] + Sampled weights. + """ diff --git a/sparsecoding/priors/spike_slab_prior.py b/sparsecoding/priors/spike_slab_prior.py new file mode 100644 index 0000000..1fa059a --- /dev/null +++ b/sparsecoding/priors/spike_slab_prior.py @@ -0,0 +1,72 @@ +import torch +from torch.distributions.laplace import Laplace + +from .prior import Prior + + +class SpikeSlabPrior(Prior): + """Prior where weights are drawn from a "spike-and-slab" distribution. + + The "spike" is at 0 and the "slab" is Laplacian. + + See: + https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf + for a good review of the spike-and-slab model. + + Parameters + ---------- + dim : int + Number of weights per sample. + p_spike : float + The probability of the weight being 0. + scale : float + The "scale" of the Laplacian distribution (larger is wider). + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + p_spike: float, + scale: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if p_spike < 0 or p_spike > 1: + raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") + if scale <= 0: + raise ValueError(f"`scale` must be positive, got {scale}.") + + self.dim = dim + self.p_spike = p_spike + self.scale = scale + self.positive_only = positive_only + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + N = num_samples + + zero_weights = torch.zeros((N, self.D), dtype=torch.float32) + slab_weights = Laplace( + loc=zero_weights, + scale=torch.full((N, self.D), self.scale, dtype=torch.float32), + ).sample() # [N, D] + + if self.positive_only: + slab_weights = torch.abs(slab_weights) + + spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike + + weights = torch.where( + spike_over_slab, + zero_weights, + slab_weights, + ) + + return weights diff --git a/sparsecoding/priors/spike_slab_prior_test.py b/sparsecoding/priors/spike_slab_prior_test.py new file mode 100644 index 0000000..f1023ac --- /dev/null +++ b/sparsecoding/priors/spike_slab_prior_test.py @@ -0,0 +1,54 @@ +import pytest +import torch + +from sparsecoding.priors import SpikeSlabPrior + + +@pytest.mark.parametrize("positive_only", [True, False]) +def test_spike_slab_prior(positive_only: bool): + N = 10000 + D = 4 + p_spike = 0.5 + scale = 1. + + torch.manual_seed(1997) + + p_slab = 1. - p_spike + + spike_slab_prior = SpikeSlabPrior( + D, + p_spike, + scale, + positive_only, + ) + weights = spike_slab_prior.sample(N) + + assert weights.shape == (N, D) + + # Check spike probability. + assert torch.allclose( + torch.sum(weights == 0.) / (N * D), + torch.tensor(p_spike), + atol=1e-2, + ) + + # Check Laplacian distribution. + N_slab = p_slab * N * D + if positive_only: + assert torch.sum(weights < 0.) == 0 + else: + assert torch.allclose( + torch.sum(weights < 0.) / N_slab, + torch.sum(weights > 0.) / N_slab, + atol=2e-2, + ) + weights = torch.abs(weights) + + laplace_weights = weights[weights > 0.] + for quantile in torch.arange(5) / 5.: + cutoff = -torch.log(1. - quantile) + assert torch.allclose( + torch.sum(laplace_weights < cutoff) / N_slab, + quantile, + atol=1e-2, + ) \ No newline at end of file diff --git a/tests/priors/__init__.py b/tests/priors/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py deleted file mode 100644 index 12dcce6..0000000 --- a/tests/priors/test_l0.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import unittest - -from sparsecoding.priors import L0Prior - - -class TestL0Prior(unittest.TestCase): - def test_l0_prior(self): - N = 10000 - prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) - - torch.manual_seed(1997) - - D = prob_distr.shape[0] - - l0_prior = L0Prior(prob_distr) - weights = l0_prior.sample(N) - - assert weights.shape == (N, D) - - # Check uniform distribution over which weights are active. - per_weight_hist = torch.sum(weights, dim=0) # [D] - normalized_per_weight_hist = per_weight_hist / torch.sum(per_weight_hist) # [D] - assert torch.allclose( - normalized_per_weight_hist, - torch.full((D,), 0.25, dtype=torch.float32), - atol=1e-2, - ) - - # Check the distribution over the l0-norm of the weights. - num_active_per_sample = torch.sum(weights, dim=1) # [N] - for num_active in range(1, 5): - assert torch.allclose( - torch.sum(num_active_per_sample == num_active) / N, - prob_distr[num_active - 1], - atol=1e-2, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py deleted file mode 100644 index 20804fe..0000000 --- a/tests/priors/test_spike_slab.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import unittest - -from sparsecoding.priors import SpikeSlabPrior - - -class TestSpikeSlabPrior(unittest.TestCase): - def test_spike_slab_prior(self): - N = 10000 - D = 4 - p_spike = 0.5 - scale = 1. - - torch.manual_seed(1997) - - p_slab = 1. - p_spike - - for positive_only in [True, False]: - spike_slab_prior = SpikeSlabPrior( - D, - p_spike, - scale, - positive_only, - ) - weights = spike_slab_prior.sample(N) - - assert weights.shape == (N, D) - - # Check spike probability. - assert torch.allclose( - torch.sum(weights == 0.) / (N * D), - torch.tensor(p_spike), - atol=1e-2, - ) - - # Check Laplacian distribution. - N_slab = p_slab * N * D - if positive_only: - assert torch.sum(weights < 0.) == 0 - else: - assert torch.allclose( - torch.sum(weights < 0.) / N_slab, - torch.sum(weights > 0.) / N_slab, - atol=2e-2, - ) - weights = torch.abs(weights) - - laplace_weights = weights[weights > 0.] - for quantile in torch.arange(5) / 5.: - cutoff = -torch.log(1. - quantile) - assert torch.allclose( - torch.sum(laplace_weights < cutoff) / N_slab, - quantile, - atol=1e-2, - ) - - -if __name__ == "__main__": - unittest.main() From 4bc37646b5fb01a7427b654624ca230a9eb6d4b4 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:05:28 -0800 Subject: [PATCH 07/20] not using unittest anymore --- tests/__init__.py | 0 tests/test_test.py | 35 ----------------------------------- tests/testing_utilities.py | 17 ----------------- 3 files changed, 52 deletions(-) delete mode 100644 tests/__init__.py delete mode 100644 tests/test_test.py delete mode 100644 tests/testing_utilities.py diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_test.py b/tests/test_test.py deleted file mode 100644 index ff3e42e..0000000 --- a/tests/test_test.py +++ /dev/null @@ -1,35 +0,0 @@ -from tests.testing_utilities import TestCase -import torch -import numpy as np - - -class TestTestCaseBaseClass(TestCase): - - def test_pytorch_all_close(self): - result = torch.ones([10, 10]) + 1e-10 - expected = torch.ones([10, 10]) - self.assertAllClose(result, expected) - - def test_np_all_close(self): - result = np.ones([100, 100]) + 1e-10 - expected = np.ones([100, 100]) - self.assertAllClose(result, expected) - - def test_assert_true(self): - self.assertTrue(True) - - def test_assert_false(self): - self.assertFalse(False) - - def test_assert_equal(self): - self.assertEqual('sparse coding', 'sparse coding') - - def test_assert_pytorch_shape_equal(self): - a = torch.zeros([10, 10]) - b = torch.ones([10, 10]) - self.assertShapeEqual(a, b) - - def test_assert_np_shape_equal(self): - a = np.zeros([100, 100]) - b = np.ones([100, 100]) - self.assertShapeEqual(a, b) diff --git a/tests/testing_utilities.py b/tests/testing_utilities.py deleted file mode 100644 index 8a3de74..0000000 --- a/tests/testing_utilities.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np -import unittest - - -# constants -default_atol = 1e-6 -default_rtol = 1e-5 - - -class TestCase(unittest.TestCase): - '''Base class for testing''' - - def assertAllClose(self, a, b, rtol=default_rtol, atol=default_atol): - return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) - - def assertShapeEqual(self, a, b): - assert a.shape == b.shape From bc1e655ffcc04f7dd55dfff8c1473575533a217a Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:09:00 -0800 Subject: [PATCH 08/20] update CI --- .github/workflows/tests.yml | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5fb44d8..225983a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,26 +2,30 @@ name: "Testing" on: push: - paths-ignore: - - 'tutorials/**' - - 'README.md' - - '.gitignore' - - 'examples/**' + branches: + - main + pull_request: jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python all python version - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 architecture: x64 - name: Install dependencies - run: pip install -r requirements.txt + run: | + python -m venv --upgrade-deps .venv + source .venv/bin/activate + pip install -r requirements.txt + pip install -r requirements-dev.txt - name: Run Test - run: python -m unittest discover tests -v + run: | + source .venv/bin/activate + python -m pytest . From 7547e7e96973c92bcc00b0a69b7b723db143c8f1 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:17:18 -0800 Subject: [PATCH 09/20] fixup contributing file --- docs/contributing.md | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index 804bdba..5a49a47 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,35 +1,41 @@ # Contributing + All contributions are welcome! ## Bug Reporting -If you find a bug, submit a bug report on GitHub Issues. + +If you find a bug, submit a bug report on GitHub Issues. ## Adding Features/Fixing Bugs + If you have identified a new feature or bug that you can fix yourself, please follow the following procedure. -1. Clone `main` branch. -2. Create a new branch to contain your changes. -2. `add`, `commit`, and `push` your changes to this branch. -3. Create a pull request (PR). See more information on submitting a PR request below. +1. Fork `main` branch. +2. Create a new branch to contain your changes. +3. `add`, `commit`, and `push` your changes to this branch. +4. Create a pull request (PR). See more information on submitting a PR request below. ### Submitting a Pull Request -1. If necessary, please **write your own unit tests** and add them to [the tests directory](https://github.com/rctn/sparsecoding/blob/main/docs/contributing.md). -2. Verify that all tests are passed by running `python -m unittest tests/*`. -3. Be sure that your PR follows formatting guidelines, [PEP8](https://peps.python.org/pep-0008/) and [flake8](https://flake8.pycqa.org/en/latest/). -4. Make sure the title of your PR summarizes the features/issues resolved in your branch. -5. Submit your pull request and add reviewers. + +1. If necessary, please **write your own unit tests** and place them near the code being tested. High-level tests, such as integration or example tests can be placed in the top-level "tests" folder. +2. Verify that all tests are passed by running `python -m pytest .`. +3. Be sure that your PR follows formatting guidelines, [PEP8](https://peps.python.org/pep-0008/) and [flake8](https://flake8.pycqa.org/en/latest/). +4. Make sure the title of your PR summarizes the features/issues resolved in your branch. +5. Submit your pull request and add reviewers. ## Coding Style Guidelines -The following are some guidelines on how new code should be written. Of course, there are special cases, and there will be exceptions to these rules. + +The following are some guidelines on how new code should be written. Of course, there are special cases, and there will be exceptions to these rules. 1. Format code in accordance with [flake8](https://flake8.pycqa.org/en/latest/) standard. 2. Use underscores to separate words in non-class names: `n_samples` rather than `nsamples`. 3. Avoid single-character variable names. ## Docstrings + When writing docstrings, please follow the following example. -``` +```py def count_beans(self, baz, use_gpu=False, foo="vector" bar=None): """Write a one-line summary for the method. @@ -69,4 +75,4 @@ def count_beans(self, baz, use_gpu=False, foo="vector" quantum-mechanical description of physical reality be considered complete?. Physical review, 47(10), 777. """ -``` \ No newline at end of file +``` From 6678a5e2131673b377e6af108499004a8625f0e9 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:27:54 -0800 Subject: [PATCH 10/20] formatting --- conftest.py | 25 ++++++++--- sparsecoding/inference/__init__.py | 20 ++++----- sparsecoding/inference/iht.py | 14 +++--- sparsecoding/inference/inference_method.py | 2 +- sparsecoding/inference/ista.py | 29 +++++++----- sparsecoding/inference/ista_test.py | 15 ++++++- sparsecoding/inference/lca.py | 45 ++++++++++++------- sparsecoding/inference/lca_test.py | 17 +++++-- sparsecoding/inference/lsm.py | 29 ++++++------ sparsecoding/inference/lsm_test.py | 15 ++++++- sparsecoding/inference/mp.py | 9 ++-- sparsecoding/inference/omp.py | 9 ++-- .../inference/pytorch_optimizer_test.py | 26 ++++++++--- sparsecoding/inference/vanilla.py | 26 ++++++----- sparsecoding/inference/vanilla_test.py | 15 ++++++- sparsecoding/priors/__init__.py | 8 ++-- sparsecoding/priors/l0_prior.py | 12 ++--- sparsecoding/priors/prior.py | 1 + sparsecoding/priors/spike_slab_prior_test.py | 20 ++++----- sparsecoding/test_utils/__init__.py | 11 ++++- sparsecoding/test_utils/asserts.py | 4 +- sparsecoding/test_utils/asserts_test.py | 4 +- sparsecoding/test_utils/constant_fixtures.py | 6 ++- sparsecoding/test_utils/dataset_fixtures.py | 12 +++-- sparsecoding/test_utils/model_fixtures.py | 6 +-- sparsecoding/test_utils/prior_fixtures.py | 2 +- 26 files changed, 245 insertions(+), 137 deletions(-) diff --git a/conftest.py b/conftest.py index 18ac5fd..31b4f34 100644 --- a/conftest.py +++ b/conftest.py @@ -1,9 +1,24 @@ -from sparsecoding.test_utils import (bars_datas_fixture, bars_datasets_fixture, - bars_dictionary_fixture, - dataset_size_fixture, patch_size_fixture, - priors_fixture) +import torch + +from sparsecoding.test_utils import ( + bars_datas_fixture, + bars_datasets_fixture, + bars_dictionary_fixture, + dataset_size_fixture, + patch_size_fixture, + priors_fixture, +) + +torch.manual_seed(1997) # We import and define all fixtures in this file. # This allows users to avoid any dependency fixtures. # NOTE: This means pytest should only be run from this directory. -__all__ = ['dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture'] +__all__ = [ + "dataset_size_fixture", + "patch_size_fixture", + "bars_datas_fixture", + "bars_datasets_fixture", + "bars_dictionary_fixture", + "priors_fixture", +] diff --git a/sparsecoding/inference/__init__.py b/sparsecoding/inference/__init__.py index 0cfa46d..8ac1dec 100644 --- a/sparsecoding/inference/__init__.py +++ b/sparsecoding/inference/__init__.py @@ -9,13 +9,13 @@ from .vanilla import Vanilla __all__ = [ - 'IHT', - 'InferenceMethod', - 'ISTA', - 'LCA', - 'LSM', - 'MP', - 'OMP', - 'PyTorchOptimizer', - 'Vanilla' -] \ No newline at end of file + "IHT", + "InferenceMethod", + "ISTA", + "LCA", + "LSM", + "MP", + "OMP", + "PyTorchOptimizer", + "Vanilla", +] diff --git a/sparsecoding/inference/iht.py b/sparsecoding/inference/iht.py index e62dacd..467cfc5 100644 --- a/sparsecoding/inference/iht.py +++ b/sparsecoding/inference/iht.py @@ -12,7 +12,7 @@ class IHT(InferenceMethod): """ def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False): - ''' + """ Parameters ---------- @@ -27,7 +27,7 @@ def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=Fal can result in large memory usage/potential exhaustion. This function typically used for debugging solver : default=None - ''' + """ super().__init__(solver) self.n_iter = n_iter self.sparsity = sparsity @@ -54,11 +54,10 @@ def infer(self, data, dictionary): device = dictionary.device # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) + K = np.ceil(self.sparsity * n_basis).astype(int) # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) for _ in range(self.n_iter): # Compute the prediction given the current coefficients @@ -75,9 +74,8 @@ def infer(self, data, dictionary): topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1) # Reconstruct coefficients using the output of torch.topk - coefficients = ( - torch.sign(coefficients) - * torch.zeros(batch_size, n_basis, device=device).scatter_(1, indices, topK_values) + coefficients = torch.sign(coefficients) * torch.zeros(batch_size, n_basis, device=device).scatter_( + 1, indices, topK_values ) return coefficients.detach() diff --git a/sparsecoding/inference/inference_method.py b/sparsecoding/inference/inference_method.py index 98ff0b8..62c97ec 100644 --- a/sparsecoding/inference/inference_method.py +++ b/sparsecoding/inference/inference_method.py @@ -70,4 +70,4 @@ def checknan(data=torch.tensor(0), name="data"): If the nan found in data """ if torch.isnan(data).any(): - raise ValueError("InferenceMethod error: nan in %s." % (name)) \ No newline at end of file + raise ValueError("InferenceMethod error: nan in %s." % (name)) diff --git a/sparsecoding/inference/ista.py b/sparsecoding/inference/ista.py index 5c1abfc..42bd21b 100644 --- a/sparsecoding/inference/ista.py +++ b/sparsecoding/inference/ista.py @@ -4,8 +4,15 @@ class ISTA(InferenceMethod): - def __init__(self, n_iter=100, sparsity_penalty=1e-2, stop_early=False, - epsilon=1e-2, solver=None, return_all_coefficients=False): + def __init__( + self, + n_iter=100, + sparsity_penalty=1e-2, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients=False, + ): """Iterative shrinkage-thresholding algorithm for solving LASSO problems. Parameters @@ -52,8 +59,8 @@ def threshold_nonlinearity(self, u): a : array-like, shape [batch_size, n_basis] activations """ - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a + a = (torch.abs(u) - self.threshold).clamp(min=0.0) + a = torch.sign(u) * a return a def infer(self, data, dictionary, coeff_0=None, use_checknan=False): @@ -85,9 +92,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): # Calculate stepsize based on largest eigenvalue of # dictionary.T @ dictionary. - lipschitz_constant = torch.linalg.eigvalsh( - torch.mm(dictionary.T, dictionary))[-1] - stepsize = 1. / lipschitz_constant + lipschitz_constant = torch.linalg.eigvalsh(torch.mm(dictionary.T, dictionary))[-1] + stepsize = 1.0 / lipschitz_constant self.threshold = stepsize * self.sparsity_penalty # Initialize coefficients. @@ -104,16 +110,17 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): old_u = u.clone().detach() if self.return_all_coefficients: - coefficients = torch.concat([coefficients, - self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) + coefficients = torch.concat( + [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], + dim=1, + ) u -= stepsize * torch.mm(residual, dictionary) self.coefficients = self.threshold_nonlinearity(u) if self.stop_early: # Stopping condition is function of change of the coefficients. - a_change = torch.mean( - torch.abs(old_u - u) / stepsize) + a_change = torch.mean(torch.abs(old_u - u) / stepsize) if a_change < self.epsilon: break diff --git a/sparsecoding/inference/ista_test.py b/sparsecoding/inference/ista_test.py index 6a99f5a..397c555 100644 --- a/sparsecoding/inference/ista_test.py +++ b/sparsecoding/inference/ista_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """Test that ISTA inference returns expected shapes.""" N_ITER = 10 @@ -18,7 +24,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """Test that ISTA inference recovers the correct weights.""" N_ITER = 5000 for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): diff --git a/sparsecoding/inference/lca.py b/sparsecoding/inference/lca.py index 49ef1a5..ad5023f 100644 --- a/sparsecoding/inference/lca.py +++ b/sparsecoding/inference/lca.py @@ -4,9 +4,17 @@ class LCA(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients="none", nonnegative=False): + def __init__( + self, + n_iter=100, + coeff_lr=1e-3, + threshold=0.1, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients="none", + nonnegative=False, + ): """Method implemented according locally competative algorithm (LCA) with the ideal soft thresholding function. @@ -48,8 +56,9 @@ def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, self.n_iter = n_iter self.nonnegative = nonnegative if return_all_coefficients not in ["none", "membrane", "active"]: - raise ValueError("Invalid input for return_all_coefficients. Valid" - "inputs are: \"none\", \"membrane\", \"active\".") + raise ValueError( + "Invalid input for return_all_coefficients. Valid" 'inputs are: "none", "membrane", "active".' + ) self.return_all_coefficients = return_all_coefficients def threshold_nonlinearity(self, u): @@ -66,10 +75,10 @@ def threshold_nonlinearity(self, u): Activations """ if self.nonnegative: - a = (u - self.threshold).clamp(min=0.) + a = (u - self.threshold).clamp(min=0.0) else: - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a + a = (torch.abs(u) - self.threshold).clamp(min=0.0) + a = torch.sign(u) * a return a def grad(self, b, G, u, a): @@ -89,7 +98,7 @@ def grad(self, b, G, u, a): du : array-like, shape [batch_size, n_coefficients] Gradient of membrane potentials """ - du = b-u-(G@a.t()).t() + du = b - u - (G @ a.t()).t() return du def infer(self, data, dictionary, coeff_0=None, use_checknan=False): @@ -127,8 +136,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - b = (dictionary.t()@data.t()).t() - G = dictionary.t()@dictionary-torch.eye(n_basis).to(device) + b = (dictionary.t() @ data.t()).t() + G = dictionary.t() @ dictionary - torch.eye(n_basis).to(device) for i in range(self.n_iter): # store old membrane potentials to evalute stop early condition if self.stop_early: @@ -138,19 +147,23 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): if self.return_all_coefficients != "none": if self.return_all_coefficients == "active": coefficients = torch.concat( - [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) + [ + coefficients, + self.threshold_nonlinearity(u).clone().unsqueeze(1), + ], + dim=1, + ) else: - coefficients = torch.concat( - [coefficients, u.clone().unsqueeze(1)], dim=1) + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) # compute new a = self.threshold_nonlinearity(u) du = self.grad(b, G, u, a) - u = u + self.coeff_lr*du + u = u + self.coeff_lr * du # check stopping condition if self.stop_early: - relative_change_in_coeff = torch.linalg.norm(old_u - u)/torch.linalg.norm(old_u) + relative_change_in_coeff = torch.linalg.norm(old_u - u) / torch.linalg.norm(old_u) if relative_change_in_coeff < self.epsilon: break diff --git a/sparsecoding/inference/lca_test.py b/sparsecoding/inference/lca_test.py index 2fe955f..0834ddd 100644 --- a/sparsecoding/inference/lca_test.py +++ b/sparsecoding/inference/lca_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LCA inference returns expected shapes. """ @@ -21,7 +27,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LCA inference recovers the correct weights. """ @@ -38,4 +49,4 @@ def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: li a = inference_method.infer(data, bars_dictionary_fixture) - assert_allclose(a, dataset.weights, atol=5e-2) \ No newline at end of file + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/inference/lsm.py b/sparsecoding/inference/lsm.py index c3faaf6..83e5ed5 100644 --- a/sparsecoding/inference/lsm.py +++ b/sparsecoding/inference/lsm.py @@ -4,9 +4,17 @@ class LSM(InferenceMethod): - def __init__(self, n_iter=100, n_iter_LSM=6, beta=0.01, alpha=80.0, - sigma=0.005, sparse_threshold=10**-2, solver=None, - return_all_coefficients=False): + def __init__( + self, + n_iter=100, + n_iter_LSM=6, + beta=0.01, + alpha=80.0, + sigma=0.005, + sparse_threshold=10**-2, + solver=None, + return_all_coefficients=False, + ): """Infer latent coefficients generating data given dictionary. Method implemented according to "Group Sparse Coding with a Laplacian Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010) @@ -73,7 +81,7 @@ def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma): # Compute loss preds = torch.mm(dictionary, coefficients.t()).t() - mse_loss = (1/(2*(sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) + mse_loss = (1 / (2 * (sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True) loss = mse_loss + sparse_loss return loss @@ -107,10 +115,7 @@ def infer(self, data, dictionary): # Outer loop, set sparsity penalties (lambdas). for i in range(self.n_iter_LSM): # Compute the initial values of lambdas - lambdas = ( - (self.alpha + 1) - / (self.beta + torch.abs(coefficients.detach())) - ) + lambdas = (self.alpha + 1) / (self.beta + torch.abs(coefficients.detach())) # Inner loop, optimize coefficients w/ current sparsity penalties. # Exits early if converged before `n_iter`s. @@ -132,16 +137,12 @@ def infer(self, data, dictionary): optimizer.step() # Break if coefficients have converged. - if ( - last_loss is not None - and loss > 1.05 * last_loss - ): + if last_loss is not None and loss > 1.05 * last_loss: break last_loss = loss # Sparsify the final solution by discarding the small coefficients - coefficients.data[torch.abs(coefficients.data) - < self.sparse_threshold] = 0 + coefficients.data[torch.abs(coefficients.data) < self.sparse_threshold] = 0 return coefficients.detach() diff --git a/sparsecoding/inference/lsm_test.py b/sparsecoding/inference/lsm_test.py index b7e4b05..6ac2155 100644 --- a/sparsecoding/inference/lsm_test.py +++ b/sparsecoding/inference/lsm_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LSM inference returns expected shapes. """ @@ -16,7 +22,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that LSM inference recovers the correct weights. """ diff --git a/sparsecoding/inference/mp.py b/sparsecoding/inference/mp.py index 2d413fc..4305976 100644 --- a/sparsecoding/inference/mp.py +++ b/sparsecoding/inference/mp.py @@ -12,7 +12,7 @@ class MP(InferenceMethod): """ def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' + """ Parameters ---------- @@ -24,7 +24,7 @@ def __init__(self, sparsity, solver=None, return_all_coefficients=False): can result in large memory usage/potential exhaustion. This function typically used for debugging solver : default=None - ''' + """ super().__init__(solver) self.sparsity = sparsity self.return_all_coefficients = return_all_coefficients @@ -50,14 +50,13 @@ def infer(self, data, dictionary): device = dictionary.device # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) + K = np.ceil(self.sparsity * n_basis).astype(int) # Get dictionary norms in case atoms are not normalized dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) residual = data.clone() # [batch_size, n_features] diff --git a/sparsecoding/inference/omp.py b/sparsecoding/inference/omp.py index 33c49d7..db99427 100644 --- a/sparsecoding/inference/omp.py +++ b/sparsecoding/inference/omp.py @@ -13,7 +13,7 @@ class OMP(InferenceMethod): """ def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' + """ Parameters ---------- @@ -25,7 +25,7 @@ def __init__(self, sparsity, solver=None, return_all_coefficients=False): can result in large memory usage/potential exhaustion. This function typically used for debugging solver : default=None - ''' + """ super().__init__(solver) self.sparsity = sparsity self.return_all_coefficients = return_all_coefficients @@ -51,14 +51,13 @@ def infer(self, data, dictionary): device = dictionary.device # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) + K = np.ceil(self.sparsity * n_basis).astype(int) # Get dictionary norms in case atoms are not normalized dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) residual = data.clone() # [batch_size, n_features] diff --git a/sparsecoding/inference/pytorch_optimizer_test.py b/sparsecoding/inference/pytorch_optimizer_test.py index f90fd66..e7322bc 100644 --- a/sparsecoding/inference/pytorch_optimizer_test.py +++ b/sparsecoding/inference/pytorch_optimizer_test.py @@ -10,22 +10,24 @@ def lasso_loss(data, dictionary, coefficients, sparsity_penalty): Generic MSE + l1-norm loss. """ batch_size = data.shape[0] - datahat = (dictionary@coefficients.t()).t() + datahat = (dictionary @ coefficients.t()).t() - mse_loss = torch.linalg.vector_norm(datahat-data, dim=1).square() + mse_loss = torch.linalg.vector_norm(datahat - data, dim=1).square() sparse_loss = torch.sum(torch.abs(coefficients), axis=1) - total_loss = (mse_loss + sparsity_penalty*sparse_loss)/batch_size + total_loss = (mse_loss + sparsity_penalty * sparse_loss) / batch_size return total_loss + def loss_fn(data, dictionary, coefficients): return lasso_loss( data, dictionary, coefficients, - sparsity_penalty=1., + sparsity_penalty=1.0, ) + def optimizer_fn(coefficients): return torch.optim.Adam( coefficients, @@ -35,7 +37,14 @@ def optimizer_fn(coefficients): weight_decay=0, ) -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that PyTorchOptimizer inference returns expected shapes. """ @@ -48,7 +57,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert_shape_equal(a, dataset.weights) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that PyTorchOptimizer inference recovers the correct weights. """ diff --git a/sparsecoding/inference/vanilla.py b/sparsecoding/inference/vanilla.py index 473fcb8..11bbfa4 100644 --- a/sparsecoding/inference/vanilla.py +++ b/sparsecoding/inference/vanilla.py @@ -4,9 +4,16 @@ class Vanilla(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, sparsity_penalty=0.2, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients=False): + def __init__( + self, + n_iter=100, + coeff_lr=1e-3, + sparsity_penalty=0.2, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients=False, + ): """Gradient descent with Euler's method on model in Olshausen & Field (1997) with laplace prior over coefficients (corresponding to l-1 norm penalty). @@ -61,8 +68,7 @@ def grad(self, residual, dictionary, a): da : array-like, shape [batch_size, n_coefficients] Gradient of membrane potentials """ - da = (dictionary.t()@residual.t()).t() - \ - self.sparsity_penalty*torch.sign(a) + da = (dictionary.t() @ residual.t()).t() - self.sparsity_penalty * torch.sign(a) return da def infer(self, data, dictionary, coeff_0=None, use_checknan=False): @@ -96,11 +102,11 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): if coeff_0 is not None: a = coeff_0.to(device) else: - a = torch.rand((batch_size, n_basis)).to(device)-0.5 + a = torch.rand((batch_size, n_basis)).to(device) - 0.5 coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - residual = data - (dictionary@a.t()).t() + residual = data - (dictionary @ a.t()).t() for i in range(self.n_iter): if self.return_all_coefficients: @@ -110,13 +116,13 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False): old_a = a.clone().detach() da = self.grad(residual, dictionary, a) - a = a + self.coeff_lr*da + a = a + self.coeff_lr * da if self.stop_early: - if torch.linalg.norm(old_a - a)/torch.linalg.norm(old_a) < self.epsilon: + if torch.linalg.norm(old_a - a) / torch.linalg.norm(old_a) < self.epsilon: break - residual = data - (dictionary@a.t()).t() + residual = data - (dictionary @ a.t()).t() if use_checknan: self.checknan(a, "coefficients") diff --git a/sparsecoding/inference/vanilla_test.py b/sparsecoding/inference/vanilla_test.py index 57c2244..9c556e5 100644 --- a/sparsecoding/inference/vanilla_test.py +++ b/sparsecoding/inference/vanilla_test.py @@ -5,7 +5,13 @@ from sparsecoding.test_utils import assert_allclose, assert_shape_equal -def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that Vanilla inference returns expected shapes. """ @@ -20,7 +26,12 @@ def test_shape(patch_size_fixture: int, dataset_size_fixture: int, bars_dictiona a = inference_method.infer(data, bars_dictionary_fixture) assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) -def test_inference(bars_dictionary_fixture: torch.Tensor, bars_datas_fixture: list[torch.Tensor], bars_datasets_fixture: list[BarsDataset]): + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): """ Test that Vanilla inference recovers the correct weights. """ diff --git a/sparsecoding/priors/__init__.py b/sparsecoding/priors/__init__.py index 94e9994..75314d7 100644 --- a/sparsecoding/priors/__init__.py +++ b/sparsecoding/priors/__init__.py @@ -3,7 +3,7 @@ from .spike_slab_prior import SpikeSlabPrior __all__ = [ - "Prior", - "L0Prior", - "SpikeSlabPrior", -] \ No newline at end of file + "Prior", + "L0Prior", + "SpikeSlabPrior", +] diff --git a/sparsecoding/priors/l0_prior.py b/sparsecoding/priors/l0_prior.py index 1e41853..590bcad 100644 --- a/sparsecoding/priors/l0_prior.py +++ b/sparsecoding/priors/l0_prior.py @@ -33,10 +33,7 @@ def __init__( def D(self): return self.prob_distr.shape[0] - def sample( - self, - num_samples: int - ): + def sample(self, num_samples: int): N = num_samples num_active_weights = 1 + torch.multinomial( @@ -46,10 +43,7 @@ def sample( ) # [N] d_idxs = torch.arange(self.D) - active_idx_mask = ( - d_idxs.reshape(1, self.D) - < num_active_weights.reshape(N, 1) - ) # [N, self.D] + active_idx_mask = d_idxs.reshape(1, self.D) < num_active_weights.reshape(N, 1) # [N, self.D] n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] # Need to shuffle here so that it's not always the first weights that are active. @@ -60,6 +54,6 @@ def sample( active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] weights = torch.zeros((N, self.D), dtype=torch.float32) - weights[active_weight_idxs] += 1. + weights[active_weight_idxs] += 1.0 return weights diff --git a/sparsecoding/priors/prior.py b/sparsecoding/priors/prior.py index 4940661..061c62f 100644 --- a/sparsecoding/priors/prior.py +++ b/sparsecoding/priors/prior.py @@ -9,6 +9,7 @@ class Prior(ABC): weights_dim : int Number of weights for each sample. """ + @abstractmethod def D(self): """ diff --git a/sparsecoding/priors/spike_slab_prior_test.py b/sparsecoding/priors/spike_slab_prior_test.py index f1023ac..ed601bc 100644 --- a/sparsecoding/priors/spike_slab_prior_test.py +++ b/sparsecoding/priors/spike_slab_prior_test.py @@ -9,11 +9,11 @@ def test_spike_slab_prior(positive_only: bool): N = 10000 D = 4 p_spike = 0.5 - scale = 1. + scale = 1.0 torch.manual_seed(1997) - p_slab = 1. - p_spike + p_slab = 1.0 - p_spike spike_slab_prior = SpikeSlabPrior( D, @@ -27,7 +27,7 @@ def test_spike_slab_prior(positive_only: bool): # Check spike probability. assert torch.allclose( - torch.sum(weights == 0.) / (N * D), + torch.sum(weights == 0.0) / (N * D), torch.tensor(p_spike), atol=1e-2, ) @@ -35,20 +35,20 @@ def test_spike_slab_prior(positive_only: bool): # Check Laplacian distribution. N_slab = p_slab * N * D if positive_only: - assert torch.sum(weights < 0.) == 0 + assert torch.sum(weights < 0.0) == 0 else: assert torch.allclose( - torch.sum(weights < 0.) / N_slab, - torch.sum(weights > 0.) / N_slab, + torch.sum(weights < 0.0) / N_slab, + torch.sum(weights > 0.0) / N_slab, atol=2e-2, ) weights = torch.abs(weights) - laplace_weights = weights[weights > 0.] - for quantile in torch.arange(5) / 5.: - cutoff = -torch.log(1. - quantile) + laplace_weights = weights[weights > 0.0] + for quantile in torch.arange(5) / 5.0: + cutoff = -torch.log(1.0 - quantile) assert torch.allclose( torch.sum(laplace_weights < cutoff) / N_slab, quantile, atol=1e-2, - ) \ No newline at end of file + ) diff --git a/sparsecoding/test_utils/__init__.py b/sparsecoding/test_utils/__init__.py index 7b8371b..8c18de1 100644 --- a/sparsecoding/test_utils/__init__.py +++ b/sparsecoding/test_utils/__init__.py @@ -4,4 +4,13 @@ from .model_fixtures import bars_dictionary_fixture from .prior_fixtures import priors_fixture -__all__ = ['assert_allclose', 'assert_shape_equal', 'dataset_size_fixture', 'patch_size_fixture', 'bars_datas_fixture', 'bars_datasets_fixture', 'bars_dictionary_fixture', 'priors_fixture'] +__all__ = [ + "assert_allclose", + "assert_shape_equal", + "dataset_size_fixture", + "patch_size_fixture", + "bars_datas_fixture", + "bars_datasets_fixture", + "bars_dictionary_fixture", + "priors_fixture", +] diff --git a/sparsecoding/test_utils/asserts.py b/sparsecoding/test_utils/asserts.py index bc7056a..64db496 100644 --- a/sparsecoding/test_utils/asserts.py +++ b/sparsecoding/test_utils/asserts.py @@ -4,8 +4,10 @@ DEFAULT_ATOL = 1e-6 DEFAULT_RTOL = 1e-5 + def assert_allclose(a: np.ndarray, b: np.ndarray, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL) -> None: return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + def assert_shape_equal(a: np.ndarray, b: np.ndarray) -> None: - assert a.shape == b.shape \ No newline at end of file + assert a.shape == b.shape diff --git a/sparsecoding/test_utils/asserts_test.py b/sparsecoding/test_utils/asserts_test.py index 912b0ad..1af9705 100644 --- a/sparsecoding/test_utils/asserts_test.py +++ b/sparsecoding/test_utils/asserts_test.py @@ -1,4 +1,3 @@ - import numpy as np import torch @@ -10,16 +9,19 @@ def test_pytorch_all_close(): expected = torch.ones([10, 10]) assert_allclose(result, expected) + def test_np_all_close(): result = np.ones([100, 100]) + 1e-10 expected = np.ones([100, 100]) assert_allclose(result, expected) + def test_assert_pytorch_shape_equal(): a = torch.zeros([10, 10]) b = torch.ones([10, 10]) assert_shape_equal(a, b) + def test_assert_np_shape_equal(): a = np.zeros([100, 100]) b = np.ones([100, 100]) diff --git a/sparsecoding/test_utils/constant_fixtures.py b/sparsecoding/test_utils/constant_fixtures.py index eac6253..1c04698 100644 --- a/sparsecoding/test_utils/constant_fixtures.py +++ b/sparsecoding/test_utils/constant_fixtures.py @@ -3,10 +3,12 @@ PATCH_SIZE = 8 DATASET_SIZE = 1000 + @pytest.fixture() def patch_size_fixture() -> int: - return PATCH_SIZE + return PATCH_SIZE + @pytest.fixture() def dataset_size_fixture() -> int: - return DATASET_SIZE \ No newline at end of file + return DATASET_SIZE diff --git a/sparsecoding/test_utils/dataset_fixtures.py b/sparsecoding/test_utils/dataset_fixtures.py index 9e3dcc9..c3f20d4 100644 --- a/sparsecoding/test_utils/dataset_fixtures.py +++ b/sparsecoding/test_utils/dataset_fixtures.py @@ -1,4 +1,3 @@ - import pytest import torch @@ -7,7 +6,9 @@ @pytest.fixture() -def bars_datasets_fixture(patch_size_fixture: int, dataset_size_fixture: int, priors_fixture: list[Prior]) -> list[BarsDataset]: +def bars_datasets_fixture( + patch_size_fixture: int, dataset_size_fixture: int, priors_fixture: list[Prior] +) -> list[BarsDataset]: return [ BarsDataset( patch_size=patch_size_fixture, @@ -17,9 +18,12 @@ def bars_datasets_fixture(patch_size_fixture: int, dataset_size_fixture: int, pr for prior in priors_fixture ] + @pytest.fixture() -def bars_datas_fixture(patch_size_fixture: int, dataset_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> list[torch.Tensor]: +def bars_datas_fixture( + patch_size_fixture: int, dataset_size_fixture: int, bars_datasets_fixture: list[BarsDataset] +) -> list[torch.Tensor]: return [ dataset.data.reshape((dataset_size_fixture, patch_size_fixture * patch_size_fixture)) for dataset in bars_datasets_fixture - ] \ No newline at end of file + ] diff --git a/sparsecoding/test_utils/model_fixtures.py b/sparsecoding/test_utils/model_fixtures.py index 95573f6..1a0d895 100644 --- a/sparsecoding/test_utils/model_fixtures.py +++ b/sparsecoding/test_utils/model_fixtures.py @@ -1,12 +1,10 @@ - import pytest import torch from sparsecoding.datasets import BarsDataset -torch.manual_seed(1997) @pytest.fixture() def bars_dictionary_fixture(patch_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> torch.Tensor: - """Return a bars dataset basis reshaped to represent a dictionary.""" - return bars_datasets_fixture[0].basis.reshape((2 * patch_size_fixture, patch_size_fixture * patch_size_fixture)).T \ No newline at end of file + """Return a bars dataset basis reshaped to represent a dictionary.""" + return bars_datasets_fixture[0].basis.reshape((2 * patch_size_fixture, patch_size_fixture * patch_size_fixture)).T diff --git a/sparsecoding/test_utils/prior_fixtures.py b/sparsecoding/test_utils/prior_fixtures.py index ffa2ef1..08a3125 100644 --- a/sparsecoding/test_utils/prior_fixtures.py +++ b/sparsecoding/test_utils/prior_fixtures.py @@ -21,4 +21,4 @@ def priors_fixture(patch_size_fixture: int) -> list[Prior]: ).type(torch.float32) ), ), - ] \ No newline at end of file + ] From ea74a84b6c4c74a28f60233d973dc9eb8bf0cbf4 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:33:15 -0800 Subject: [PATCH 11/20] set seed in conftest --- sparsecoding/priors/lo_prior_test.py | 2 -- sparsecoding/priors/spike_slab_prior_test.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/sparsecoding/priors/lo_prior_test.py b/sparsecoding/priors/lo_prior_test.py index 9dbb0e4..7cfa206 100644 --- a/sparsecoding/priors/lo_prior_test.py +++ b/sparsecoding/priors/lo_prior_test.py @@ -7,8 +7,6 @@ def test_l0_prior(): N = 10000 prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) - torch.manual_seed(1997) - D = prob_distr.shape[0] l0_prior = L0Prior(prob_distr) diff --git a/sparsecoding/priors/spike_slab_prior_test.py b/sparsecoding/priors/spike_slab_prior_test.py index ed601bc..d0054cc 100644 --- a/sparsecoding/priors/spike_slab_prior_test.py +++ b/sparsecoding/priors/spike_slab_prior_test.py @@ -11,8 +11,6 @@ def test_spike_slab_prior(positive_only: bool): p_spike = 0.5 scale = 1.0 - torch.manual_seed(1997) - p_slab = 1.0 - p_spike spike_slab_prior = SpikeSlabPrior( From 087ec50dc14469761fed916279e34f0eea43108e Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:40:22 -0800 Subject: [PATCH 12/20] ignore data --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index d3bd7e6..ee70a21 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# OS specific +.DS_Store + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -130,3 +133,4 @@ dmypy.json # data raw_data +data \ No newline at end of file From b45be9e1fd1737e8cfdf619bd57f59aa34b20129 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:43:49 -0800 Subject: [PATCH 13/20] rename files to match new convention --- sparsecoding/transforms/images_patch_test.py | 48 +++++++++++++ sparsecoding/transforms/test_patch.py | 58 ---------------- sparsecoding/transforms/test_whiten.py | 73 -------------------- sparsecoding/transforms/whiten_test.py | 61 ++++++++++++++++ 4 files changed, 109 insertions(+), 131 deletions(-) create mode 100644 sparsecoding/transforms/images_patch_test.py delete mode 100644 sparsecoding/transforms/test_patch.py delete mode 100644 sparsecoding/transforms/test_whiten.py create mode 100644 sparsecoding/transforms/whiten_test.py diff --git a/sparsecoding/transforms/images_patch_test.py b/sparsecoding/transforms/images_patch_test.py new file mode 100644 index 0000000..f7ff395 --- /dev/null +++ b/sparsecoding/transforms/images_patch_test.py @@ -0,0 +1,48 @@ +import torch + +from sparsecoding.transforms import patchify, quilt, sample_random_patches + + +def test_patchify_quilt_cycle(): + X, Y, Z = 3, 4, 5 + C = 3 + P = 8 + H = 6 * P + W = 8 * P + + images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) + + patches = patchify(P, images) + assert patches.shape == (X, Y, Z, int(H / P) * int(W / P), C, P, P) + + quilted_images = quilt(H, W, patches) + assert torch.allclose( + images, + quilted_images, + ), "Quilted images should be equal to input images." + +def test_sample_random_patches(): + X, Y, Z = 3, 4, 5 + C = 3 + P = 8 + H = 4 * P + W = 8 * P + N = 10 + + images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) + + random_patches = sample_random_patches(P, N, images) + assert random_patches.shape == (N, C, P, P) + + # Check that patches are actually taken from one of the images. + all_patches = torch.nn.functional.unfold( + input=images.reshape(-1, C, H, W), + kernel_size=P, + ) # [prod(*), C*P*P, L] + all_patches = torch.permute(all_patches, (0, 2, 1)) # [prod(*), L, C*P*P] + all_patches = torch.reshape(all_patches, (-1, C*P*P)) + for n in range(N): + patch = random_patches[n].reshape(1, C*P*P) + delta = torch.abs(patch - all_patches) # [-1, C*P*P] + patchwise_delta = torch.sum(delta, dim=1) # [-1] + assert torch.min(patchwise_delta) == 0. diff --git a/sparsecoding/transforms/test_patch.py b/sparsecoding/transforms/test_patch.py deleted file mode 100644 index 22cafbc..0000000 --- a/sparsecoding/transforms/test_patch.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import unittest - -from sparsecoding.transforms import sample_random_patches, patchify, quilt - - -class TestPatcher(unittest.TestCase): - def test_patchify_quilt_cycle(self): - X, Y, Z = 3, 4, 5 - C = 3 - P = 8 - H = 6 * P - W = 8 * P - - torch.manual_seed(1997) - - images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) - - patches = patchify(P, images) - assert patches.shape == (X, Y, Z, int(H / P) * int(W / P), C, P, P) - - quilted_images = quilt(H, W, patches) - assert torch.allclose( - images, - quilted_images, - ), "Quilted images should be equal to input images." - - def test_sample_random_patches(self): - X, Y, Z = 3, 4, 5 - C = 3 - P = 8 - H = 4 * P - W = 8 * P - N = 10 - - torch.manual_seed(1997) - - images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) - - random_patches = sample_random_patches(P, N, images) - assert random_patches.shape == (N, C, P, P) - - # Check that patches are actually taken from one of the images. - all_patches = torch.nn.functional.unfold( - input=images.reshape(-1, C, H, W), - kernel_size=P, - ) # [prod(*), C*P*P, L] - all_patches = torch.permute(all_patches, (0, 2, 1)) # [prod(*), L, C*P*P] - all_patches = torch.reshape(all_patches, (-1, C*P*P)) - for n in range(N): - patch = random_patches[n].reshape(1, C*P*P) - delta = torch.abs(patch - all_patches) # [-1, C*P*P] - patchwise_delta = torch.sum(delta, dim=1) # [-1] - assert torch.min(patchwise_delta) == 0. - - -if __name__ == "__main__": - unittest.main() diff --git a/sparsecoding/transforms/test_whiten.py b/sparsecoding/transforms/test_whiten.py deleted file mode 100644 index 3fa4406..0000000 --- a/sparsecoding/transforms/test_whiten.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import unittest - -from sparsecoding.transforms import whiten - - -class TestWhitener(unittest.TestCase): - def test_zca(self): - N = 5000 - D = 32*32 - - torch.manual_seed(1997) - - X = torch.rand((N, D), dtype=torch.float32) - - X_whitened = whiten(X) - - assert torch.allclose( - torch.mean(X_whitened, dim=0), - torch.zeros(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have zero mean." - assert torch.allclose( - torch.cov(X_whitened.T), - torch.eye(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have unit (identity) covariance." - - def test_pca(self): - N = 5000 - D = 32*32 - - torch.manual_seed(1997) - - X = torch.rand((N, D), dtype=torch.float32) - - X_whitened = whiten(X, algorithm='pca') - - assert torch.allclose( - torch.mean(X_whitened, dim=0), - torch.zeros(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have zero mean." - assert torch.allclose( - torch.cov(X_whitened.T), - torch.eye(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have unit (identity) covariance." - - def test_cholesky(self): - N = 5000 - D = 32*32 - - torch.manual_seed(1997) - - X = torch.rand((N, D), dtype=torch.float32) - - X_whitened = whiten(X, algorithm='cholesky') - - assert torch.allclose( - torch.mean(X_whitened, dim=0), - torch.zeros(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have zero mean." - assert torch.allclose( - torch.cov(X_whitened.T), - torch.eye(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have unit (identity) covariance." - - -if __name__ == "__main__": - unittest.main() diff --git a/sparsecoding/transforms/whiten_test.py b/sparsecoding/transforms/whiten_test.py new file mode 100644 index 0000000..4cdd97d --- /dev/null +++ b/sparsecoding/transforms/whiten_test.py @@ -0,0 +1,61 @@ +import torch + +from sparsecoding.transforms import whiten + + +def test_zca(): + N = 5000 + D = 32*32 + + X = torch.rand((N, D), dtype=torch.float32) + + X_whitened = whiten(X) + + assert torch.allclose( + torch.mean(X_whitened, dim=0), + torch.zeros(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have zero mean." + assert torch.allclose( + torch.cov(X_whitened.T), + torch.eye(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have unit (identity) covariance." + +def test_pca(): + N = 5000 + D = 32*32 + + X = torch.rand((N, D), dtype=torch.float32) + + X_whitened = whiten(X, algorithm='pca') + + assert torch.allclose( + torch.mean(X_whitened, dim=0), + torch.zeros(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have zero mean." + assert torch.allclose( + torch.cov(X_whitened.T), + torch.eye(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have unit (identity) covariance." + +def test_cholesky(): + N = 5000 + D = 32*32 + + X = torch.rand((N, D), dtype=torch.float32) + + X_whitened = whiten(X, algorithm='cholesky') + + assert torch.allclose( + torch.mean(X_whitened, dim=0), + torch.zeros(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have zero mean." + assert torch.allclose( + torch.cov(X_whitened.T), + torch.eye(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have unit (identity) covariance." From a4f3ba99ef4628e9552b43d5bd2234bab4b34732 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 12:48:53 -0800 Subject: [PATCH 14/20] newline --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ee70a21..7a1c20c 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,4 @@ dmypy.json # data raw_data -data \ No newline at end of file +data From 7cefe0687d32b7ded6091b04e8cf0c5f38f0d9e8 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 13:01:25 -0800 Subject: [PATCH 15/20] formatting --- sparsecoding/transforms/__init__.py | 23 +++- sparsecoding/transforms/images.py | 121 +++++++------------ sparsecoding/transforms/images_patch_test.py | 7 +- sparsecoding/transforms/whiten.py | 48 ++++---- sparsecoding/transforms/whiten_test.py | 12 +- 5 files changed, 93 insertions(+), 118 deletions(-) diff --git a/sparsecoding/transforms/__init__.py b/sparsecoding/transforms/__init__.py index 8b77c4d..f01adca 100644 --- a/sparsecoding/transforms/__init__.py +++ b/sparsecoding/transforms/__init__.py @@ -1,7 +1,20 @@ from .whiten import whiten, compute_whitening_stats -from .images import whiten_images, compute_image_whitening_stats, WhiteningTransform, \ - quilt, patchify, sample_random_patches +from .images import ( + whiten_images, + compute_image_whitening_stats, + WhiteningTransform, + quilt, + patchify, + sample_random_patches, +) -__all__ = ['quilt', 'patchify', 'sample_random_patches', 'whiten', - 'compute_whitening_stats', 'compute_image_whitening_stats', - 'WhiteningTransform', 'whiten_images'] +__all__ = [ + "quilt", + "patchify", + "sample_random_patches", + "whiten", + "compute_whitening_stats", + "compute_image_whitening_stats", + "WhiteningTransform", + "whiten_images", +] diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index eb1ecf3..9e0f20b 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -6,32 +6,34 @@ from .whiten import whiten, compute_whitening_stats -def check_images(images: torch.Tensor, algorithm: str = 'zca'): - """Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method - """ +def check_images(images: torch.Tensor, algorithm: str = "zca"): + """Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method""" if len(images.shape) != 4: - raise ValueError('Images must be in shape [N, C, H, W]') + raise ValueError("Images must be in shape [N, C, H, W]") - if images.shape[1] != 1 and algorithm == 'frequency': - raise ValueError("When using frequency based decorrelation, images must" + - f"be grayscale, received {images.shape[1]} channels") + if images.shape[1] != 1 and algorithm == "frequency": + raise ValueError( + "When using frequency based decorrelation, images must" + + f"be grayscale, received {images.shape[1]} channels" + ) # Running cov based methods on large images can eat memory - if algorithm in ['zca', 'pca', 'cholesky'] and (images.shape[2] > 64 or images.shape[3] > 64): - print(f"WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}." + - "It is not recommended to use this for images smaller than 64x64") + if algorithm in ["zca", "pca", "cholesky"] and (images.shape[2] > 64 or images.shape[3] > 64): + print( + f"WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}." + + "It is not recommended to use this for images smaller than 64x64" + ) # Running cov based methods on large images can eat memory - if algorithm == 'frequency' and (images.shape[2] <= 64 or images.shape[3] <= 64): - print(f"WARNING: Running frequency based whitening for images of size {images.shape[2]}x{images.shape[3]}." + - "It is recommended to use this for images larger than 64x64") + if algorithm == "frequency" and (images.shape[2] <= 64 or images.shape[3] <= 64): + print( + f"WARNING: Running frequency based whitening for images of size {images.shape[2]}x{images.shape[3]}." + + "It is recommended to use this for images larger than 64x64" + ) -def whiten_images(images: torch.Tensor, - algorithm: str, - stats: Dict = None, - **kwargs) -> torch.Tensor: +def whiten_images(images: torch.Tensor, algorithm: str, stats: Dict = None, **kwargs) -> torch.Tensor: """ Wrapper for all whitening transformations @@ -48,18 +50,20 @@ def whiten_images(images: torch.Tensor, check_images(images, algorithm) - if algorithm == 'frequency': + if algorithm == "frequency": return frequency_whitening(images, **kwargs) - elif algorithm in ['zca', 'pca', 'cholesky']: + elif algorithm in ["zca", "pca", "cholesky"]: N, C, H, W = images.shape flattened_images = images.flatten(start_dim=1) whitened = whiten(flattened_images, algorithm, stats, **kwargs) return whitened.reshape((N, C, H, W)) else: - raise ValueError(f"Unknown whitening algorithm: {algorithm}, \ - must be one of ['frequency', 'pca', 'zca', 'cholesky]") + raise ValueError( + f"Unknown whitening algorithm: {algorithm}, \ + must be one of ['frequency', 'pca', 'zca', 'cholesky]" + ) def compute_image_whitening_stats(images: torch.Tensor) -> Dict: @@ -95,13 +99,13 @@ def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Te ---------- torch.Tensor: Frequency domain filter """ - fx = torch.linspace(-image_size/2, image_size/2-1, image_size) - fy = torch.linspace(-image_size/2, image_size/2-1, image_size) - fx, fy = torch.meshgrid(fx, fy, indexing='xy') + fx = torch.linspace(-image_size / 2, image_size / 2 - 1, image_size) + fy = torch.linspace(-image_size / 2, image_size / 2 - 1, image_size) + fx, fy = torch.meshgrid(fx, fy, indexing="xy") rho = torch.sqrt(fx**2 + fy**2) f_0 = f0_factor * image_size - filt = rho * torch.exp(-(rho/f_0)**4) + filt = rho * torch.exp(-((rho / f_0) ** 4)) return fft.fftshift(filt) @@ -123,7 +127,7 @@ def get_cached_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor: return create_frequency_filter(image_size, f0_factor) -def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.) -> torch.Tensor: +def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.0) -> torch.Tensor: """ Normalize the variance of a tensor to a target value. @@ -146,11 +150,7 @@ def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.) -> tor return centered -def whiten_channel( - channel: torch.Tensor, - filt: torch.Tensor, - target_variance: float = 1. -) -> torch.Tensor: +def whiten_channel(channel: torch.Tensor, filt: torch.Tensor, target_variance: float = 1.0) -> torch.Tensor: """ Apply frequency domain whitening to a single channel. @@ -181,11 +181,7 @@ def whiten_channel( return whitened -def frequency_whitening( - images: torch.Tensor, - target_variance: float = 0.1, - f0_factor: float = 0.4 -) -> torch.Tensor: +def frequency_whitening(images: torch.Tensor, target_variance: float = 0.1, f0_factor: float = 0.4) -> torch.Tensor: """ Apply frequency domain decorrelation to batched images. Method used in original sparsenet in Olshausen and Field in Nature @@ -211,9 +207,7 @@ def frequency_whitening( # Process each image in the batch whitened_batch = [] for img in images: - whitened_batch.append( - whiten_channel(img[0], filt, target_variance) - ) + whitened_batch.append(whiten_channel(img[0], filt, target_variance)) return torch.stack(whitened_batch).unsqueeze(1) @@ -223,13 +217,8 @@ class WhiteningTransform(object): A PyTorch transform for image whitening that can be used in a transform pipeline. Supports frequency, PCA, and ZCA whitening methods. """ - def __init__( - self, - algorithm: str = 'zca', - stats: Optional[Dict] = None, - compute_stats: bool = False, - **kwargs - ): + + def __init__(self, algorithm: str = "zca", stats: Optional[Dict] = None, compute_stats: bool = False, **kwargs): """ Initialize whitening transform. @@ -266,12 +255,7 @@ def __call__(self, images: torch.Tensor) -> torch.Tensor: check_images(images) # Apply whitening - whitened = whiten_images( - images, - self.algorithm, - self.stats, - **self.kwargs - ) + whitened = whiten_images(images, self.algorithm, self.stats, **self.kwargs) # Remove batch dimension if input was single image if single_image: @@ -322,18 +306,11 @@ def sample_random_patches( size=(N,), ) - h_patch_idxs, w_patch_idxs = torch.meshgrid( - torch.arange(P), - torch.arange(P), - indexing='ij' - ) + h_patch_idxs, w_patch_idxs = torch.meshgrid(torch.arange(P), torch.arange(P), indexing="ij") h_idxs = h_start_idx.reshape(N, 1, 1) + h_patch_idxs w_idxs = w_start_idx.reshape(N, 1, 1) + w_patch_idxs - leading_idxs = [ - torch.randint(low=0, high=image.shape[d], size=(N, 1, 1)) - for d in range(image.dim() - 3) - ] + leading_idxs = [torch.randint(low=0, high=image.shape[d], size=(N, 1, 1)) for d in range(image.dim() - 3)] idxs = leading_idxs + [slice(None), h_idxs, w_idxs] @@ -379,20 +356,14 @@ def patchify( if stride is None: stride = P - if ( - H % P != 0 - or W % P != 0 - ): + if H % P != 0 or W % P != 0: warnings.warn( f"Image size ({H, W}) not evenly divisible by `patch_size` ({P})," f"parts on the bottom and/or right will be cropped.", UserWarning, ) - N = ( - int((H - P + 1 + stride) // stride) - * int((W - P + 1 + stride) // stride) - ) + N = int((H - P + 1 + stride) // stride) * int((W - P + 1 + stride) // stride) patches = torch.nn.functional.unfold( input=image.reshape(-1, C, H, W), @@ -439,22 +410,16 @@ def quilt( W = width if int(H / P) * int(W / P) != N: - raise ValueError( - f"Expected {N} patches per image, " - f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}." - ) + raise ValueError(f"Expected {N} patches per image, " f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}.") - if ( - H % P != 0 - or W % P != 0 - ): + if H % P != 0 or W % P != 0: warnings.warn( f"Image size ({H, W}) not evenly divisible by `patch_size` ({P})," f"parts on the bottom and/or right will be zeroed.", UserWarning, ) - patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P] + patches = patches.reshape(-1, N, C * P * P) # [prod(*), N, C*P*P] patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N] image = torch.nn.functional.fold( input=patches, diff --git a/sparsecoding/transforms/images_patch_test.py b/sparsecoding/transforms/images_patch_test.py index f7ff395..d7999c5 100644 --- a/sparsecoding/transforms/images_patch_test.py +++ b/sparsecoding/transforms/images_patch_test.py @@ -21,6 +21,7 @@ def test_patchify_quilt_cycle(): quilted_images, ), "Quilted images should be equal to input images." + def test_sample_random_patches(): X, Y, Z = 3, 4, 5 C = 3 @@ -40,9 +41,9 @@ def test_sample_random_patches(): kernel_size=P, ) # [prod(*), C*P*P, L] all_patches = torch.permute(all_patches, (0, 2, 1)) # [prod(*), L, C*P*P] - all_patches = torch.reshape(all_patches, (-1, C*P*P)) + all_patches = torch.reshape(all_patches, (-1, C * P * P)) for n in range(N): - patch = random_patches[n].reshape(1, C*P*P) + patch = random_patches[n].reshape(1, C * P * P) delta = torch.abs(patch - all_patches) # [-1, C*P*P] patchwise_delta = torch.sum(delta, dim=1) # [-1] - assert torch.min(patchwise_delta) == 0. + assert torch.min(patchwise_delta) == 0.0 diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index 64641e4..4206adc 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -26,21 +26,17 @@ def compute_whitening_stats(X: torch.Tensor): eigenvalues = torch.flip(eigenvalues, dims=[0]) eigenvectors = torch.flip(eigenvectors, dims=[1]) - return { - 'mean': mean, - 'eigenvalues': eigenvalues, - 'eigenvectors': eigenvectors, - 'covariance': Sigma - } - - -def whiten(X: torch.Tensor, - algorithm: str = 'zca', - stats: Dict = None, - n_components: float = None, - epsilon: float = 0., - return_W: bool = False - ) -> torch.Tensor: + return {"mean": mean, "eigenvalues": eigenvalues, "eigenvectors": eigenvectors, "covariance": Sigma} + + +def whiten( + X: torch.Tensor, + algorithm: str = "zca", + stats: Dict = None, + n_components: float = None, + epsilon: float = 0.0, + return_W: bool = False, +) -> torch.Tensor: """ Apply whitening transform to data using pre-computed statistics. @@ -73,21 +69,21 @@ def whiten(X: torch.Tensor, if stats is None: stats = compute_whitening_stats(X) - x_centered = X - stats.get('mean') + x_centered = X - stats.get("mean") - if algorithm == 'pca' or algorithm == 'zca': + if algorithm == "pca" or algorithm == "zca": - scaling = 1. / torch.sqrt(stats.get('eigenvalues') + epsilon) + scaling = 1.0 / torch.sqrt(stats.get("eigenvalues") + epsilon) if n_components is not None: if isinstance(n_components, float): if not 0 < n_components <= 1: raise ValueError("If n_components is float, it must be between 0 and 1") - explained_variance_ratio = stats.get('eigenvalues') / torch.sum(stats.get('eigenvalues')) + explained_variance_ratio = stats.get("eigenvalues") / torch.sum(stats.get("eigenvalues")) cumulative_variance_ratio = torch.cumsum(explained_variance_ratio, dim=0) n_components = torch.sum(cumulative_variance_ratio <= n_components) + 1 elif isinstance(n_components, int): - if not 0 < n_components <= len(stats.get('eigenvalues')): + if not 0 < n_components <= len(stats.get("eigenvalues")): raise ValueError(f"n_components must be between 1 and {len(stats.get('eigenvalues'))}") else: raise ValueError("n_components must be int or float") @@ -98,17 +94,15 @@ def whiten(X: torch.Tensor, scaling = torch.diag(scaling) - if algorithm == 'pca': + if algorithm == "pca": # For PCA: project onto eigenvectors and scale - W = scaling @ stats.get('eigenvectors').T + W = scaling @ stats.get("eigenvectors").T else: # For ZCA: project, scale, and rotate back - W = (stats.get('eigenvectors') @ - scaling @ - stats.get('eigenvectors').T) - elif algorithm == 'cholesky': + W = stats.get("eigenvectors") @ scaling @ stats.get("eigenvectors").T + elif algorithm == "cholesky": # Based on Cholesky decomp, related to QR decomp - L = torch.linalg.cholesky(stats.get('covariance')) + L = torch.linalg.cholesky(stats.get("covariance")) Identity = torch.eye(L.shape[0], device=L.device, dtype=L.dtype) # Solve L @ W = I for W, more stable and quicker than inv(L) W = torch.linalg.solve_triangular(L, Identity, upper=False) diff --git a/sparsecoding/transforms/whiten_test.py b/sparsecoding/transforms/whiten_test.py index 4cdd97d..9017da0 100644 --- a/sparsecoding/transforms/whiten_test.py +++ b/sparsecoding/transforms/whiten_test.py @@ -5,7 +5,7 @@ def test_zca(): N = 5000 - D = 32*32 + D = 32 * 32 X = torch.rand((N, D), dtype=torch.float32) @@ -22,13 +22,14 @@ def test_zca(): atol=1e-3, ), "Whitened data should have unit (identity) covariance." + def test_pca(): N = 5000 - D = 32*32 + D = 32 * 32 X = torch.rand((N, D), dtype=torch.float32) - X_whitened = whiten(X, algorithm='pca') + X_whitened = whiten(X, algorithm="pca") assert torch.allclose( torch.mean(X_whitened, dim=0), @@ -41,13 +42,14 @@ def test_pca(): atol=1e-3, ), "Whitened data should have unit (identity) covariance." + def test_cholesky(): N = 5000 - D = 32*32 + D = 32 * 32 X = torch.rand((N, D), dtype=torch.float32) - X_whitened = whiten(X, algorithm='cholesky') + X_whitened = whiten(X, algorithm="cholesky") assert torch.allclose( torch.mean(X_whitened, dim=0), From 7861eb9d63c8c9f933db422804a9bb81e739ace3 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 13:12:49 -0800 Subject: [PATCH 16/20] fix setting seed --- conftest.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/conftest.py b/conftest.py index 31b4f34..e49cd90 100644 --- a/conftest.py +++ b/conftest.py @@ -1,24 +1,25 @@ +import pytest import torch -from sparsecoding.test_utils import ( - bars_datas_fixture, - bars_datasets_fixture, - bars_dictionary_fixture, - dataset_size_fixture, - patch_size_fixture, - priors_fixture, -) +from sparsecoding.test_utils import (bars_datas_fixture, bars_datasets_fixture, + bars_dictionary_fixture, + dataset_size_fixture, patch_size_fixture, + priors_fixture) -torch.manual_seed(1997) + +@pytest.fixture(autouse=True) +def set_seed(): + torch.manual_seed(1997) # We import and define all fixtures in this file. # This allows users to avoid any dependency fixtures. # NOTE: This means pytest should only be run from this directory. __all__ = [ - "dataset_size_fixture", - "patch_size_fixture", "bars_datas_fixture", "bars_datasets_fixture", "bars_dictionary_fixture", + "dataset_size_fixture", + "patch_size_fixture", "priors_fixture", + "set_seed", ] From da4080e80dce5c8c30b80d30a71913296de05f40 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 13:13:03 -0800 Subject: [PATCH 17/20] bump tolerances --- sparsecoding/inference/lsm_test.py | 2 +- sparsecoding/inference/pytorch_optimizer_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sparsecoding/inference/lsm_test.py b/sparsecoding/inference/lsm_test.py index 6ac2155..8cd3470 100644 --- a/sparsecoding/inference/lsm_test.py +++ b/sparsecoding/inference/lsm_test.py @@ -38,4 +38,4 @@ def test_inference( a = inference_method.infer(data, bars_dictionary_fixture) - assert_allclose(a, dataset.weights, atol=5e-2) + assert_allclose(a, dataset.weights, atol=6e-2) diff --git a/sparsecoding/inference/pytorch_optimizer_test.py b/sparsecoding/inference/pytorch_optimizer_test.py index e7322bc..87961af 100644 --- a/sparsecoding/inference/pytorch_optimizer_test.py +++ b/sparsecoding/inference/pytorch_optimizer_test.py @@ -77,4 +77,4 @@ def test_inference( a = inference_method.infer(data, bars_dictionary_fixture) - assert_allclose(a, dataset.weights, atol=1e-1, rtol=1e-1) + assert_allclose(a, dataset.weights, atol=2e-1, rtol=1e-1) From cbc64b516438fcc314fb132c6fc0ff768405f970 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 13:14:58 -0800 Subject: [PATCH 18/20] up python min version to match pytest --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7925c5e..efde059 100644 --- a/setup.py +++ b/setup.py @@ -17,5 +17,5 @@ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires='>=3.8', ) From b441ae44412ad415c456b3ff4c00f985fbbc7632 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 31 Dec 2024 13:17:21 -0800 Subject: [PATCH 19/20] formatting --- conftest.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/conftest.py b/conftest.py index e49cd90..98637e1 100644 --- a/conftest.py +++ b/conftest.py @@ -1,16 +1,21 @@ import pytest import torch -from sparsecoding.test_utils import (bars_datas_fixture, bars_datasets_fixture, - bars_dictionary_fixture, - dataset_size_fixture, patch_size_fixture, - priors_fixture) +from sparsecoding.test_utils import ( + bars_datas_fixture, + bars_datasets_fixture, + bars_dictionary_fixture, + dataset_size_fixture, + patch_size_fixture, + priors_fixture, +) @pytest.fixture(autouse=True) def set_seed(): torch.manual_seed(1997) + # We import and define all fixtures in this file. # This allows users to avoid any dependency fixtures. # NOTE: This means pytest should only be run from this directory. From 43e34c03a62bedcb5eee4679ebfd4adf2e41b916 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 14 Jan 2025 16:29:21 -0800 Subject: [PATCH 20/20] rename file --- sparsecoding/transforms/{images_patch_test.py => images_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sparsecoding/transforms/{images_patch_test.py => images_test.py} (100%) diff --git a/sparsecoding/transforms/images_patch_test.py b/sparsecoding/transforms/images_test.py similarity index 100% rename from sparsecoding/transforms/images_patch_test.py rename to sparsecoding/transforms/images_test.py