Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
235 changes: 235 additions & 0 deletions examples/cel0_inference_bars_example.ipynb

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions sparsecoding/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,3 +932,134 @@ def infer(self, data, dictionary):
residual = data.clone() - coefficients @ dictionary.T # [batch_size, n_features]

return coefficients.detach()


class CEL0(InferenceMethod):
def __init__(self, n_iter=100, coeff_lr=1e-3, threshold=1e-2, return_all_coefficients="none", solver=None):
"""
Parameters
----------
n_iter : int, default=100
Number of iterations to run
coeff_lr : float, default=1e-3
Update rate of coefficient dynamics
threshold : float, default=1e-2
Threshold for non-linearity
return_all_coefficients : str, {"none", "active"}, default="none"
Returns all coefficients during inference procedure if not equal
to "none". If return_all_coefficients=="active",
active units (a) (output of thresholding function over u) returned.
User beware: if n_iter is large, setting this parameter to True
can result in large memory usage/potential exhaustion. This
function typically used for debugging.
solver : default=None

References
----------
[1] https://arxiv.org/abs/2301.10002
"""
super().__init__(solver)
self.threshold = threshold
self.coeff_lr = coeff_lr
self.n_iter = n_iter
self.return_all_coefficients = return_all_coefficients
self.dictionary_norms = None

def threshold_nonlinearity(self, u, a=1):
'''
CEL0 thresholding function: A continuous exact l0 penalty

Note: It is assumed that the dictionary is normalized

Parameters
----------
u : array-like, shape [batch_size, n_basis]
a : the norm of the column of the dictionary, default=1

Returns
-------
re : array-like, shape [batch_size, n_basis]

'''
if a * self.coeff_lr < 1:
num = (np.abs(u) - torch.sqrt(2 * self.threshold) * a * self.coeff_lr)
num[num < 0] = 0
den = 1 - a ** 2 * self.coeff_lr
re = np.sign(u) * np.minimum(np.abs(u), np.divide(num, den)) # * (a ** 2 * self.coeff_lr < 1)
return re
else:
# TODO: This is not the same as the paper
larger = u[np.abs(u) < torch.sqrt(2 * self.threshold * self.coeff_lr)]
equal = u[np.abs(u) == torch.sqrt(2 * self.threshold * self.coeff_lr)]
re = larger + equal
return re

def infer(self, data, dictionary, coeff_0=None, use_checknan=False):
"""Infer coefficients using provided dictionary

Parameters
----------
dictionary : array-like, shape [n_features, n_basis]

data : array-like, shape [n_samples, n_features]

coeff_0 : array-like, shape [n_samples, n_basis], optional
Initial coefficient values
use_checknan : bool, default=False
Check for nans in coefficients on each iteration. Setting this to
False can speed up inference time.

Returns
-------
coefficients : array-like, shape [n_samples, n_basis] OR [n_samples, n_iter+1, n_basis]
First case occurs if return_all_coefficients == "none". If
return_all_coefficients != "none", returned shape is second case.
Returned dimension along dim 1 can be less than n_iter when
stop_early==True and stopping criteria met.
"""
batch_size, n_features = data.shape
n_features, n_basis = dictionary.shape
device = dictionary.device

# initialize
if coeff_0 is not None:
u = coeff_0.to(device)
else:
u = torch.zeros((batch_size, n_basis)).to(device)

coefficients = torch.zeros((batch_size, 0, n_basis)).to(device)

self.dictionary_norms = torch.norm(dictionary, dim=0, keepdim=True).squeeze()[0]
assert self.dictionary_norms == 1, "Dictionary must be normalized"

for i in range(self.n_iter):
# check return all
if self.return_all_coefficients != "none":
if self.return_all_coefficients == "active":
coefficients = torch.concat(
[coefficients, self.CEL0Thresholding(u).clone().unsqueeze(1)], dim=1)
else:
coefficients = torch.concat(
[coefficients, u.clone().unsqueeze(1)], dim=1)

# compute new
# Step 1: Gradient descent on u
recon = u @ dictionary.T
residual = data - recon
dLda = residual @ dictionary
u = u + self.coeff_lr * dLda

# Step 2: Thresholding
u = self.threshold_nonlinearity(u)

if use_checknan:
self.checknan(u, "coefficients")

# return active units if return_all_coefficients in ["none", "active"]
if self.return_all_coefficients == "active":
coefficients = torch.concat([coefficients, u.clone().unsqueeze(1)], dim=1)
else:
final_coefficients = u
coefficients = torch.concat([coefficients, final_coefficients.clone().unsqueeze(1)], dim=1)

return coefficients.squeeze()
41 changes: 41 additions & 0 deletions tests/inference/test_CEL0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import unittest

from sparsecoding import inference
from tests.testing_utilities import TestCase
from tests.inference.common import (
DATAS, DATASET_SIZE, DATASET, DICTIONARY, PATCH_SIZE
)


class TestCEL0(TestCase):
def test_shape(self):
"""
Test that CEL0 inference returns expected shapes.
"""
N_ITER = 10

for (data, dataset) in zip(DATAS, DATASET):
inference_method = inference.CEL0(N_ITER)
a = inference_method.infer(data, DICTIONARY)
self.assertShapeEqual(a, dataset.weights)

inference_method = inference.CEL0(N_ITER, return_all_coefficients=True)
a = inference_method.infer(data, DICTIONARY)
self.assertEqual(a.shape, (DATASET_SIZE, N_ITER + 1, 2 * PATCH_SIZE))

def test_inference(self):
"""
Test that CEL0 inference recovers the correct weights.
"""
N_ITER = 1000

for (data, dataset) in zip(DATAS, DATASET):
inference_method = inference.CEL0(n_iter=N_ITER, coeff_lr=1e-1, threshold=5e-1)

a = inference_method.infer(data, DICTIONARY)

self.assertAllClose(a, dataset.weights, atol=5e-2, rtol=1e-1)


if __name__ == "__main__":
unittest.main()
Binary file added tutorials/.DS_Store
Binary file not shown.