diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5fb44d8..225983a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,26 +2,30 @@ name: "Testing" on: push: - paths-ignore: - - 'tutorials/**' - - 'README.md' - - '.gitignore' - - 'examples/**' + branches: + - main + pull_request: jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python all python version - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11 architecture: x64 - name: Install dependencies - run: pip install -r requirements.txt + run: | + python -m venv --upgrade-deps .venv + source .venv/bin/activate + pip install -r requirements.txt + pip install -r requirements-dev.txt - name: Run Test - run: python -m unittest discover tests -v + run: | + source .venv/bin/activate + python -m pytest . diff --git a/.gitignore b/.gitignore index d3bd7e6..7a1c20c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# OS specific +.DS_Store + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -130,3 +133,4 @@ dmypy.json # data raw_data +data diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..98637e1 --- /dev/null +++ b/conftest.py @@ -0,0 +1,30 @@ +import pytest +import torch + +from sparsecoding.test_utils import ( + bars_datas_fixture, + bars_datasets_fixture, + bars_dictionary_fixture, + dataset_size_fixture, + patch_size_fixture, + priors_fixture, +) + + +@pytest.fixture(autouse=True) +def set_seed(): + torch.manual_seed(1997) + + +# We import and define all fixtures in this file. +# This allows users to avoid any dependency fixtures. +# NOTE: This means pytest should only be run from this directory. +__all__ = [ + "bars_datas_fixture", + "bars_datasets_fixture", + "bars_dictionary_fixture", + "dataset_size_fixture", + "patch_size_fixture", + "priors_fixture", + "set_seed", +] diff --git a/docs/contributing.md b/docs/contributing.md index 804bdba..5a49a47 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,35 +1,41 @@ # Contributing + All contributions are welcome! ## Bug Reporting -If you find a bug, submit a bug report on GitHub Issues. + +If you find a bug, submit a bug report on GitHub Issues. ## Adding Features/Fixing Bugs + If you have identified a new feature or bug that you can fix yourself, please follow the following procedure. -1. Clone `main` branch. -2. Create a new branch to contain your changes. -2. `add`, `commit`, and `push` your changes to this branch. -3. Create a pull request (PR). See more information on submitting a PR request below. +1. Fork `main` branch. +2. Create a new branch to contain your changes. +3. `add`, `commit`, and `push` your changes to this branch. +4. Create a pull request (PR). See more information on submitting a PR request below. ### Submitting a Pull Request -1. If necessary, please **write your own unit tests** and add them to [the tests directory](https://github.com/rctn/sparsecoding/blob/main/docs/contributing.md). -2. Verify that all tests are passed by running `python -m unittest tests/*`. -3. Be sure that your PR follows formatting guidelines, [PEP8](https://peps.python.org/pep-0008/) and [flake8](https://flake8.pycqa.org/en/latest/). -4. Make sure the title of your PR summarizes the features/issues resolved in your branch. -5. Submit your pull request and add reviewers. + +1. If necessary, please **write your own unit tests** and place them near the code being tested. High-level tests, such as integration or example tests can be placed in the top-level "tests" folder. +2. Verify that all tests are passed by running `python -m pytest .`. +3. Be sure that your PR follows formatting guidelines, [PEP8](https://peps.python.org/pep-0008/) and [flake8](https://flake8.pycqa.org/en/latest/). +4. Make sure the title of your PR summarizes the features/issues resolved in your branch. +5. Submit your pull request and add reviewers. ## Coding Style Guidelines -The following are some guidelines on how new code should be written. Of course, there are special cases, and there will be exceptions to these rules. + +The following are some guidelines on how new code should be written. Of course, there are special cases, and there will be exceptions to these rules. 1. Format code in accordance with [flake8](https://flake8.pycqa.org/en/latest/) standard. 2. Use underscores to separate words in non-class names: `n_samples` rather than `nsamples`. 3. Avoid single-character variable names. ## Docstrings + When writing docstrings, please follow the following example. -``` +```py def count_beans(self, baz, use_gpu=False, foo="vector" bar=None): """Write a one-line summary for the method. @@ -69,4 +75,4 @@ def count_beans(self, baz, use_gpu=False, foo="vector" quantum-mechanical description of physical reality be considered complete?. Physical review, 47(10), 777. """ -``` \ No newline at end of file +``` diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..4191a3f --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +pylint +pytest +pyright \ No newline at end of file diff --git a/setup.py b/setup.py index 7925c5e..efde059 100644 --- a/setup.py +++ b/setup.py @@ -17,5 +17,5 @@ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ], - python_requires='>=3.6', + python_requires='>=3.8', ) diff --git a/sparsecoding/inference.py b/sparsecoding/inference.py deleted file mode 100644 index 83ed11f..0000000 --- a/sparsecoding/inference.py +++ /dev/null @@ -1,940 +0,0 @@ -import numpy as np -import torch - - -class InferenceMethod: - """Base class for inference method.""" - - def __init__(self, solver): - """ - Parameters - ---------- - """ - self.solver = solver - - def initialize(self, a): - """Define initial coefficients. - - Parameters - ---------- - - Returns - ------- - """ - raise NotImplementedError - - def grad(self): - """Compute the gradient step. - - Parameters - ---------- - - Returns - ------- - """ - raise NotImplementedError - - def infer(self, dictionary, data, coeff_0=None, use_checknan=False): - """Infer the coefficients given a dataset and dictionary. - - Parameters - ---------- - dictionary : array-like, shape [n_features,n_basis] - - data : array-like, shape [n_samples,n_features] - - coeff_0 : array-like, shape [n_samples,n_basis], optional - Initial coefficient values. - use_checknan : bool, default=False - Check for nans in coefficients on each iteration - - Returns - ------- - coefficients : array-like, shape [n_samples,n_basis] - """ - raise NotImplementedError - - @staticmethod - def checknan(data=torch.tensor(0), name="data"): - """Check for nan values in data. - - Parameters - ---------- - data : array-like, optional - Data to check for nans - name : str, default="data" - Name to add to error, if one is thrown - - Raises - ------ - ValueError - If the nan found in data - """ - if torch.isnan(data).any(): - raise ValueError("InferenceMethod error: nan in %s." % (name)) - - -class LCA(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=0.1, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients="none", nonnegative=False): - """Method implemented according locally competative algorithm (LCA) - with the ideal soft thresholding function. - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run - coeff_lr : float, default=1e-3 - Update rate of coefficient dynamics - threshold : float, default=0.1 - Threshold for non-linearity - stop_early : bool, default=False - Stops dynamics early based on change in coefficents - epsilon : float, default=1e-2 - Only used if stop_early True, specifies criteria to stop dynamics - nonnegative : bool, default=False - Constrain coefficients to be nonnegative - return_all_coefficients : str, {"none", "membrane", "active"}, default="none" - Returns all coefficients during inference procedure if not equal - to "none". If return_all_coefficients=="membrane", membrane - potentials (u) returned. If return_all_coefficients=="active", - active units (a) (output of thresholding function over u) returned. - User beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Rozell, C. J., Johnson, D. H., Baraniuk, R. G., & Olshausen, - B. A. (2008). Sparse coding via thresholding and local competition - in neural circuits. Neural computation, 20(10), 2526-2563. - """ - super().__init__(solver) - self.threshold = threshold - self.coeff_lr = coeff_lr - self.stop_early = stop_early - self.epsilon = epsilon - self.n_iter = n_iter - self.nonnegative = nonnegative - if return_all_coefficients not in ["none", "membrane", "active"]: - raise ValueError("Invalid input for return_all_coefficients. Valid" - "inputs are: \"none\", \"membrane\", \"active\".") - self.return_all_coefficients = return_all_coefficients - - def threshold_nonlinearity(self, u): - """Soft threshhold function - - Parameters - ---------- - u : array-like, shape [batch_size, n_basis] - Membrane potentials - - Returns - ------- - a : array-like, shape [batch_size, n_basis] - Activations - """ - if self.nonnegative: - a = (u - self.threshold).clamp(min=0.) - else: - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a - return a - - def grad(self, b, G, u, a): - """Compute the gradient step on membrane potentials - - Parameters - ---------- - b : array-like, shape [batch_size, n_coefficients] - Driver signal for coefficients - G : array-like, shape [n_coefficients, n_coefficients] - Inhibition matrix - a : array-like, shape [batch_size, n_coefficients] - Currently active coefficients - - Returns - ------- - du : array-like, shape [batch_size, n_coefficients] - Gradient of membrane potentials - """ - du = b-u-(G@a.t()).t() - return du - - def infer(self, data, dictionary, coeff_0=None, use_checknan=False): - """Infer coefficients using provided dictionary - - Parameters - ---------- - dictionary : array-like, shape [n_features, n_basis] - - data : array-like, shape [n_samples, n_features] - - coeff_0 : array-like, shape [n_samples, n_basis], optional - Initial coefficient values - use_checknan : bool, default=False - Check for nans in coefficients on each iteration. Setting this to - False can speed up inference time. - - Returns - ------- - coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] - First case occurs if return_all_coefficients == "none". If - return_all_coefficients != "none", returned shape is second case. - Returned dimension along dim 1 can be less than n_iter when - stop_early==True and stopping criteria met. - """ - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # initialize - if coeff_0 is not None: - u = coeff_0.to(device) - else: - u = torch.zeros((batch_size, n_basis)).to(device) - - coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - - b = (dictionary.t()@data.t()).t() - G = dictionary.t()@dictionary-torch.eye(n_basis).to(device) - for i in range(self.n_iter): - # store old membrane potentials to evalute stop early condition - if self.stop_early: - old_u = u.clone().detach() - - # check return all - if self.return_all_coefficients != "none": - if self.return_all_coefficients == "active": - coefficients = torch.concat( - [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) - else: - coefficients = torch.concat( - [coefficients, u.clone().unsqueeze(1)], dim=1) - - # compute new - a = self.threshold_nonlinearity(u) - du = self.grad(b, G, u, a) - u = u + self.coeff_lr*du - - # check stopping condition - if self.stop_early: - relative_change_in_coeff = torch.linalg.norm(old_u - u)/torch.linalg.norm(old_u) - if relative_change_in_coeff < self.epsilon: - break - - if use_checknan: - self.checknan(u, "coefficients") - - # return active units if return_all_coefficients in ["none", "active"] - if self.return_all_coefficients == "membrane": - coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) - else: - final_coefficients = self.threshold_nonlinearity(u) - coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1) - - return coefficients.squeeze() - - -class Vanilla(InferenceMethod): - def __init__(self, n_iter=100, coeff_lr=1e-3, sparsity_penalty=0.2, - stop_early=False, epsilon=1e-2, solver=None, - return_all_coefficients=False): - """Gradient descent with Euler's method on model in Olshausen & Field - (1997) with laplace prior over coefficients (corresponding to l-1 norm - penalty). - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run - coeff_lr : float, default=1e-3 - Update rate of coefficient dynamics - sparsity_penalty : float, default=0.2 - - stop_early : bool, default=False - Stops dynamics early based on change in coefficents - epsilon : float, default=1e-2 - Only used if stop_early True, specifies criteria to stop dynamics - return_all_coefficients : str, default=False - Returns all coefficients during inference procedure if True - User beware: If n_iter is large, setting this parameter to True - Can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Olshausen, B. A., & Field, D. J. (1997). Sparse coding with an - overcomplete basis set: A strategy employed by V1?. Vision research, - 37(23), 3311-3325. - """ - super().__init__(solver) - self.coeff_lr = coeff_lr - self.sparsity_penalty = sparsity_penalty - self.stop_early = stop_early - self.epsilon = epsilon - self.n_iter = n_iter - self.return_all_coefficients = return_all_coefficients - - def grad(self, residual, dictionary, a): - """Compute the gradient step on coefficients - - Parameters - ---------- - residual : array-like, shape [batch_size, n_features] - Residual between reconstructed image and original - dictionary : array-like, shape [n_features,n_coefficients] - Dictionary - a : array-like, shape [batch_size, n_coefficients] - Coefficients - - Returns - ------- - da : array-like, shape [batch_size, n_coefficients] - Gradient of membrane potentials - """ - da = (dictionary.t()@residual.t()).t() - \ - self.sparsity_penalty*torch.sign(a) - return da - - def infer(self, data, dictionary, coeff_0=None, use_checknan=False): - """Infer coefficients using provided dictionary - - Parameters - ---------- - dictionary : array-like, shape [n_features, n_basis] - Dictionary - data : array like, shape [n_samples, n_features] - - coeff_0 : array-like, shape [n_samples, n_basis], optional - Initial coefficient values - use_checknan : bool, default=False - check for nans in coefficients on each iteration. Setting this to - False can speed up inference time - - Returns - ------- - coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] - First case occurs if return_all_coefficients == "none". If - return_all_coefficients != "none", returned shape is second case. - Returned dimension along dim 1 can be less than n_iter when - stop_early==True and stopping criteria met. - """ - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # initialize - if coeff_0 is not None: - a = coeff_0.to(device) - else: - a = torch.rand((batch_size, n_basis)).to(device)-0.5 - - coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - - residual = data - (dictionary@a.t()).t() - for i in range(self.n_iter): - - if self.return_all_coefficients: - coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) - - if self.stop_early: - old_a = a.clone().detach() - - da = self.grad(residual, dictionary, a) - a = a + self.coeff_lr*da - - if self.stop_early: - if torch.linalg.norm(old_a - a)/torch.linalg.norm(old_a) < self.epsilon: - break - - residual = data - (dictionary@a.t()).t() - - if use_checknan: - self.checknan(a, "coefficients") - - coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) - return torch.squeeze(coefficients) - - -class ISTA(InferenceMethod): - def __init__(self, n_iter=100, sparsity_penalty=1e-2, stop_early=False, - epsilon=1e-2, solver=None, return_all_coefficients=False): - """Iterative shrinkage-thresholding algorithm for solving LASSO problems. - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run - sparsity_penalty : float, default=0.2 - - stop_early : bool, default=False - Stops dynamics early based on change in coefficents - epsilon : float, default=1e-2 - Only used if stop_early True, specifies criteria to stop dynamics - return_all_coefficients : str, default=False - Returns all coefficients during inference procedure if True - User beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Beck, A., & Teboulle, M. (2009). A fast iterative - shrinkage-thresholding algorithm for linear inverse problems. - SIAM journal on imaging sciences, 2(1), 183-202. - """ - super().__init__(solver) - self.n_iter = n_iter - self.sparsity_penalty = sparsity_penalty - self.stop_early = stop_early - self.epsilon = epsilon - self.coefficients = None - self.return_all_coefficients = return_all_coefficients - - def threshold_nonlinearity(self, u): - """Soft threshhold function - - Parameters - ---------- - u : array-likes, shape [batch_size, n_basis] - Membrane potentials - - Returns - ------- - a : array-like, shape [batch_size, n_basis] - activations - """ - a = (torch.abs(u) - self.threshold).clamp(min=0.) - a = torch.sign(u)*a - return a - - def infer(self, data, dictionary, coeff_0=None, use_checknan=False): - """Infer coefficients for each image in data using dictionary elements. - Uses ISTA (Beck & Taboulle 2009), equations 1.4 and 1.5. - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - - dictionary : array-like, shape [n_features, n_basis] - - coeff_0 : array-like, shape [n_samples, n_basis], optional - Initial coefficient values - use_checknan : bool, default=False - Check for nans in coefficients on each iteration. Setting this to - False can speed up inference time. - Returns - ------- - coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] - First case occurs if return_all_coefficients == "none". If - return_all_coefficients != "none", returned shape is second case. - Returned dimension along dim 1 can be less than n_iter when - stop_early==True and stopping criteria met. - """ - batch_size = data.shape[0] - n_basis = dictionary.shape[1] - device = dictionary.device - - # Calculate stepsize based on largest eigenvalue of - # dictionary.T @ dictionary. - lipschitz_constant = torch.linalg.eigvalsh( - torch.mm(dictionary.T, dictionary))[-1] - stepsize = 1. / lipschitz_constant - self.threshold = stepsize * self.sparsity_penalty - - # Initialize coefficients. - if coeff_0 is not None: - u = coeff_0.to(device) - else: - u = torch.zeros((batch_size, n_basis)).to(device) - coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) - self.coefficients = self.threshold_nonlinearity(u) - residual = torch.mm(dictionary, self.coefficients.T).T - data - - for _ in range(self.n_iter): - if self.stop_early: - old_u = u.clone().detach() - - if self.return_all_coefficients: - coefficients = torch.concat([coefficients, - self.threshold_nonlinearity(u).clone().unsqueeze(1)], dim=1) - - u -= stepsize * torch.mm(residual, dictionary) - self.coefficients = self.threshold_nonlinearity(u) - - if self.stop_early: - # Stopping condition is function of change of the coefficients. - a_change = torch.mean( - torch.abs(old_u - u) / stepsize) - if a_change < self.epsilon: - break - - residual = torch.mm(dictionary, self.coefficients.T).T - data - u = self.coefficients - - if use_checknan: - self.checknan(u, "coefficients") - - coefficients = torch.concat([coefficients, self.coefficients.clone().unsqueeze(1)], dim=1) - return torch.squeeze(coefficients) - - -class LSM(InferenceMethod): - def __init__(self, n_iter=100, n_iter_LSM=6, beta=0.01, alpha=80.0, - sigma=0.005, sparse_threshold=10**-2, solver=None, - return_all_coefficients=False): - """Infer latent coefficients generating data given dictionary. - Method implemented according to "Group Sparse Coding with a Laplacian - Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010) - - Parameters - ---------- - n_iter : int, default=100 - Number of iterations to run for an optimizer - n_iter_LSM : int, default=6 - Number of iterations to run the outer loop of LSM - beta : float, default=0.01 - LSM parameter used to update lambdas - alpha : float, default=80.0 - LSM parameter used to update lambdas - sigma : float, default=0.005 - LSM parameter used to compute the loss function - sparse_threshold : float, default=10**-2 - Threshold used to discard smallest coefficients in the final - solution SM parameter used to compute the loss function - return_all_coefficients : bool, default=False - Returns all coefficients during inference procedure if True - User beware: If n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This - function typically used for debugging. - solver : default=None - - References - ---------- - [1] Garrigues, P., & Olshausen, B. (2010). Group sparse coding with - a laplacian scale mixture prior. Advances in neural information - processing systems, 23. - """ - super().__init__(solver) - self.n_iter = n_iter - self.n_iter_LSM = n_iter_LSM - self.beta = beta - self.alpha = alpha - self.sigma = sigma - self.sparse_threshold = sparse_threshold - self.return_all_coefficients = return_all_coefficients - - def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma): - """Compute LSM loss according to equation (7) in (P. J. Garrigues & - B. A. Olshausen, 2010) - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - Data to be used in sparse coding - dictionary : array-like, shape [n_features, n_basis] - Dictionary to be used - coefficients : array-like, shape [batch_size, n_basis] - The current values of coefficients - lambdas : array-like, shape [batch_size, n_basis] - The current values of regularization coefficient for all basis - sigma : float, default=0.005 - LSM parameter used to compute the loss functions - - Returns - ------- - loss : array-like, shape [batch_size, 1] - Loss values for each data sample - """ - - # Compute loss - preds = torch.mm(dictionary, coefficients.t()).t() - mse_loss = (1/(2*(sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) - sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True) - loss = mse_loss + sparse_loss - return loss - - def infer(self, data, dictionary): - """Infer coefficients for each image in data using dict elements - dictionary using Laplacian Scale Mixture (LSM) - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - Data to be used in sparse coding - dictionary : array-like, shape [n_features, n_basis] - Dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like, shape [batch_size, n_basis] - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Initialize coefficients for the whole batch - coefficients = torch.zeros(batch_size, n_basis, device=device, requires_grad=True) - - # Set up optimizer - optimizer = torch.optim.Adam([coefficients], lr=1e-1) - - # Outer loop, set sparsity penalties (lambdas). - for i in range(self.n_iter_LSM): - # Compute the initial values of lambdas - lambdas = ( - (self.alpha + 1) - / (self.beta + torch.abs(coefficients.detach())) - ) - - # Inner loop, optimize coefficients w/ current sparsity penalties. - # Exits early if converged before `n_iter`s. - last_loss = None - for t in range(self.n_iter): - # compute LSM loss for the current iteration - loss = self.lsm_Loss( - data=data, - dictionary=dictionary, - coefficients=coefficients, - lambdas=lambdas, - sigma=self.sigma, - ) - loss = torch.sum(loss) - - # Backward pass: compute gradient and update model parameters. - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # Break if coefficients have converged. - if ( - last_loss is not None - and loss > 1.05 * last_loss - ): - break - - last_loss = loss - - # Sparsify the final solution by discarding the small coefficients - coefficients.data[torch.abs(coefficients.data) - < self.sparse_threshold] = 0 - - return coefficients.detach() - - -class PyTorchOptimizer(InferenceMethod): - def __init__(self, optimizer_f, loss_f, n_iter=100, solver=None): - """Infer coefficients using provided loss functional and optimizer - - Parameters - ---------- - optimizer : function handle - Pytorch optimizer handle have single parameter: - (coefficients) - where coefficients is of shape [batch_size, n_basis] - loss_f : function handle - Must have parameters: - (data, dictionary, coefficients) - where data is of shape [batch_size, n_features] - and loss_f must return tensor of size [batch_size,] - n_iter : int, default=100 - Number of iterations to run for an optimizer - solver : default=None - """ - super().__init__(solver) - self.optimizer_f = optimizer_f - self.loss_f = loss_f - self.n_iter = n_iter - - def infer(self, data, dictionary, coeff_0=None): - """Infer coefficients for each image in data using dict elements - dictionary by minimizing provided loss function with provided - optimizer. - - Parameters - ---------- - data : array-like, shape [batch_size, n_features] - Data to be used in sparse coding - - dictionary : array-like, shape [n_features, n_basis] - Dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like, shape [batch_size, n_basis] - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Initialize coefficients for the whole batch - - # initialize - if coeff_0 is not None: - coefficients = coeff_0.requires_grad_(True) - else: - coefficients = torch.zeros((batch_size, n_basis), requires_grad=True, device=device) - - optimizer = self.optimizer_f([coefficients]) - - for i in range(self.n_iter): - - # compute LSM loss for the current iteration - loss = self.loss_f( - data=data, - dictionary=dictionary, - coefficients=coefficients, - ) - - optimizer.zero_grad() - - # Backward pass: compute gradient of the loss with respect to - # model parameters - loss.backward(torch.ones((batch_size,), device=device)) - - # Calling the step function on an Optimizer makes an update to its - # parameters - optimizer.step() - - return coefficients.detach() - - -class IHT(InferenceMethod): - """ - Infer coefficients for each image in data using elements dictionary. - Method description can be traced to - "Iterative Hard Thresholding for Compressed Sensing" (T. Blumensath & M. E. Davies, 2009) - """ - - def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False): - ''' - - Parameters - ---------- - sparsity : scalar (1,) - Sparsity of the solution. The number of active coefficients will be set - to ceil(sparsity * data_dim) at the end of each iterative update. - n_iter : scalar (1,) default=100 - number of iterations to run for an inference method - return_all_coefficients : string (1,) default=False - returns all coefficients during inference procedure if True - user beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This function typically used for - debugging - solver : default=None - ''' - super().__init__(solver) - self.n_iter = n_iter - self.sparsity = sparsity - self.return_all_coefficients = return_all_coefficients - - def infer(self, data, dictionary): - """ - Infer coefficients for each image in data using dict elements dictionary using Iterative Hard Thresholding (IHT) - - Parameters - ---------- - data : array-like (batch_size, n_features) - data to be used in sparse coding - dictionary : array-like, (n_features, n_basis) - dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like (batch_size, n_basis) - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) - - # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) - - for _ in range(self.n_iter): - # Compute the prediction given the current coefficients - preds = coefficients @ dictionary.T # [batch_size, n_features] - - # Compute the residual - delta = data - preds # [batch_size, n_features] - - # Compute the similarity between the residual and the atoms in the dictionary - update = delta @ dictionary # [batch_size, n_basis] - coefficients = coefficients + update # [batch_size, n_basis] - - # Apply kWTA nonlinearity - topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1) - - # Reconstruct coefficients using the output of torch.topk - coefficients = ( - torch.sign(coefficients) - * torch.zeros(batch_size, n_basis, device=device).scatter_(1, indices, topK_values) - ) - - return coefficients.detach() - - -class MP(InferenceMethod): - """ - Infer coefficients for each image in data using elements dictionary. - Method description can be traced - to "Matching Pursuits with Time-Frequency Dictionaries" (S. G. Mallat & Z. Zhang, 1993) - """ - - def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' - - Parameters - ---------- - sparsity : scalar (1,) - sparsity of the solution - return_all_coefficients : string (1,) default=False - returns all coefficients during inference procedure if True - user beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This function typically used for - debugging - solver : default=None - ''' - super().__init__(solver) - self.sparsity = sparsity - self.return_all_coefficients = return_all_coefficients - - def infer(self, data, dictionary): - """ - Infer coefficients for each image in data using dict elements dictionary using Matching Pursuit (MP) - - Parameters - ---------- - data : array-like (batch_size, n_features) - data to be used in sparse coding - dictionary : array-like, (n_features, n_basis) - dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like (batch_size, n_basis) - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) - - # Get dictionary norms in case atoms are not normalized - dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) - - # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) - - residual = data.clone() # [batch_size, n_features] - - for _ in range(K): - # Select which (coefficient, basis function) pair to update using the inner product. - candidate_coefs = residual @ dictionary # [batch_size, n_basis] - top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] - - # Update the coefficient. - top_coefs = candidate_coefs[torch.arange(batch_size), top_coef_idxs] # [batch_size] - coefficients[torch.arange(batch_size), top_coef_idxs] = top_coefs - - # Explain away/subtract the chosen coefficient and corresponding basis from the residual. - top_coef_bases = dictionary[..., top_coef_idxs] # [n_features, batch_size] - residual = residual - top_coefs.reshape(batch_size, 1) * top_coef_bases.T # [batch_size, n_features] - - return coefficients.detach() - - -class OMP(InferenceMethod): - """ - Infer coefficients for each image in data using elements dictionary. - Method description can be traced to: - "Orthogonal Matching Pursuit: Recursive Function Approximation with Application to Wavelet Decomposition" - (Y. Pati & R. Rezaiifar & P. Krishnaprasad, 1993) - """ - - def __init__(self, sparsity, solver=None, return_all_coefficients=False): - ''' - - Parameters - ---------- - sparsity : scalar (1,) - sparsity of the solution - return_all_coefficients : string (1,) default=False - returns all coefficients during inference procedure if True - user beware: if n_iter is large, setting this parameter to True - can result in large memory usage/potential exhaustion. This function typically used for - debugging - solver : default=None - ''' - super().__init__(solver) - self.sparsity = sparsity - self.return_all_coefficients = return_all_coefficients - - def infer(self, data, dictionary): - """ - Infer coefficients for each image in data using dict elements dictionary using Orthogonal Matching Pursuit (OMP) - - Parameters - ---------- - data : array-like (batch_size, n_features) - data to be used in sparse coding - dictionary : array-like, (n_features, n_basis) - dictionary to be used to get the coefficients - - Returns - ------- - coefficients : array-like (batch_size, n_basis) - """ - # Get input characteristics - batch_size, n_features = data.shape - n_features, n_basis = dictionary.shape - device = dictionary.device - - # Define signal sparsity - K = np.ceil(self.sparsity*n_basis).astype(int) - - # Get dictionary norms in case atoms are not normalized - dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) - - # Initialize coefficients for the whole batch - coefficients = torch.zeros( - batch_size, n_basis, requires_grad=False, device=device) - - residual = data.clone() # [batch_size, n_features] - - # The basis functions that are used to infer the coefficients will be updated each time. - used_basis_fns = torch.zeros((batch_size, n_basis), dtype=torch.bool) - - for t in range(K): - # Select which (coefficient, basis function) pair to update using the inner product. - candidate_coefs = residual @ dictionary # [batch_size, n_basis] - top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] - used_basis_fns[torch.arange(batch_size), top_coef_idxs] = True - - # Update the coefficients - used_dictionary = dictionary[..., used_basis_fns.nonzero()[:, 1]].reshape((n_features, batch_size, t + 1)) - - (used_coefficients, _, _, _) = torch.linalg.lstsq( - used_dictionary.permute((1, 0, 2)), # [batch_size, n_features, t + 1] - data.reshape(batch_size, n_features, 1), - ) # [batch_size, t + 1, 1] - coefficients[used_basis_fns] = used_coefficients.reshape(-1) - - # Update the residual. - residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features] - - return coefficients.detach() diff --git a/sparsecoding/inference/__init__.py b/sparsecoding/inference/__init__.py new file mode 100644 index 0000000..8ac1dec --- /dev/null +++ b/sparsecoding/inference/__init__.py @@ -0,0 +1,21 @@ +from .iht import IHT +from .inference_method import InferenceMethod +from .ista import ISTA +from .lca import LCA +from .lsm import LSM +from .mp import MP +from .omp import OMP +from .pytorch_optimizer import PyTorchOptimizer +from .vanilla import Vanilla + +__all__ = [ + "IHT", + "InferenceMethod", + "ISTA", + "LCA", + "LSM", + "MP", + "OMP", + "PyTorchOptimizer", + "Vanilla", +] diff --git a/sparsecoding/inference/iht.py b/sparsecoding/inference/iht.py new file mode 100644 index 0000000..467cfc5 --- /dev/null +++ b/sparsecoding/inference/iht.py @@ -0,0 +1,81 @@ +import numpy as np +import torch + +from .inference_method import InferenceMethod + + +class IHT(InferenceMethod): + """ + Infer coefficients for each image in data using elements dictionary. + Method description can be traced to + "Iterative Hard Thresholding for Compressed Sensing" (T. Blumensath & M. E. Davies, 2009) + """ + + def __init__(self, sparsity, n_iter=10, solver=None, return_all_coefficients=False): + """ + + Parameters + ---------- + sparsity : scalar (1,) + Sparsity of the solution. The number of active coefficients will be set + to ceil(sparsity * data_dim) at the end of each iterative update. + n_iter : scalar (1,) default=100 + number of iterations to run for an inference method + return_all_coefficients : string (1,) default=False + returns all coefficients during inference procedure if True + user beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This function typically used for + debugging + solver : default=None + """ + super().__init__(solver) + self.n_iter = n_iter + self.sparsity = sparsity + self.return_all_coefficients = return_all_coefficients + + def infer(self, data, dictionary): + """ + Infer coefficients for each image in data using dict elements dictionary using Iterative Hard Thresholding (IHT) + + Parameters + ---------- + data : array-like (batch_size, n_features) + data to be used in sparse coding + dictionary : array-like, (n_features, n_basis) + dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like (batch_size, n_basis) + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Define signal sparsity + K = np.ceil(self.sparsity * n_basis).astype(int) + + # Initialize coefficients for the whole batch + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) + + for _ in range(self.n_iter): + # Compute the prediction given the current coefficients + preds = coefficients @ dictionary.T # [batch_size, n_features] + + # Compute the residual + delta = data - preds # [batch_size, n_features] + + # Compute the similarity between the residual and the atoms in the dictionary + update = delta @ dictionary # [batch_size, n_basis] + coefficients = coefficients + update # [batch_size, n_basis] + + # Apply kWTA nonlinearity + topK_values, indices = torch.topk(torch.abs(coefficients), K, dim=1) + + # Reconstruct coefficients using the output of torch.topk + coefficients = torch.sign(coefficients) * torch.zeros(batch_size, n_basis, device=device).scatter_( + 1, indices, topK_values + ) + + return coefficients.detach() diff --git a/sparsecoding/inference/inference_method.py b/sparsecoding/inference/inference_method.py new file mode 100644 index 0000000..62c97ec --- /dev/null +++ b/sparsecoding/inference/inference_method.py @@ -0,0 +1,73 @@ +import torch + + +class InferenceMethod: + """Base class for inference method.""" + + def __init__(self, solver): + """ + Parameters + ---------- + """ + self.solver = solver + + def initialize(self, a): + """Define initial coefficients. + + Parameters + ---------- + + Returns + ------- + """ + raise NotImplementedError + + def grad(self): + """Compute the gradient step. + + Parameters + ---------- + + Returns + ------- + """ + raise NotImplementedError + + def infer(self, dictionary, data, coeff_0=None, use_checknan=False): + """Infer the coefficients given a dataset and dictionary. + + Parameters + ---------- + dictionary : array-like, shape [n_features,n_basis] + + data : array-like, shape [n_samples,n_features] + + coeff_0 : array-like, shape [n_samples,n_basis], optional + Initial coefficient values. + use_checknan : bool, default=False + Check for nans in coefficients on each iteration + + Returns + ------- + coefficients : array-like, shape [n_samples,n_basis] + """ + raise NotImplementedError + + @staticmethod + def checknan(data=torch.tensor(0), name="data"): + """Check for nan values in data. + + Parameters + ---------- + data : array-like, optional + Data to check for nans + name : str, default="data" + Name to add to error, if one is thrown + + Raises + ------ + ValueError + If the nan found in data + """ + if torch.isnan(data).any(): + raise ValueError("InferenceMethod error: nan in %s." % (name)) diff --git a/sparsecoding/inference/ista.py b/sparsecoding/inference/ista.py new file mode 100644 index 0000000..42bd21b --- /dev/null +++ b/sparsecoding/inference/ista.py @@ -0,0 +1,134 @@ +import torch + +from .inference_method import InferenceMethod + + +class ISTA(InferenceMethod): + def __init__( + self, + n_iter=100, + sparsity_penalty=1e-2, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients=False, + ): + """Iterative shrinkage-thresholding algorithm for solving LASSO problems. + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + sparsity_penalty : float, default=0.2 + + stop_early : bool, default=False + Stops dynamics early based on change in coefficents + epsilon : float, default=1e-2 + Only used if stop_early True, specifies criteria to stop dynamics + return_all_coefficients : str, default=False + Returns all coefficients during inference procedure if True + User beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Beck, A., & Teboulle, M. (2009). A fast iterative + shrinkage-thresholding algorithm for linear inverse problems. + SIAM journal on imaging sciences, 2(1), 183-202. + """ + super().__init__(solver) + self.n_iter = n_iter + self.sparsity_penalty = sparsity_penalty + self.stop_early = stop_early + self.epsilon = epsilon + self.coefficients = None + self.return_all_coefficients = return_all_coefficients + + def threshold_nonlinearity(self, u): + """Soft threshhold function + + Parameters + ---------- + u : array-likes, shape [batch_size, n_basis] + Membrane potentials + + Returns + ------- + a : array-like, shape [batch_size, n_basis] + activations + """ + a = (torch.abs(u) - self.threshold).clamp(min=0.0) + a = torch.sign(u) * a + return a + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients for each image in data using dictionary elements. + Uses ISTA (Beck & Taboulle 2009), equations 1.4 and 1.5. + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + + dictionary : array-like, shape [n_features, n_basis] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + Check for nans in coefficients on each iteration. Setting this to + False can speed up inference time. + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size = data.shape[0] + n_basis = dictionary.shape[1] + device = dictionary.device + + # Calculate stepsize based on largest eigenvalue of + # dictionary.T @ dictionary. + lipschitz_constant = torch.linalg.eigvalsh(torch.mm(dictionary.T, dictionary))[-1] + stepsize = 1.0 / lipschitz_constant + self.threshold = stepsize * self.sparsity_penalty + + # Initialize coefficients. + if coeff_0 is not None: + u = coeff_0.to(device) + else: + u = torch.zeros((batch_size, n_basis)).to(device) + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + self.coefficients = self.threshold_nonlinearity(u) + residual = torch.mm(dictionary, self.coefficients.T).T - data + + for _ in range(self.n_iter): + if self.stop_early: + old_u = u.clone().detach() + + if self.return_all_coefficients: + coefficients = torch.concat( + [coefficients, self.threshold_nonlinearity(u).clone().unsqueeze(1)], + dim=1, + ) + + u -= stepsize * torch.mm(residual, dictionary) + self.coefficients = self.threshold_nonlinearity(u) + + if self.stop_early: + # Stopping condition is function of change of the coefficients. + a_change = torch.mean(torch.abs(old_u - u) / stepsize) + if a_change < self.epsilon: + break + + residual = torch.mm(dictionary, self.coefficients.T).T - data + u = self.coefficients + + if use_checknan: + self.checknan(u, "coefficients") + + coefficients = torch.concat([coefficients, self.coefficients.clone().unsqueeze(1)], dim=1) + return torch.squeeze(coefficients) diff --git a/sparsecoding/inference/ista_test.py b/sparsecoding/inference/ista_test.py new file mode 100644 index 0000000..397c555 --- /dev/null +++ b/sparsecoding/inference/ista_test.py @@ -0,0 +1,40 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """Test that ISTA inference returns expected shapes.""" + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.ISTA(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + inference_method = inference.ISTA(N_ITER, return_all_coefficients=True) + a = inference_method.infer(data, bars_dictionary_fixture) + assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) + + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """Test that ISTA inference recovers the correct weights.""" + N_ITER = 5000 + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.ISTA(n_iter=N_ITER) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/inference/lca.py b/sparsecoding/inference/lca.py new file mode 100644 index 0000000..ad5023f --- /dev/null +++ b/sparsecoding/inference/lca.py @@ -0,0 +1,180 @@ +import torch + +from .inference_method import InferenceMethod + + +class LCA(InferenceMethod): + def __init__( + self, + n_iter=100, + coeff_lr=1e-3, + threshold=0.1, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients="none", + nonnegative=False, + ): + """Method implemented according locally competative algorithm (LCA) + with the ideal soft thresholding function. + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + coeff_lr : float, default=1e-3 + Update rate of coefficient dynamics + threshold : float, default=0.1 + Threshold for non-linearity + stop_early : bool, default=False + Stops dynamics early based on change in coefficents + epsilon : float, default=1e-2 + Only used if stop_early True, specifies criteria to stop dynamics + nonnegative : bool, default=False + Constrain coefficients to be nonnegative + return_all_coefficients : str, {"none", "membrane", "active"}, default="none" + Returns all coefficients during inference procedure if not equal + to "none". If return_all_coefficients=="membrane", membrane + potentials (u) returned. If return_all_coefficients=="active", + active units (a) (output of thresholding function over u) returned. + User beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Rozell, C. J., Johnson, D. H., Baraniuk, R. G., & Olshausen, + B. A. (2008). Sparse coding via thresholding and local competition + in neural circuits. Neural computation, 20(10), 2526-2563. + """ + super().__init__(solver) + self.threshold = threshold + self.coeff_lr = coeff_lr + self.stop_early = stop_early + self.epsilon = epsilon + self.n_iter = n_iter + self.nonnegative = nonnegative + if return_all_coefficients not in ["none", "membrane", "active"]: + raise ValueError( + "Invalid input for return_all_coefficients. Valid" 'inputs are: "none", "membrane", "active".' + ) + self.return_all_coefficients = return_all_coefficients + + def threshold_nonlinearity(self, u): + """Soft threshhold function + + Parameters + ---------- + u : array-like, shape [batch_size, n_basis] + Membrane potentials + + Returns + ------- + a : array-like, shape [batch_size, n_basis] + Activations + """ + if self.nonnegative: + a = (u - self.threshold).clamp(min=0.0) + else: + a = (torch.abs(u) - self.threshold).clamp(min=0.0) + a = torch.sign(u) * a + return a + + def grad(self, b, G, u, a): + """Compute the gradient step on membrane potentials + + Parameters + ---------- + b : array-like, shape [batch_size, n_coefficients] + Driver signal for coefficients + G : array-like, shape [n_coefficients, n_coefficients] + Inhibition matrix + a : array-like, shape [batch_size, n_coefficients] + Currently active coefficients + + Returns + ------- + du : array-like, shape [batch_size, n_coefficients] + Gradient of membrane potentials + """ + du = b - u - (G @ a.t()).t() + return du + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients using provided dictionary + + Parameters + ---------- + dictionary : array-like, shape [n_features, n_basis] + + data : array-like, shape [n_samples, n_features] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + Check for nans in coefficients on each iteration. Setting this to + False can speed up inference time. + + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # initialize + if coeff_0 is not None: + u = coeff_0.to(device) + else: + u = torch.zeros((batch_size, n_basis)).to(device) + + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + + b = (dictionary.t() @ data.t()).t() + G = dictionary.t() @ dictionary - torch.eye(n_basis).to(device) + for i in range(self.n_iter): + # store old membrane potentials to evalute stop early condition + if self.stop_early: + old_u = u.clone().detach() + + # check return all + if self.return_all_coefficients != "none": + if self.return_all_coefficients == "active": + coefficients = torch.concat( + [ + coefficients, + self.threshold_nonlinearity(u).clone().unsqueeze(1), + ], + dim=1, + ) + else: + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) + + # compute new + a = self.threshold_nonlinearity(u) + du = self.grad(b, G, u, a) + u = u + self.coeff_lr * du + + # check stopping condition + if self.stop_early: + relative_change_in_coeff = torch.linalg.norm(old_u - u) / torch.linalg.norm(old_u) + if relative_change_in_coeff < self.epsilon: + break + + if use_checknan: + self.checknan(u, "coefficients") + + # return active units if return_all_coefficients in ["none", "active"] + if self.return_all_coefficients == "membrane": + coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1) + else: + final_coefficients = self.threshold_nonlinearity(u) + coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1) + + return coefficients.squeeze() diff --git a/sparsecoding/inference/lca_test.py b/sparsecoding/inference/lca_test.py new file mode 100644 index 0000000..0834ddd --- /dev/null +++ b/sparsecoding/inference/lca_test.py @@ -0,0 +1,52 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that LCA inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LCA(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + for retval in ["active", "membrane"]: + inference_method = inference.LCA(N_ITER, return_all_coefficients=retval) + a = inference_method.infer(data, bars_dictionary_fixture) + assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) + + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that LCA inference recovers the correct weights. + """ + LR = 5e-2 + THRESHOLD = 0.1 + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LCA( + coeff_lr=LR, + threshold=THRESHOLD, + n_iter=N_ITER, + ) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/inference/lsm.py b/sparsecoding/inference/lsm.py new file mode 100644 index 0000000..83e5ed5 --- /dev/null +++ b/sparsecoding/inference/lsm.py @@ -0,0 +1,148 @@ +import torch + +from .inference_method import InferenceMethod + + +class LSM(InferenceMethod): + def __init__( + self, + n_iter=100, + n_iter_LSM=6, + beta=0.01, + alpha=80.0, + sigma=0.005, + sparse_threshold=10**-2, + solver=None, + return_all_coefficients=False, + ): + """Infer latent coefficients generating data given dictionary. + Method implemented according to "Group Sparse Coding with a Laplacian + Scale Mixture Prior" (P. J. Garrigues & B. A. Olshausen, 2010) + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run for an optimizer + n_iter_LSM : int, default=6 + Number of iterations to run the outer loop of LSM + beta : float, default=0.01 + LSM parameter used to update lambdas + alpha : float, default=80.0 + LSM parameter used to update lambdas + sigma : float, default=0.005 + LSM parameter used to compute the loss function + sparse_threshold : float, default=10**-2 + Threshold used to discard smallest coefficients in the final + solution SM parameter used to compute the loss function + return_all_coefficients : bool, default=False + Returns all coefficients during inference procedure if True + User beware: If n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Garrigues, P., & Olshausen, B. (2010). Group sparse coding with + a laplacian scale mixture prior. Advances in neural information + processing systems, 23. + """ + super().__init__(solver) + self.n_iter = n_iter + self.n_iter_LSM = n_iter_LSM + self.beta = beta + self.alpha = alpha + self.sigma = sigma + self.sparse_threshold = sparse_threshold + self.return_all_coefficients = return_all_coefficients + + def lsm_Loss(self, data, dictionary, coefficients, lambdas, sigma): + """Compute LSM loss according to equation (7) in (P. J. Garrigues & + B. A. Olshausen, 2010) + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + Data to be used in sparse coding + dictionary : array-like, shape [n_features, n_basis] + Dictionary to be used + coefficients : array-like, shape [batch_size, n_basis] + The current values of coefficients + lambdas : array-like, shape [batch_size, n_basis] + The current values of regularization coefficient for all basis + sigma : float, default=0.005 + LSM parameter used to compute the loss functions + + Returns + ------- + loss : array-like, shape [batch_size, 1] + Loss values for each data sample + """ + + # Compute loss + preds = torch.mm(dictionary, coefficients.t()).t() + mse_loss = (1 / (2 * (sigma**2))) * torch.sum(torch.square(data - preds), dim=1, keepdim=True) + sparse_loss = torch.sum(lambdas * torch.abs(coefficients), dim=1, keepdim=True) + loss = mse_loss + sparse_loss + return loss + + def infer(self, data, dictionary): + """Infer coefficients for each image in data using dict elements + dictionary using Laplacian Scale Mixture (LSM) + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + Data to be used in sparse coding + dictionary : array-like, shape [n_features, n_basis] + Dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like, shape [batch_size, n_basis] + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Initialize coefficients for the whole batch + coefficients = torch.zeros(batch_size, n_basis, device=device, requires_grad=True) + + # Set up optimizer + optimizer = torch.optim.Adam([coefficients], lr=1e-1) + + # Outer loop, set sparsity penalties (lambdas). + for i in range(self.n_iter_LSM): + # Compute the initial values of lambdas + lambdas = (self.alpha + 1) / (self.beta + torch.abs(coefficients.detach())) + + # Inner loop, optimize coefficients w/ current sparsity penalties. + # Exits early if converged before `n_iter`s. + last_loss = None + for t in range(self.n_iter): + # compute LSM loss for the current iteration + loss = self.lsm_Loss( + data=data, + dictionary=dictionary, + coefficients=coefficients, + lambdas=lambdas, + sigma=self.sigma, + ) + loss = torch.sum(loss) + + # Backward pass: compute gradient and update model parameters. + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Break if coefficients have converged. + if last_loss is not None and loss > 1.05 * last_loss: + break + + last_loss = loss + + # Sparsify the final solution by discarding the small coefficients + coefficients.data[torch.abs(coefficients.data) < self.sparse_threshold] = 0 + + return coefficients.detach() diff --git a/sparsecoding/inference/lsm_test.py b/sparsecoding/inference/lsm_test.py new file mode 100644 index 0000000..8cd3470 --- /dev/null +++ b/sparsecoding/inference/lsm_test.py @@ -0,0 +1,41 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that LSM inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LSM(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that LSM inference recovers the correct weights. + """ + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.LSM(n_iter=N_ITER) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=6e-2) diff --git a/sparsecoding/inference/mp.py b/sparsecoding/inference/mp.py new file mode 100644 index 0000000..4305976 --- /dev/null +++ b/sparsecoding/inference/mp.py @@ -0,0 +1,76 @@ +import numpy as np +import torch + +from .inference_method import InferenceMethod + + +class MP(InferenceMethod): + """ + Infer coefficients for each image in data using elements dictionary. + Method description can be traced + to "Matching Pursuits with Time-Frequency Dictionaries" (S. G. Mallat & Z. Zhang, 1993) + """ + + def __init__(self, sparsity, solver=None, return_all_coefficients=False): + """ + + Parameters + ---------- + sparsity : scalar (1,) + sparsity of the solution + return_all_coefficients : string (1,) default=False + returns all coefficients during inference procedure if True + user beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This function typically used for + debugging + solver : default=None + """ + super().__init__(solver) + self.sparsity = sparsity + self.return_all_coefficients = return_all_coefficients + + def infer(self, data, dictionary): + """ + Infer coefficients for each image in data using dict elements dictionary using Matching Pursuit (MP) + + Parameters + ---------- + data : array-like (batch_size, n_features) + data to be used in sparse coding + dictionary : array-like, (n_features, n_basis) + dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like (batch_size, n_basis) + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Define signal sparsity + K = np.ceil(self.sparsity * n_basis).astype(int) + + # Get dictionary norms in case atoms are not normalized + dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) + + # Initialize coefficients for the whole batch + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) + + residual = data.clone() # [batch_size, n_features] + + for _ in range(K): + # Select which (coefficient, basis function) pair to update using the inner product. + candidate_coefs = residual @ dictionary # [batch_size, n_basis] + top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] + + # Update the coefficient. + top_coefs = candidate_coefs[torch.arange(batch_size), top_coef_idxs] # [batch_size] + coefficients[torch.arange(batch_size), top_coef_idxs] = top_coefs + + # Explain away/subtract the chosen coefficient and corresponding basis from the residual. + top_coef_bases = dictionary[..., top_coef_idxs] # [n_features, batch_size] + residual = residual - top_coefs.reshape(batch_size, 1) * top_coef_bases.T # [batch_size, n_features] + + return coefficients.detach() diff --git a/sparsecoding/inference/omp.py b/sparsecoding/inference/omp.py new file mode 100644 index 0000000..db99427 --- /dev/null +++ b/sparsecoding/inference/omp.py @@ -0,0 +1,85 @@ +import numpy as np +import torch + +from .inference_method import InferenceMethod + + +class OMP(InferenceMethod): + """ + Infer coefficients for each image in data using elements dictionary. + Method description can be traced to: + "Orthogonal Matching Pursuit: Recursive Function Approximation with Application to Wavelet Decomposition" + (Y. Pati & R. Rezaiifar & P. Krishnaprasad, 1993) + """ + + def __init__(self, sparsity, solver=None, return_all_coefficients=False): + """ + + Parameters + ---------- + sparsity : scalar (1,) + sparsity of the solution + return_all_coefficients : string (1,) default=False + returns all coefficients during inference procedure if True + user beware: if n_iter is large, setting this parameter to True + can result in large memory usage/potential exhaustion. This function typically used for + debugging + solver : default=None + """ + super().__init__(solver) + self.sparsity = sparsity + self.return_all_coefficients = return_all_coefficients + + def infer(self, data, dictionary): + """ + Infer coefficients for each image in data using dict elements dictionary using Orthogonal Matching Pursuit (OMP) + + Parameters + ---------- + data : array-like (batch_size, n_features) + data to be used in sparse coding + dictionary : array-like, (n_features, n_basis) + dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like (batch_size, n_basis) + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Define signal sparsity + K = np.ceil(self.sparsity * n_basis).astype(int) + + # Get dictionary norms in case atoms are not normalized + dictionary_norms = torch.norm(dictionary, p=2, dim=0, keepdim=True) + + # Initialize coefficients for the whole batch + coefficients = torch.zeros(batch_size, n_basis, requires_grad=False, device=device) + + residual = data.clone() # [batch_size, n_features] + + # The basis functions that are used to infer the coefficients will be updated each time. + used_basis_fns = torch.zeros((batch_size, n_basis), dtype=torch.bool) + + for t in range(K): + # Select which (coefficient, basis function) pair to update using the inner product. + candidate_coefs = residual @ dictionary # [batch_size, n_basis] + top_coef_idxs = torch.argmax(torch.abs(candidate_coefs) / dictionary_norms, dim=1) # [batch_size] + used_basis_fns[torch.arange(batch_size), top_coef_idxs] = True + + # Update the coefficients + used_dictionary = dictionary[..., used_basis_fns.nonzero()[:, 1]].reshape((n_features, batch_size, t + 1)) + + (used_coefficients, _, _, _) = torch.linalg.lstsq( + used_dictionary.permute((1, 0, 2)), # [batch_size, n_features, t + 1] + data.reshape(batch_size, n_features, 1), + ) # [batch_size, t + 1, 1] + coefficients[used_basis_fns] = used_coefficients.reshape(-1) + + # Update the residual. + residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features] + + return coefficients.detach() diff --git a/sparsecoding/inference/pytorch_optimizer.py b/sparsecoding/inference/pytorch_optimizer.py new file mode 100644 index 0000000..7c13079 --- /dev/null +++ b/sparsecoding/inference/pytorch_optimizer.py @@ -0,0 +1,81 @@ +import torch + +from .inference_method import InferenceMethod + + +class PyTorchOptimizer(InferenceMethod): + def __init__(self, optimizer_f, loss_f, n_iter=100, solver=None): + """Infer coefficients using provided loss functional and optimizer + + Parameters + ---------- + optimizer : function handle + Pytorch optimizer handle have single parameter: + (coefficients) + where coefficients is of shape [batch_size, n_basis] + loss_f : function handle + Must have parameters: + (data, dictionary, coefficients) + where data is of shape [batch_size, n_features] + and loss_f must return tensor of size [batch_size,] + n_iter : int, default=100 + Number of iterations to run for an optimizer + solver : default=None + """ + super().__init__(solver) + self.optimizer_f = optimizer_f + self.loss_f = loss_f + self.n_iter = n_iter + + def infer(self, data, dictionary, coeff_0=None): + """Infer coefficients for each image in data using dict elements + dictionary by minimizing provided loss function with provided + optimizer. + + Parameters + ---------- + data : array-like, shape [batch_size, n_features] + Data to be used in sparse coding + + dictionary : array-like, shape [n_features, n_basis] + Dictionary to be used to get the coefficients + + Returns + ------- + coefficients : array-like, shape [batch_size, n_basis] + """ + # Get input characteristics + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # Initialize coefficients for the whole batch + + # initialize + if coeff_0 is not None: + coefficients = coeff_0.requires_grad_(True) + else: + coefficients = torch.zeros((batch_size, n_basis), requires_grad=True, device=device) + + optimizer = self.optimizer_f([coefficients]) + + for i in range(self.n_iter): + + # compute LSM loss for the current iteration + loss = self.loss_f( + data=data, + dictionary=dictionary, + coefficients=coefficients, + ) + + optimizer.zero_grad() + + # Backward pass: compute gradient of the loss with respect to + # model parameters + loss.backward(torch.ones((batch_size,), device=device)) + + # Calling the step function on an Optimizer makes an update to its + # parameters + optimizer.step() + + return coefficients.detach() diff --git a/sparsecoding/inference/pytorch_optimizer_test.py b/sparsecoding/inference/pytorch_optimizer_test.py new file mode 100644 index 0000000..87961af --- /dev/null +++ b/sparsecoding/inference/pytorch_optimizer_test.py @@ -0,0 +1,80 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def lasso_loss(data, dictionary, coefficients, sparsity_penalty): + """ + Generic MSE + l1-norm loss. + """ + batch_size = data.shape[0] + datahat = (dictionary @ coefficients.t()).t() + + mse_loss = torch.linalg.vector_norm(datahat - data, dim=1).square() + sparse_loss = torch.sum(torch.abs(coefficients), axis=1) + + total_loss = (mse_loss + sparsity_penalty * sparse_loss) / batch_size + return total_loss + + +def loss_fn(data, dictionary, coefficients): + return lasso_loss( + data, + dictionary, + coefficients, + sparsity_penalty=1.0, + ) + + +def optimizer_fn(coefficients): + return torch.optim.Adam( + coefficients, + lr=0.1, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0, + ) + + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that PyTorchOptimizer inference returns expected shapes. + """ + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.PyTorchOptimizer( + optimizer_fn, + loss_fn, + n_iter=10, + ) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that PyTorchOptimizer inference recovers the correct weights. + """ + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.PyTorchOptimizer( + optimizer_fn, + loss_fn, + n_iter=N_ITER, + ) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=2e-1, rtol=1e-1) diff --git a/sparsecoding/inference/vanilla.py b/sparsecoding/inference/vanilla.py new file mode 100644 index 0000000..11bbfa4 --- /dev/null +++ b/sparsecoding/inference/vanilla.py @@ -0,0 +1,131 @@ +import torch + +from .inference_method import InferenceMethod + + +class Vanilla(InferenceMethod): + def __init__( + self, + n_iter=100, + coeff_lr=1e-3, + sparsity_penalty=0.2, + stop_early=False, + epsilon=1e-2, + solver=None, + return_all_coefficients=False, + ): + """Gradient descent with Euler's method on model in Olshausen & Field + (1997) with laplace prior over coefficients (corresponding to l-1 norm + penalty). + + Parameters + ---------- + n_iter : int, default=100 + Number of iterations to run + coeff_lr : float, default=1e-3 + Update rate of coefficient dynamics + sparsity_penalty : float, default=0.2 + + stop_early : bool, default=False + Stops dynamics early based on change in coefficents + epsilon : float, default=1e-2 + Only used if stop_early True, specifies criteria to stop dynamics + return_all_coefficients : str, default=False + Returns all coefficients during inference procedure if True + User beware: If n_iter is large, setting this parameter to True + Can result in large memory usage/potential exhaustion. This + function typically used for debugging. + solver : default=None + + References + ---------- + [1] Olshausen, B. A., & Field, D. J. (1997). Sparse coding with an + overcomplete basis set: A strategy employed by V1?. Vision research, + 37(23), 3311-3325. + """ + super().__init__(solver) + self.coeff_lr = coeff_lr + self.sparsity_penalty = sparsity_penalty + self.stop_early = stop_early + self.epsilon = epsilon + self.n_iter = n_iter + self.return_all_coefficients = return_all_coefficients + + def grad(self, residual, dictionary, a): + """Compute the gradient step on coefficients + + Parameters + ---------- + residual : array-like, shape [batch_size, n_features] + Residual between reconstructed image and original + dictionary : array-like, shape [n_features,n_coefficients] + Dictionary + a : array-like, shape [batch_size, n_coefficients] + Coefficients + + Returns + ------- + da : array-like, shape [batch_size, n_coefficients] + Gradient of membrane potentials + """ + da = (dictionary.t() @ residual.t()).t() - self.sparsity_penalty * torch.sign(a) + return da + + def infer(self, data, dictionary, coeff_0=None, use_checknan=False): + """Infer coefficients using provided dictionary + + Parameters + ---------- + dictionary : array-like, shape [n_features, n_basis] + Dictionary + data : array like, shape [n_samples, n_features] + + coeff_0 : array-like, shape [n_samples, n_basis], optional + Initial coefficient values + use_checknan : bool, default=False + check for nans in coefficients on each iteration. Setting this to + False can speed up inference time + + Returns + ------- + coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis] + First case occurs if return_all_coefficients == "none". If + return_all_coefficients != "none", returned shape is second case. + Returned dimension along dim 1 can be less than n_iter when + stop_early==True and stopping criteria met. + """ + batch_size, n_features = data.shape + n_features, n_basis = dictionary.shape + device = dictionary.device + + # initialize + if coeff_0 is not None: + a = coeff_0.to(device) + else: + a = torch.rand((batch_size, n_basis)).to(device) - 0.5 + + coefficients = torch.zeros((batch_size, 0, n_basis)).to(device) + + residual = data - (dictionary @ a.t()).t() + for i in range(self.n_iter): + + if self.return_all_coefficients: + coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) + + if self.stop_early: + old_a = a.clone().detach() + + da = self.grad(residual, dictionary, a) + a = a + self.coeff_lr * da + + if self.stop_early: + if torch.linalg.norm(old_a - a) / torch.linalg.norm(old_a) < self.epsilon: + break + + residual = data - (dictionary @ a.t()).t() + + if use_checknan: + self.checknan(a, "coefficients") + + coefficients = torch.concat([coefficients, a.clone().unsqueeze(1)], dim=1) + return torch.squeeze(coefficients) diff --git a/sparsecoding/inference/vanilla_test.py b/sparsecoding/inference/vanilla_test.py new file mode 100644 index 0000000..9c556e5 --- /dev/null +++ b/sparsecoding/inference/vanilla_test.py @@ -0,0 +1,46 @@ +import torch + +from sparsecoding import inference +from sparsecoding.datasets import BarsDataset +from sparsecoding.test_utils import assert_allclose, assert_shape_equal + + +def test_shape( + patch_size_fixture: int, + dataset_size_fixture: int, + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that Vanilla inference returns expected shapes. + """ + N_ITER = 10 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.Vanilla(N_ITER) + a = inference_method.infer(data, bars_dictionary_fixture) + assert_shape_equal(a, dataset.weights) + + inference_method = inference.Vanilla(N_ITER, return_all_coefficients=True) + a = inference_method.infer(data, bars_dictionary_fixture) + assert a.shape == (dataset_size_fixture, N_ITER + 1, 2 * patch_size_fixture) + + +def test_inference( + bars_dictionary_fixture: torch.Tensor, + bars_datas_fixture: list[torch.Tensor], + bars_datasets_fixture: list[BarsDataset], +): + """ + Test that Vanilla inference recovers the correct weights. + """ + LR = 5e-2 + N_ITER = 1000 + + for (data, dataset) in zip(bars_datas_fixture, bars_datasets_fixture): + inference_method = inference.Vanilla(coeff_lr=LR, n_iter=N_ITER) + + a = inference_method.infer(data, bars_dictionary_fixture) + + assert_allclose(a, dataset.weights, atol=5e-2) diff --git a/sparsecoding/priors.py b/sparsecoding/priors.py deleted file mode 100644 index 026ac43..0000000 --- a/sparsecoding/priors.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -from torch.distributions.laplace import Laplace - -from abc import ABC, abstractmethod - - -class Prior(ABC): - """A distribution over weights. - - Parameters - ---------- - weights_dim : int - Number of weights for each sample. - """ - @abstractmethod - def D(self): - """ - Number of weights per sample. - """ - - @abstractmethod - def sample( - self, - num_samples: int = 1, - ): - """Sample weights from the prior. - - Parameters - ---------- - num_samples : int, default=1 - Number of samples. - - Returns - ------- - samples : Tensor, shape [num_samples, self.D] - Sampled weights. - """ - - -class SpikeSlabPrior(Prior): - """Prior where weights are drawn from a "spike-and-slab" distribution. - - The "spike" is at 0 and the "slab" is Laplacian. - - See: - https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf - for a good review of the spike-and-slab model. - - Parameters - ---------- - dim : int - Number of weights per sample. - p_spike : float - The probability of the weight being 0. - scale : float - The "scale" of the Laplacian distribution (larger is wider). - positive_only : bool - Ensure that the weights are positive by taking the absolute value - of weights sampled from the Laplacian. - """ - - def __init__( - self, - dim: int, - p_spike: float, - scale: float, - positive_only: bool = True, - ): - if dim < 0: - raise ValueError(f"`dim` should be nonnegative, got {dim}.") - if p_spike < 0 or p_spike > 1: - raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") - if scale <= 0: - raise ValueError(f"`scale` must be positive, got {scale}.") - - self.dim = dim - self.p_spike = p_spike - self.scale = scale - self.positive_only = positive_only - - @property - def D(self): - return self.dim - - def sample(self, num_samples: int): - N = num_samples - - zero_weights = torch.zeros((N, self.D), dtype=torch.float32) - slab_weights = Laplace( - loc=zero_weights, - scale=torch.full((N, self.D), self.scale, dtype=torch.float32), - ).sample() # [N, D] - - if self.positive_only: - slab_weights = torch.abs(slab_weights) - - spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike - - weights = torch.where( - spike_over_slab, - zero_weights, - slab_weights, - ) - - return weights - - -class L0Prior(Prior): - """Prior with a distribution over the l0-norm of the weights. - - A class of priors where the weights are binary; - the distribution is over the l0-norm of the weight vector - (how many weights are active). - - Parameters - ---------- - prob_distr : Tensor, shape [D], dtype float32 - Probability distribution over the l0-norm of the weights. - """ - - def __init__( - self, - prob_distr: torch.Tensor, - ): - if prob_distr.dim() != 1: - raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") - if prob_distr.dtype != torch.float32: - raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") - if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): - raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") - - self.prob_distr = prob_distr - - @property - def D(self): - return self.prob_distr.shape[0] - - def sample( - self, - num_samples: int - ): - N = num_samples - - num_active_weights = 1 + torch.multinomial( - input=self.prob_distr, - num_samples=num_samples, - replacement=True, - ) # [N] - - d_idxs = torch.arange(self.D) - active_idx_mask = ( - d_idxs.reshape(1, self.D) - < num_active_weights.reshape(N, 1) - ) # [N, self.D] - - n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] - # Need to shuffle here so that it's not always the first weights that are active. - shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] - shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] - - # [num_active_weights], [num_active_weights] - active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] - - weights = torch.zeros((N, self.D), dtype=torch.float32) - weights[active_weight_idxs] += 1. - - return weights diff --git a/sparsecoding/priors/__init__.py b/sparsecoding/priors/__init__.py new file mode 100644 index 0000000..75314d7 --- /dev/null +++ b/sparsecoding/priors/__init__.py @@ -0,0 +1,9 @@ +from .l0_prior import L0Prior +from .prior import Prior +from .spike_slab_prior import SpikeSlabPrior + +__all__ = [ + "Prior", + "L0Prior", + "SpikeSlabPrior", +] diff --git a/sparsecoding/priors/l0_prior.py b/sparsecoding/priors/l0_prior.py new file mode 100644 index 0000000..590bcad --- /dev/null +++ b/sparsecoding/priors/l0_prior.py @@ -0,0 +1,59 @@ +import torch + +from .prior import Prior + + +class L0Prior(Prior): + """Prior with a distribution over the l0-norm of the weights. + + A class of priors where the weights are binary; + the distribution is over the l0-norm of the weight vector + (how many weights are active). + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + + self.prob_distr = prob_distr + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample(self, num_samples: int): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = d_idxs.reshape(1, self.D) < num_active_weights.reshape(N, 1) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + weights[active_weight_idxs] += 1.0 + + return weights diff --git a/sparsecoding/priors/lo_prior_test.py b/sparsecoding/priors/lo_prior_test.py new file mode 100644 index 0000000..7cfa206 --- /dev/null +++ b/sparsecoding/priors/lo_prior_test.py @@ -0,0 +1,33 @@ +import torch + +from sparsecoding.priors import L0Prior + + +def test_l0_prior(): + N = 10000 + prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) + + D = prob_distr.shape[0] + + l0_prior = L0Prior(prob_distr) + weights = l0_prior.sample(N) + + assert weights.shape == (N, D) + + # Check uniform distribution over which weights are active. + per_weight_hist = torch.sum(weights, dim=0) # [D] + normalized_per_weight_hist = per_weight_hist / torch.sum(per_weight_hist) # [D] + assert torch.allclose( + normalized_per_weight_hist, + torch.full((D,), 0.25, dtype=torch.float32), + atol=1e-2, + ) + + # Check the distribution over the l0-norm of the weights. + num_active_per_sample = torch.sum(weights, dim=1) # [N] + for num_active in range(1, 5): + assert torch.allclose( + torch.sum(num_active_per_sample == num_active) / N, + prob_distr[num_active - 1], + atol=1e-2, + ) diff --git a/sparsecoding/priors/prior.py b/sparsecoding/priors/prior.py new file mode 100644 index 0000000..061c62f --- /dev/null +++ b/sparsecoding/priors/prior.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod + + +class Prior(ABC): + """A distribution over weights. + + Parameters + ---------- + weights_dim : int + Number of weights for each sample. + """ + + @abstractmethod + def D(self): + """ + Number of weights per sample. + """ + + @abstractmethod + def sample( + self, + num_samples: int = 1, + ): + """Sample weights from the prior. + + Parameters + ---------- + num_samples : int, default=1 + Number of samples. + + Returns + ------- + samples : Tensor, shape [num_samples, self.D] + Sampled weights. + """ diff --git a/sparsecoding/priors/spike_slab_prior.py b/sparsecoding/priors/spike_slab_prior.py new file mode 100644 index 0000000..1fa059a --- /dev/null +++ b/sparsecoding/priors/spike_slab_prior.py @@ -0,0 +1,72 @@ +import torch +from torch.distributions.laplace import Laplace + +from .prior import Prior + + +class SpikeSlabPrior(Prior): + """Prior where weights are drawn from a "spike-and-slab" distribution. + + The "spike" is at 0 and the "slab" is Laplacian. + + See: + https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf + for a good review of the spike-and-slab model. + + Parameters + ---------- + dim : int + Number of weights per sample. + p_spike : float + The probability of the weight being 0. + scale : float + The "scale" of the Laplacian distribution (larger is wider). + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + p_spike: float, + scale: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if p_spike < 0 or p_spike > 1: + raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") + if scale <= 0: + raise ValueError(f"`scale` must be positive, got {scale}.") + + self.dim = dim + self.p_spike = p_spike + self.scale = scale + self.positive_only = positive_only + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + N = num_samples + + zero_weights = torch.zeros((N, self.D), dtype=torch.float32) + slab_weights = Laplace( + loc=zero_weights, + scale=torch.full((N, self.D), self.scale, dtype=torch.float32), + ).sample() # [N, D] + + if self.positive_only: + slab_weights = torch.abs(slab_weights) + + spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike + + weights = torch.where( + spike_over_slab, + zero_weights, + slab_weights, + ) + + return weights diff --git a/sparsecoding/priors/spike_slab_prior_test.py b/sparsecoding/priors/spike_slab_prior_test.py new file mode 100644 index 0000000..d0054cc --- /dev/null +++ b/sparsecoding/priors/spike_slab_prior_test.py @@ -0,0 +1,52 @@ +import pytest +import torch + +from sparsecoding.priors import SpikeSlabPrior + + +@pytest.mark.parametrize("positive_only", [True, False]) +def test_spike_slab_prior(positive_only: bool): + N = 10000 + D = 4 + p_spike = 0.5 + scale = 1.0 + + p_slab = 1.0 - p_spike + + spike_slab_prior = SpikeSlabPrior( + D, + p_spike, + scale, + positive_only, + ) + weights = spike_slab_prior.sample(N) + + assert weights.shape == (N, D) + + # Check spike probability. + assert torch.allclose( + torch.sum(weights == 0.0) / (N * D), + torch.tensor(p_spike), + atol=1e-2, + ) + + # Check Laplacian distribution. + N_slab = p_slab * N * D + if positive_only: + assert torch.sum(weights < 0.0) == 0 + else: + assert torch.allclose( + torch.sum(weights < 0.0) / N_slab, + torch.sum(weights > 0.0) / N_slab, + atol=2e-2, + ) + weights = torch.abs(weights) + + laplace_weights = weights[weights > 0.0] + for quantile in torch.arange(5) / 5.0: + cutoff = -torch.log(1.0 - quantile) + assert torch.allclose( + torch.sum(laplace_weights < cutoff) / N_slab, + quantile, + atol=1e-2, + ) diff --git a/sparsecoding/test_utils/__init__.py b/sparsecoding/test_utils/__init__.py new file mode 100644 index 0000000..8c18de1 --- /dev/null +++ b/sparsecoding/test_utils/__init__.py @@ -0,0 +1,16 @@ +from .asserts import assert_allclose, assert_shape_equal +from .constant_fixtures import dataset_size_fixture, patch_size_fixture +from .dataset_fixtures import bars_datas_fixture, bars_datasets_fixture +from .model_fixtures import bars_dictionary_fixture +from .prior_fixtures import priors_fixture + +__all__ = [ + "assert_allclose", + "assert_shape_equal", + "dataset_size_fixture", + "patch_size_fixture", + "bars_datas_fixture", + "bars_datasets_fixture", + "bars_dictionary_fixture", + "priors_fixture", +] diff --git a/sparsecoding/test_utils/asserts.py b/sparsecoding/test_utils/asserts.py new file mode 100644 index 0000000..64db496 --- /dev/null +++ b/sparsecoding/test_utils/asserts.py @@ -0,0 +1,13 @@ +import numpy as np + +# constants +DEFAULT_ATOL = 1e-6 +DEFAULT_RTOL = 1e-5 + + +def assert_allclose(a: np.ndarray, b: np.ndarray, rtol: float = DEFAULT_RTOL, atol: float = DEFAULT_ATOL) -> None: + return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + + +def assert_shape_equal(a: np.ndarray, b: np.ndarray) -> None: + assert a.shape == b.shape diff --git a/sparsecoding/test_utils/asserts_test.py b/sparsecoding/test_utils/asserts_test.py new file mode 100644 index 0000000..1af9705 --- /dev/null +++ b/sparsecoding/test_utils/asserts_test.py @@ -0,0 +1,28 @@ +import numpy as np +import torch + +from .asserts import assert_allclose, assert_shape_equal + + +def test_pytorch_all_close(): + result = torch.ones([10, 10]) + 1e-10 + expected = torch.ones([10, 10]) + assert_allclose(result, expected) + + +def test_np_all_close(): + result = np.ones([100, 100]) + 1e-10 + expected = np.ones([100, 100]) + assert_allclose(result, expected) + + +def test_assert_pytorch_shape_equal(): + a = torch.zeros([10, 10]) + b = torch.ones([10, 10]) + assert_shape_equal(a, b) + + +def test_assert_np_shape_equal(): + a = np.zeros([100, 100]) + b = np.ones([100, 100]) + assert_shape_equal(a, b) diff --git a/sparsecoding/test_utils/constant_fixtures.py b/sparsecoding/test_utils/constant_fixtures.py new file mode 100644 index 0000000..1c04698 --- /dev/null +++ b/sparsecoding/test_utils/constant_fixtures.py @@ -0,0 +1,14 @@ +import pytest + +PATCH_SIZE = 8 +DATASET_SIZE = 1000 + + +@pytest.fixture() +def patch_size_fixture() -> int: + return PATCH_SIZE + + +@pytest.fixture() +def dataset_size_fixture() -> int: + return DATASET_SIZE diff --git a/sparsecoding/test_utils/dataset_fixtures.py b/sparsecoding/test_utils/dataset_fixtures.py new file mode 100644 index 0000000..c3f20d4 --- /dev/null +++ b/sparsecoding/test_utils/dataset_fixtures.py @@ -0,0 +1,29 @@ +import pytest +import torch + +from sparsecoding.datasets import BarsDataset +from sparsecoding.priors import Prior + + +@pytest.fixture() +def bars_datasets_fixture( + patch_size_fixture: int, dataset_size_fixture: int, priors_fixture: list[Prior] +) -> list[BarsDataset]: + return [ + BarsDataset( + patch_size=patch_size_fixture, + dataset_size=dataset_size_fixture, + prior=prior, + ) + for prior in priors_fixture + ] + + +@pytest.fixture() +def bars_datas_fixture( + patch_size_fixture: int, dataset_size_fixture: int, bars_datasets_fixture: list[BarsDataset] +) -> list[torch.Tensor]: + return [ + dataset.data.reshape((dataset_size_fixture, patch_size_fixture * patch_size_fixture)) + for dataset in bars_datasets_fixture + ] diff --git a/sparsecoding/test_utils/model_fixtures.py b/sparsecoding/test_utils/model_fixtures.py new file mode 100644 index 0000000..1a0d895 --- /dev/null +++ b/sparsecoding/test_utils/model_fixtures.py @@ -0,0 +1,10 @@ +import pytest +import torch + +from sparsecoding.datasets import BarsDataset + + +@pytest.fixture() +def bars_dictionary_fixture(patch_size_fixture: int, bars_datasets_fixture: list[BarsDataset]) -> torch.Tensor: + """Return a bars dataset basis reshaped to represent a dictionary.""" + return bars_datasets_fixture[0].basis.reshape((2 * patch_size_fixture, patch_size_fixture * patch_size_fixture)).T diff --git a/sparsecoding/test_utils/prior_fixtures.py b/sparsecoding/test_utils/prior_fixtures.py new file mode 100644 index 0000000..08a3125 --- /dev/null +++ b/sparsecoding/test_utils/prior_fixtures.py @@ -0,0 +1,24 @@ +import pytest +import torch + +from sparsecoding.priors import L0Prior, Prior, SpikeSlabPrior + + +@pytest.fixture() +def priors_fixture(patch_size_fixture: int) -> list[Prior]: + return [ + SpikeSlabPrior( + dim=2 * patch_size_fixture, + p_spike=0.8, + scale=1.0, + positive_only=True, + ), + L0Prior( + prob_distr=( + torch.nn.functional.one_hot( + torch.tensor(1), + num_classes=2 * patch_size_fixture, + ).type(torch.float32) + ), + ), + ] diff --git a/sparsecoding/transforms/__init__.py b/sparsecoding/transforms/__init__.py index 8b77c4d..f01adca 100644 --- a/sparsecoding/transforms/__init__.py +++ b/sparsecoding/transforms/__init__.py @@ -1,7 +1,20 @@ from .whiten import whiten, compute_whitening_stats -from .images import whiten_images, compute_image_whitening_stats, WhiteningTransform, \ - quilt, patchify, sample_random_patches +from .images import ( + whiten_images, + compute_image_whitening_stats, + WhiteningTransform, + quilt, + patchify, + sample_random_patches, +) -__all__ = ['quilt', 'patchify', 'sample_random_patches', 'whiten', - 'compute_whitening_stats', 'compute_image_whitening_stats', - 'WhiteningTransform', 'whiten_images'] +__all__ = [ + "quilt", + "patchify", + "sample_random_patches", + "whiten", + "compute_whitening_stats", + "compute_image_whitening_stats", + "WhiteningTransform", + "whiten_images", +] diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index eb1ecf3..9e0f20b 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -6,32 +6,34 @@ from .whiten import whiten, compute_whitening_stats -def check_images(images: torch.Tensor, algorithm: str = 'zca'): - """Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method - """ +def check_images(images: torch.Tensor, algorithm: str = "zca"): + """Verify that tensor is in the shape [N, C, H, W] and C != when using fourier based method""" if len(images.shape) != 4: - raise ValueError('Images must be in shape [N, C, H, W]') + raise ValueError("Images must be in shape [N, C, H, W]") - if images.shape[1] != 1 and algorithm == 'frequency': - raise ValueError("When using frequency based decorrelation, images must" + - f"be grayscale, received {images.shape[1]} channels") + if images.shape[1] != 1 and algorithm == "frequency": + raise ValueError( + "When using frequency based decorrelation, images must" + + f"be grayscale, received {images.shape[1]} channels" + ) # Running cov based methods on large images can eat memory - if algorithm in ['zca', 'pca', 'cholesky'] and (images.shape[2] > 64 or images.shape[3] > 64): - print(f"WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}." + - "It is not recommended to use this for images smaller than 64x64") + if algorithm in ["zca", "pca", "cholesky"] and (images.shape[2] > 64 or images.shape[3] > 64): + print( + f"WARNING: Running covaraince based whitening for images of size {images.shape[2]}x{images.shape[3]}." + + "It is not recommended to use this for images smaller than 64x64" + ) # Running cov based methods on large images can eat memory - if algorithm == 'frequency' and (images.shape[2] <= 64 or images.shape[3] <= 64): - print(f"WARNING: Running frequency based whitening for images of size {images.shape[2]}x{images.shape[3]}." + - "It is recommended to use this for images larger than 64x64") + if algorithm == "frequency" and (images.shape[2] <= 64 or images.shape[3] <= 64): + print( + f"WARNING: Running frequency based whitening for images of size {images.shape[2]}x{images.shape[3]}." + + "It is recommended to use this for images larger than 64x64" + ) -def whiten_images(images: torch.Tensor, - algorithm: str, - stats: Dict = None, - **kwargs) -> torch.Tensor: +def whiten_images(images: torch.Tensor, algorithm: str, stats: Dict = None, **kwargs) -> torch.Tensor: """ Wrapper for all whitening transformations @@ -48,18 +50,20 @@ def whiten_images(images: torch.Tensor, check_images(images, algorithm) - if algorithm == 'frequency': + if algorithm == "frequency": return frequency_whitening(images, **kwargs) - elif algorithm in ['zca', 'pca', 'cholesky']: + elif algorithm in ["zca", "pca", "cholesky"]: N, C, H, W = images.shape flattened_images = images.flatten(start_dim=1) whitened = whiten(flattened_images, algorithm, stats, **kwargs) return whitened.reshape((N, C, H, W)) else: - raise ValueError(f"Unknown whitening algorithm: {algorithm}, \ - must be one of ['frequency', 'pca', 'zca', 'cholesky]") + raise ValueError( + f"Unknown whitening algorithm: {algorithm}, \ + must be one of ['frequency', 'pca', 'zca', 'cholesky]" + ) def compute_image_whitening_stats(images: torch.Tensor) -> Dict: @@ -95,13 +99,13 @@ def create_frequency_filter(image_size: int, f0_factor: float = 0.4) -> torch.Te ---------- torch.Tensor: Frequency domain filter """ - fx = torch.linspace(-image_size/2, image_size/2-1, image_size) - fy = torch.linspace(-image_size/2, image_size/2-1, image_size) - fx, fy = torch.meshgrid(fx, fy, indexing='xy') + fx = torch.linspace(-image_size / 2, image_size / 2 - 1, image_size) + fy = torch.linspace(-image_size / 2, image_size / 2 - 1, image_size) + fx, fy = torch.meshgrid(fx, fy, indexing="xy") rho = torch.sqrt(fx**2 + fy**2) f_0 = f0_factor * image_size - filt = rho * torch.exp(-(rho/f_0)**4) + filt = rho * torch.exp(-((rho / f_0) ** 4)) return fft.fftshift(filt) @@ -123,7 +127,7 @@ def get_cached_filter(image_size: int, f0_factor: float = 0.4) -> torch.Tensor: return create_frequency_filter(image_size, f0_factor) -def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.) -> torch.Tensor: +def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.0) -> torch.Tensor: """ Normalize the variance of a tensor to a target value. @@ -146,11 +150,7 @@ def normalize_variance(tensor: torch.Tensor, target_variance: float = 1.) -> tor return centered -def whiten_channel( - channel: torch.Tensor, - filt: torch.Tensor, - target_variance: float = 1. -) -> torch.Tensor: +def whiten_channel(channel: torch.Tensor, filt: torch.Tensor, target_variance: float = 1.0) -> torch.Tensor: """ Apply frequency domain whitening to a single channel. @@ -181,11 +181,7 @@ def whiten_channel( return whitened -def frequency_whitening( - images: torch.Tensor, - target_variance: float = 0.1, - f0_factor: float = 0.4 -) -> torch.Tensor: +def frequency_whitening(images: torch.Tensor, target_variance: float = 0.1, f0_factor: float = 0.4) -> torch.Tensor: """ Apply frequency domain decorrelation to batched images. Method used in original sparsenet in Olshausen and Field in Nature @@ -211,9 +207,7 @@ def frequency_whitening( # Process each image in the batch whitened_batch = [] for img in images: - whitened_batch.append( - whiten_channel(img[0], filt, target_variance) - ) + whitened_batch.append(whiten_channel(img[0], filt, target_variance)) return torch.stack(whitened_batch).unsqueeze(1) @@ -223,13 +217,8 @@ class WhiteningTransform(object): A PyTorch transform for image whitening that can be used in a transform pipeline. Supports frequency, PCA, and ZCA whitening methods. """ - def __init__( - self, - algorithm: str = 'zca', - stats: Optional[Dict] = None, - compute_stats: bool = False, - **kwargs - ): + + def __init__(self, algorithm: str = "zca", stats: Optional[Dict] = None, compute_stats: bool = False, **kwargs): """ Initialize whitening transform. @@ -266,12 +255,7 @@ def __call__(self, images: torch.Tensor) -> torch.Tensor: check_images(images) # Apply whitening - whitened = whiten_images( - images, - self.algorithm, - self.stats, - **self.kwargs - ) + whitened = whiten_images(images, self.algorithm, self.stats, **self.kwargs) # Remove batch dimension if input was single image if single_image: @@ -322,18 +306,11 @@ def sample_random_patches( size=(N,), ) - h_patch_idxs, w_patch_idxs = torch.meshgrid( - torch.arange(P), - torch.arange(P), - indexing='ij' - ) + h_patch_idxs, w_patch_idxs = torch.meshgrid(torch.arange(P), torch.arange(P), indexing="ij") h_idxs = h_start_idx.reshape(N, 1, 1) + h_patch_idxs w_idxs = w_start_idx.reshape(N, 1, 1) + w_patch_idxs - leading_idxs = [ - torch.randint(low=0, high=image.shape[d], size=(N, 1, 1)) - for d in range(image.dim() - 3) - ] + leading_idxs = [torch.randint(low=0, high=image.shape[d], size=(N, 1, 1)) for d in range(image.dim() - 3)] idxs = leading_idxs + [slice(None), h_idxs, w_idxs] @@ -379,20 +356,14 @@ def patchify( if stride is None: stride = P - if ( - H % P != 0 - or W % P != 0 - ): + if H % P != 0 or W % P != 0: warnings.warn( f"Image size ({H, W}) not evenly divisible by `patch_size` ({P})," f"parts on the bottom and/or right will be cropped.", UserWarning, ) - N = ( - int((H - P + 1 + stride) // stride) - * int((W - P + 1 + stride) // stride) - ) + N = int((H - P + 1 + stride) // stride) * int((W - P + 1 + stride) // stride) patches = torch.nn.functional.unfold( input=image.reshape(-1, C, H, W), @@ -439,22 +410,16 @@ def quilt( W = width if int(H / P) * int(W / P) != N: - raise ValueError( - f"Expected {N} patches per image, " - f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}." - ) + raise ValueError(f"Expected {N} patches per image, " f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}.") - if ( - H % P != 0 - or W % P != 0 - ): + if H % P != 0 or W % P != 0: warnings.warn( f"Image size ({H, W}) not evenly divisible by `patch_size` ({P})," f"parts on the bottom and/or right will be zeroed.", UserWarning, ) - patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P] + patches = patches.reshape(-1, N, C * P * P) # [prod(*), N, C*P*P] patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N] image = torch.nn.functional.fold( input=patches, diff --git a/sparsecoding/transforms/images_test.py b/sparsecoding/transforms/images_test.py new file mode 100644 index 0000000..d7999c5 --- /dev/null +++ b/sparsecoding/transforms/images_test.py @@ -0,0 +1,49 @@ +import torch + +from sparsecoding.transforms import patchify, quilt, sample_random_patches + + +def test_patchify_quilt_cycle(): + X, Y, Z = 3, 4, 5 + C = 3 + P = 8 + H = 6 * P + W = 8 * P + + images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) + + patches = patchify(P, images) + assert patches.shape == (X, Y, Z, int(H / P) * int(W / P), C, P, P) + + quilted_images = quilt(H, W, patches) + assert torch.allclose( + images, + quilted_images, + ), "Quilted images should be equal to input images." + + +def test_sample_random_patches(): + X, Y, Z = 3, 4, 5 + C = 3 + P = 8 + H = 4 * P + W = 8 * P + N = 10 + + images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) + + random_patches = sample_random_patches(P, N, images) + assert random_patches.shape == (N, C, P, P) + + # Check that patches are actually taken from one of the images. + all_patches = torch.nn.functional.unfold( + input=images.reshape(-1, C, H, W), + kernel_size=P, + ) # [prod(*), C*P*P, L] + all_patches = torch.permute(all_patches, (0, 2, 1)) # [prod(*), L, C*P*P] + all_patches = torch.reshape(all_patches, (-1, C * P * P)) + for n in range(N): + patch = random_patches[n].reshape(1, C * P * P) + delta = torch.abs(patch - all_patches) # [-1, C*P*P] + patchwise_delta = torch.sum(delta, dim=1) # [-1] + assert torch.min(patchwise_delta) == 0.0 diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index 64641e4..4206adc 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -26,21 +26,17 @@ def compute_whitening_stats(X: torch.Tensor): eigenvalues = torch.flip(eigenvalues, dims=[0]) eigenvectors = torch.flip(eigenvectors, dims=[1]) - return { - 'mean': mean, - 'eigenvalues': eigenvalues, - 'eigenvectors': eigenvectors, - 'covariance': Sigma - } - - -def whiten(X: torch.Tensor, - algorithm: str = 'zca', - stats: Dict = None, - n_components: float = None, - epsilon: float = 0., - return_W: bool = False - ) -> torch.Tensor: + return {"mean": mean, "eigenvalues": eigenvalues, "eigenvectors": eigenvectors, "covariance": Sigma} + + +def whiten( + X: torch.Tensor, + algorithm: str = "zca", + stats: Dict = None, + n_components: float = None, + epsilon: float = 0.0, + return_W: bool = False, +) -> torch.Tensor: """ Apply whitening transform to data using pre-computed statistics. @@ -73,21 +69,21 @@ def whiten(X: torch.Tensor, if stats is None: stats = compute_whitening_stats(X) - x_centered = X - stats.get('mean') + x_centered = X - stats.get("mean") - if algorithm == 'pca' or algorithm == 'zca': + if algorithm == "pca" or algorithm == "zca": - scaling = 1. / torch.sqrt(stats.get('eigenvalues') + epsilon) + scaling = 1.0 / torch.sqrt(stats.get("eigenvalues") + epsilon) if n_components is not None: if isinstance(n_components, float): if not 0 < n_components <= 1: raise ValueError("If n_components is float, it must be between 0 and 1") - explained_variance_ratio = stats.get('eigenvalues') / torch.sum(stats.get('eigenvalues')) + explained_variance_ratio = stats.get("eigenvalues") / torch.sum(stats.get("eigenvalues")) cumulative_variance_ratio = torch.cumsum(explained_variance_ratio, dim=0) n_components = torch.sum(cumulative_variance_ratio <= n_components) + 1 elif isinstance(n_components, int): - if not 0 < n_components <= len(stats.get('eigenvalues')): + if not 0 < n_components <= len(stats.get("eigenvalues")): raise ValueError(f"n_components must be between 1 and {len(stats.get('eigenvalues'))}") else: raise ValueError("n_components must be int or float") @@ -98,17 +94,15 @@ def whiten(X: torch.Tensor, scaling = torch.diag(scaling) - if algorithm == 'pca': + if algorithm == "pca": # For PCA: project onto eigenvectors and scale - W = scaling @ stats.get('eigenvectors').T + W = scaling @ stats.get("eigenvectors").T else: # For ZCA: project, scale, and rotate back - W = (stats.get('eigenvectors') @ - scaling @ - stats.get('eigenvectors').T) - elif algorithm == 'cholesky': + W = stats.get("eigenvectors") @ scaling @ stats.get("eigenvectors").T + elif algorithm == "cholesky": # Based on Cholesky decomp, related to QR decomp - L = torch.linalg.cholesky(stats.get('covariance')) + L = torch.linalg.cholesky(stats.get("covariance")) Identity = torch.eye(L.shape[0], device=L.device, dtype=L.dtype) # Solve L @ W = I for W, more stable and quicker than inv(L) W = torch.linalg.solve_triangular(L, Identity, upper=False) diff --git a/sparsecoding/transforms/whiten_test.py b/sparsecoding/transforms/whiten_test.py new file mode 100644 index 0000000..9017da0 --- /dev/null +++ b/sparsecoding/transforms/whiten_test.py @@ -0,0 +1,63 @@ +import torch + +from sparsecoding.transforms import whiten + + +def test_zca(): + N = 5000 + D = 32 * 32 + + X = torch.rand((N, D), dtype=torch.float32) + + X_whitened = whiten(X) + + assert torch.allclose( + torch.mean(X_whitened, dim=0), + torch.zeros(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have zero mean." + assert torch.allclose( + torch.cov(X_whitened.T), + torch.eye(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have unit (identity) covariance." + + +def test_pca(): + N = 5000 + D = 32 * 32 + + X = torch.rand((N, D), dtype=torch.float32) + + X_whitened = whiten(X, algorithm="pca") + + assert torch.allclose( + torch.mean(X_whitened, dim=0), + torch.zeros(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have zero mean." + assert torch.allclose( + torch.cov(X_whitened.T), + torch.eye(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have unit (identity) covariance." + + +def test_cholesky(): + N = 5000 + D = 32 * 32 + + X = torch.rand((N, D), dtype=torch.float32) + + X_whitened = whiten(X, algorithm="cholesky") + + assert torch.allclose( + torch.mean(X_whitened, dim=0), + torch.zeros(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have zero mean." + assert torch.allclose( + torch.cov(X_whitened.T), + torch.eye(D, dtype=torch.float32), + atol=1e-3, + ), "Whitened data should have unit (identity) covariance." diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/inference/common.py b/tests/inference/common.py deleted file mode 100644 index 2305ed7..0000000 --- a/tests/inference/common.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - -from sparsecoding.priors import L0Prior, SpikeSlabPrior -from sparsecoding.datasets import BarsDataset - -torch.manual_seed(1997) - -PATCH_SIZE = 8 -DATASET_SIZE = 1000 - -PRIORS = [ - SpikeSlabPrior( - dim=2 * PATCH_SIZE, - p_spike=0.8, - scale=1.0, - positive_only=True, - ), - L0Prior( - prob_distr=( - torch.nn.functional.one_hot( - torch.tensor(1), - num_classes=2 * PATCH_SIZE, - ).type(torch.float32) - ), - ), -] - -DATASET = [ - BarsDataset( - patch_size=PATCH_SIZE, - dataset_size=DATASET_SIZE, - prior=prior, - ) - for prior in PRIORS -] - -DATAS = [ - dataset.data.reshape((DATASET_SIZE, PATCH_SIZE * PATCH_SIZE)) - for dataset in DATASET -] -DICTIONARY = DATASET[0].basis.reshape((2 * PATCH_SIZE, PATCH_SIZE * PATCH_SIZE)).T diff --git a/tests/inference/test_ISTA.py b/tests/inference/test_ISTA.py deleted file mode 100644 index 0edb50e..0000000 --- a/tests/inference/test_ISTA.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import ( - DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE -) - - -class TestISTA(TestCase): - def test_shape(self): - """ - Test that ISTA inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.ISTA(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - inference_method = inference.ISTA(N_ITER, return_all_coefficients=True) - a = inference_method.infer(data, DICTIONARY) - self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) - - def test_inference(self): - """ - Test that ISTA inference recovers the correct weights. - """ - N_ITER = 5000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.ISTA(n_iter=N_ITER) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_LCA.py b/tests/inference/test_LCA.py deleted file mode 100644 index 8628377..0000000 --- a/tests/inference/test_LCA.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import ( - DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE -) - - -class TestLCA(TestCase): - def test_shape(self): - """ - Test that LCA inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LCA(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - for retval in ["active", "membrane"]: - inference_method = inference.LCA(N_ITER, return_all_coefficients=retval) - a = inference_method.infer(data, DICTIONARY) - self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) - - def test_inference(self): - """ - Test that LCA inference recovers the correct weights. - """ - LR = 5e-2 - THRESHOLD = 0.1 - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LCA( - coeff_lr=LR, - threshold=THRESHOLD, - n_iter=N_ITER, - ) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_LSM.py b/tests/inference/test_LSM.py deleted file mode 100644 index 0084daf..0000000 --- a/tests/inference/test_LSM.py +++ /dev/null @@ -1,35 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import DATAS, DATASET, DICTIONARY - - -class TestLSM(TestCase): - def test_shape(self): - """ - Test that LSM inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LSM(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - def test_inference(self): - """ - Test that LSM inference recovers the correct weights. - """ - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.LSM(n_iter=N_ITER) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_PyTorchOptimizer.py b/tests/inference/test_PyTorchOptimizer.py deleted file mode 100644 index 3c69885..0000000 --- a/tests/inference/test_PyTorchOptimizer.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import DATAS, DATASET, DICTIONARY - - -class TestPyTorchOptimizer(TestCase): - def lasso_loss(data, dictionary, coefficients, sparsity_penalty): - """ - Generic MSE + l1-norm loss. - """ - batch_size = data.shape[0] - datahat = (dictionary@coefficients.t()).t() - - mse_loss = torch.linalg.vector_norm(datahat-data, dim=1).square() - sparse_loss = torch.sum(torch.abs(coefficients), axis=1) - - total_loss = (mse_loss + sparsity_penalty*sparse_loss)/batch_size - return total_loss - - def loss_fn(data, dictionary, coefficients): - return TestPyTorchOptimizer.lasso_loss( - data, - dictionary, - coefficients, - sparsity_penalty=1., - ) - - def optimizer_fn(coefficients): - return torch.optim.Adam( - coefficients, - lr=0.1, - betas=(0.9, 0.999), - eps=1e-08, - weight_decay=0, - ) - - def test_shape(self): - """ - Test that PyTorchOptimizer inference returns expected shapes. - """ - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.PyTorchOptimizer( - TestPyTorchOptimizer.optimizer_fn, - TestPyTorchOptimizer.loss_fn, - n_iter=10, - ) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - def test_inference(self): - """ - Test that PyTorchOptimizer inference recovers the correct weights. - """ - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.PyTorchOptimizer( - TestPyTorchOptimizer.optimizer_fn, - TestPyTorchOptimizer.loss_fn, - n_iter=N_ITER, - ) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=1e-1, rtol=1e-1) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/inference/test_Vanilla.py b/tests/inference/test_Vanilla.py deleted file mode 100644 index af9ae2f..0000000 --- a/tests/inference/test_Vanilla.py +++ /dev/null @@ -1,42 +0,0 @@ -import unittest - -from sparsecoding import inference -from tests.testing_utilities import TestCase -from tests.inference.common import ( - DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE -) - - -class TestVanilla(TestCase): - def test_shape(self): - """ - Test that Vanilla inference returns expected shapes. - """ - N_ITER = 10 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.Vanilla(N_ITER) - a = inference_method.infer(data, DICTIONARY) - self.assertShapeEqual(a, dataset.weights) - - inference_method = inference.Vanilla(N_ITER, return_all_coefficients=True) - a = inference_method.infer(data, DICTIONARY) - self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE)) - - def test_inference(self): - """ - Test that Vanilla inference recovers the correct weights. - """ - LR = 5e-2 - N_ITER = 1000 - - for (data, dataset) in zip(DATAS, DATASET): - inference_method = inference.Vanilla(coeff_lr=LR, n_iter=N_ITER) - - a = inference_method.infer(data, DICTIONARY) - - self.assertAllClose(a, dataset.weights, atol=5e-2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/priors/__init__.py b/tests/priors/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py deleted file mode 100644 index 12dcce6..0000000 --- a/tests/priors/test_l0.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import unittest - -from sparsecoding.priors import L0Prior - - -class TestL0Prior(unittest.TestCase): - def test_l0_prior(self): - N = 10000 - prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) - - torch.manual_seed(1997) - - D = prob_distr.shape[0] - - l0_prior = L0Prior(prob_distr) - weights = l0_prior.sample(N) - - assert weights.shape == (N, D) - - # Check uniform distribution over which weights are active. - per_weight_hist = torch.sum(weights, dim=0) # [D] - normalized_per_weight_hist = per_weight_hist / torch.sum(per_weight_hist) # [D] - assert torch.allclose( - normalized_per_weight_hist, - torch.full((D,), 0.25, dtype=torch.float32), - atol=1e-2, - ) - - # Check the distribution over the l0-norm of the weights. - num_active_per_sample = torch.sum(weights, dim=1) # [N] - for num_active in range(1, 5): - assert torch.allclose( - torch.sum(num_active_per_sample == num_active) / N, - prob_distr[num_active - 1], - atol=1e-2, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py deleted file mode 100644 index 20804fe..0000000 --- a/tests/priors/test_spike_slab.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import unittest - -from sparsecoding.priors import SpikeSlabPrior - - -class TestSpikeSlabPrior(unittest.TestCase): - def test_spike_slab_prior(self): - N = 10000 - D = 4 - p_spike = 0.5 - scale = 1. - - torch.manual_seed(1997) - - p_slab = 1. - p_spike - - for positive_only in [True, False]: - spike_slab_prior = SpikeSlabPrior( - D, - p_spike, - scale, - positive_only, - ) - weights = spike_slab_prior.sample(N) - - assert weights.shape == (N, D) - - # Check spike probability. - assert torch.allclose( - torch.sum(weights == 0.) / (N * D), - torch.tensor(p_spike), - atol=1e-2, - ) - - # Check Laplacian distribution. - N_slab = p_slab * N * D - if positive_only: - assert torch.sum(weights < 0.) == 0 - else: - assert torch.allclose( - torch.sum(weights < 0.) / N_slab, - torch.sum(weights > 0.) / N_slab, - atol=2e-2, - ) - weights = torch.abs(weights) - - laplace_weights = weights[weights > 0.] - for quantile in torch.arange(5) / 5.: - cutoff = -torch.log(1. - quantile) - assert torch.allclose( - torch.sum(laplace_weights < cutoff) / N_slab, - quantile, - atol=1e-2, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_test.py b/tests/test_test.py deleted file mode 100644 index ff3e42e..0000000 --- a/tests/test_test.py +++ /dev/null @@ -1,35 +0,0 @@ -from tests.testing_utilities import TestCase -import torch -import numpy as np - - -class TestTestCaseBaseClass(TestCase): - - def test_pytorch_all_close(self): - result = torch.ones([10, 10]) + 1e-10 - expected = torch.ones([10, 10]) - self.assertAllClose(result, expected) - - def test_np_all_close(self): - result = np.ones([100, 100]) + 1e-10 - expected = np.ones([100, 100]) - self.assertAllClose(result, expected) - - def test_assert_true(self): - self.assertTrue(True) - - def test_assert_false(self): - self.assertFalse(False) - - def test_assert_equal(self): - self.assertEqual('sparse coding', 'sparse coding') - - def test_assert_pytorch_shape_equal(self): - a = torch.zeros([10, 10]) - b = torch.ones([10, 10]) - self.assertShapeEqual(a, b) - - def test_assert_np_shape_equal(self): - a = np.zeros([100, 100]) - b = np.ones([100, 100]) - self.assertShapeEqual(a, b) diff --git a/tests/testing_utilities.py b/tests/testing_utilities.py deleted file mode 100644 index 8a3de74..0000000 --- a/tests/testing_utilities.py +++ /dev/null @@ -1,17 +0,0 @@ -import numpy as np -import unittest - - -# constants -default_atol = 1e-6 -default_rtol = 1e-5 - - -class TestCase(unittest.TestCase): - '''Base class for testing''' - - def assertAllClose(self, a, b, rtol=default_rtol, atol=default_atol): - return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol) - - def assertShapeEqual(self, a, b): - assert a.shape == b.shape diff --git a/tests/transforms/__init__.py b/tests/transforms/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/transforms/test_patch.py b/tests/transforms/test_patch.py deleted file mode 100644 index 22cafbc..0000000 --- a/tests/transforms/test_patch.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import unittest - -from sparsecoding.transforms import sample_random_patches, patchify, quilt - - -class TestPatcher(unittest.TestCase): - def test_patchify_quilt_cycle(self): - X, Y, Z = 3, 4, 5 - C = 3 - P = 8 - H = 6 * P - W = 8 * P - - torch.manual_seed(1997) - - images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) - - patches = patchify(P, images) - assert patches.shape == (X, Y, Z, int(H / P) * int(W / P), C, P, P) - - quilted_images = quilt(H, W, patches) - assert torch.allclose( - images, - quilted_images, - ), "Quilted images should be equal to input images." - - def test_sample_random_patches(self): - X, Y, Z = 3, 4, 5 - C = 3 - P = 8 - H = 4 * P - W = 8 * P - N = 10 - - torch.manual_seed(1997) - - images = torch.rand((X, Y, Z, C, H, W), dtype=torch.float32) - - random_patches = sample_random_patches(P, N, images) - assert random_patches.shape == (N, C, P, P) - - # Check that patches are actually taken from one of the images. - all_patches = torch.nn.functional.unfold( - input=images.reshape(-1, C, H, W), - kernel_size=P, - ) # [prod(*), C*P*P, L] - all_patches = torch.permute(all_patches, (0, 2, 1)) # [prod(*), L, C*P*P] - all_patches = torch.reshape(all_patches, (-1, C*P*P)) - for n in range(N): - patch = random_patches[n].reshape(1, C*P*P) - delta = torch.abs(patch - all_patches) # [-1, C*P*P] - patchwise_delta = torch.sum(delta, dim=1) # [-1] - assert torch.min(patchwise_delta) == 0. - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/transforms/test_whiten.py b/tests/transforms/test_whiten.py deleted file mode 100644 index 3fa4406..0000000 --- a/tests/transforms/test_whiten.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import unittest - -from sparsecoding.transforms import whiten - - -class TestWhitener(unittest.TestCase): - def test_zca(self): - N = 5000 - D = 32*32 - - torch.manual_seed(1997) - - X = torch.rand((N, D), dtype=torch.float32) - - X_whitened = whiten(X) - - assert torch.allclose( - torch.mean(X_whitened, dim=0), - torch.zeros(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have zero mean." - assert torch.allclose( - torch.cov(X_whitened.T), - torch.eye(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have unit (identity) covariance." - - def test_pca(self): - N = 5000 - D = 32*32 - - torch.manual_seed(1997) - - X = torch.rand((N, D), dtype=torch.float32) - - X_whitened = whiten(X, algorithm='pca') - - assert torch.allclose( - torch.mean(X_whitened, dim=0), - torch.zeros(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have zero mean." - assert torch.allclose( - torch.cov(X_whitened.T), - torch.eye(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have unit (identity) covariance." - - def test_cholesky(self): - N = 5000 - D = 32*32 - - torch.manual_seed(1997) - - X = torch.rand((N, D), dtype=torch.float32) - - X_whitened = whiten(X, algorithm='cholesky') - - assert torch.allclose( - torch.mean(X_whitened, dim=0), - torch.zeros(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have zero mean." - assert torch.allclose( - torch.cov(X_whitened.T), - torch.eye(D, dtype=torch.float32), - atol=1e-3, - ), "Whitened data should have unit (identity) covariance." - - -if __name__ == "__main__": - unittest.main()