Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions examples/inference_bars_example.ipynb

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions examples/inference_natural_images.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions sparsecoding/data/datasets/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torch
from torch.utils.data import Dataset

from sparsecoding.data.datasets.common import RAW_DATA_DIR
from sparsecoding.data.transform.patch import patchify
from sparsecoding.data.transforms.patch import patchify


class FieldDataset(Dataset):
Expand Down Expand Up @@ -42,7 +41,7 @@ def __init__(
os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat")
os.system(f"mv IMAGES.mat {root}/field.mat")

self.images = torch.tensor(loadmat(f"{RAW_DATA_DIR}/field.mat")["IMAGES"]) # [H, W, B]
self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B]
assert self.images.shape == (self.H, self.W, self.B)

self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W]
Expand Down
18 changes: 14 additions & 4 deletions sparsecoding/data/transforms/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ class Whitener(object):
where:
N is the number of data points,
D is the dimension of the data points.

Attributes
----------
mean : Tensor, shape [D]
Mean of the input data for the Whitener.
covariance : Tensor, shape [D, D]
Co-variance of the input data for the Whitener.
eigenvalues : Tensor, shape [D]
Eigenvalues of `self.covariance`.
eigenvectors : Tensor, shape [D, D]
Eigenvectors of `self.covariance`.
"""

def __init__(
Expand All @@ -27,15 +38,14 @@ def __init__(

self.mean = torch.mean(data, dim=1) # [D]

covariance = torch.cov(data) # [D, D]
self.eigenvalues, self.eigenvectors = torch.linalg.eigh(covariance) # [D], [D, D]
self.covariance = torch.cov(data) # [D, D]
self.eigenvalues, self.eigenvectors = torch.linalg.eigh(self.covariance) # [D], [D, D]

def whiten(
self,
data: torch.Tensor,
):
"""
Whitens the input `data` to have zero mean and unit (identity) covariance.
"""Whitens the input `data` to have zero mean and unit (identity) covariance.

Uses statistics of the data from class initialization.

Expand Down
4 changes: 2 additions & 2 deletions sparsecoding/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def infer(self, data, dictionary, coeff_0=None, use_checknan=False):

# Calculate stepsize based on largest eigenvalue of
# dictionary.T @ dictionary.
lipschitz_constant = torch.symeig(
torch.mm(dictionary.T, dictionary))[0][-1]
lipschitz_constant = torch.linalg.eigvalsh(
torch.mm(dictionary.T, dictionary))[-1]
stepsize = 1. / lipschitz_constant
self.threshold = stepsize * self.sparsity_penalty

Expand Down
36 changes: 35 additions & 1 deletion sparsecoding/priors/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod

import torch


class Prior(ABC):
"""A distribution over weights.
Expand Down Expand Up @@ -29,6 +31,38 @@ def sample(

Returns
-------
samples : Tensor, shape [num_samples, self.D]
samples : Tensor, shape [num_samples, self.D()]
Sampled weights.
"""

@abstractmethod
def log_prob(
self,
sample: torch.Tensor,
):
"""Get the log-probability of the sample under this distribution.

Parameters
----------
sample : Tensor, shape [num_samples, self.D()]
Sample to get the log-probability for.

Returns
-------
log_prob : Tensor, shape [num_samples]
Log-probability of `sample`.
"""

def check_sample_input(
self,
sample: torch.Tensor,
):
"""Check the shape and dtype of the sample.

Used in:
self.log_prob().
"""
if sample.dtype != torch.float32:
raise ValueError(f"`sample` dtype should be float32, got {sample.dtype}.")
if sample.dim() != 2 or sample.shape[1] != self.D:
raise ValueError(f"`sample` should have shape [N, {self.D}], got {sample.shape}.")
19 changes: 17 additions & 2 deletions sparsecoding/priors/l0.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.")
if prob_distr.dtype != torch.float32:
raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.")
if not torch.allclose(torch.sum(prob_distr), torch.ones_like(prob_distr)):
if not torch.allclose(torch.sum(prob_distr), torch.ones(1, dtype=torch.float32)):
raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.")

self.prob_distr = prob_distr
Expand All @@ -35,7 +35,7 @@ def D(self):

def sample(
self,
num_samples: int
num_samples: int,
):
N = num_samples

Expand Down Expand Up @@ -63,3 +63,18 @@ def sample(
weights[active_weight_idxs] += 1.

return weights

def log_prob(
self,
sample: torch.Tensor,
):
super().check_sample_input(sample)

l0_norm = torch.sum(sample != 0., dim=1).type(torch.long) # [num_samples]
log_prob = torch.log(self.prob_distr[l0_norm - 1])
log_prob[l0_norm == 0] = -torch.inf
return log_prob

# TODO: Add L0ExpPrior, where the number of active units is distributed exponentially.

# TODO: Add L0IidPrior, where the magnitude of an active unit is distributed according to an i.i.d. Prior.
35 changes: 35 additions & 0 deletions sparsecoding/priors/spike_slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,38 @@ def sample(self, num_samples: int):
)

return weights

def log_prob(
self,
sample: torch.Tensor,
):
super().check_sample_input(sample)

N = sample.shape[0]

log_prob = torch.zeros((N, self.D), dtype=torch.float32)

spike_mask = sample == 0.
slab_mask = sample != 0.

# Add log-probability for spike.
log_prob[spike_mask] = torch.log(torch.tensor(self.p_spike))

# Add log-probability for slab.
if self.positive_only:
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(self.scale))
- sample[slab_mask] / self.scale
)
log_prob[sample < 0.] = -torch.inf
else:
log_prob[slab_mask] = (
torch.log(torch.tensor(1. - self.p_spike))
- torch.log(torch.tensor(2. * self.scale))
- torch.abs(sample[slab_mask]) / self.scale
)

log_prob = torch.sum(log_prob, dim=1) # [N]

return log_prob
28 changes: 27 additions & 1 deletion tests/priors/test_l0.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from itertools import product

import torch
import unittest

from sparsecoding.priors.l0 import L0Prior


class TestL0Prior(unittest.TestCase):
def test_l0_prior(self):
def test_sample(self):
N = 10000
prob_distr = torch.tensor([0.5, 0.25, 0, 0.25])

Expand Down Expand Up @@ -36,6 +38,30 @@ def test_l0_prior(self):
atol=1e-2,
)

def test_log_prob(self):
prob_distr = torch.tensor([0.75, 0.25, 0.])

l0_prior = L0Prior(prob_distr)

samples = list(product([0, 1], repeat=3)) # [2**D, D]
samples = torch.tensor(samples, dtype=torch.float32) # [2**D, D]

log_probs = l0_prior.log_prob(samples)

# The l0-norm at index `i`
# is the number of ones
# in the binary representation of `i`.
assert log_probs[0] == -torch.inf
assert torch.allclose(
log_probs[[1, 2, 4]],
torch.log(torch.tensor(0.75)),
)
assert torch.allclose(
log_probs[[3, 5, 6]],
torch.log(torch.tensor(0.25)),
)
assert log_probs[7] == -torch.inf


if __name__ == "__main__":
unittest.main()
39 changes: 38 additions & 1 deletion tests/priors/test_spike_slab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TestSpikeSlabPrior(unittest.TestCase):
def test_spike_slab_prior(self):
def test_sample(self):
N = 10000
D = 4
p_spike = 0.5
Expand Down Expand Up @@ -54,6 +54,43 @@ def test_spike_slab_prior(self):
atol=1e-2,
)

def test_log_prob(self):
D = 3
p_spike = 0.5
scale = 1.

for positive_only in [True, False]:
spike_slab_prior = SpikeSlabPrior(
D,
p_spike,
scale,
positive_only,
)

samples = torch.Tensor([[-1., 0., 1.]])

if positive_only:
assert spike_slab_prior.log_prob(samples)[0] == -torch.inf

samples = torch.abs(samples)
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
-1. + torch.log(torch.tensor(1. - p_spike))
+ torch.log(torch.tensor(p_spike))
- 1. + torch.log(torch.tensor(1. - p_spike))
)
)
else:
assert torch.allclose(
spike_slab_prior.log_prob(samples)[0],
(
-1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.))
+ torch.log(torch.tensor(p_spike))
- 1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.))
)
)


if __name__ == "__main__":
unittest.main()