From b680e97428dea8427fd7baed5de0a814d8a5d473 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Wed, 25 May 2022 03:47:02 -0700 Subject: [PATCH 1/3] add lsm prior --- sparsecoding/priors/lsm.py | 97 +++++++++++++++++++++++++++++ sparsecoding/priors/spike_slab.py | 18 +++--- tests/priors/test_lsm.py | 100 ++++++++++++++++++++++++++++++ tests/priors/test_spike_slab.py | 23 ++++--- 4 files changed, 215 insertions(+), 23 deletions(-) create mode 100644 sparsecoding/priors/lsm.py create mode 100644 tests/priors/test_lsm.py diff --git a/sparsecoding/priors/lsm.py b/sparsecoding/priors/lsm.py new file mode 100644 index 0000000..70d01da --- /dev/null +++ b/sparsecoding/priors/lsm.py @@ -0,0 +1,97 @@ +import torch +from torch.distributions.laplace import Laplace +from torch.distributions.gamma import Gamma + +from sparsecoding.priors.common import Prior + + +class LSMPrior(Prior): + """Prior where weights are drawn from i.i.d. from Laplacian scale mixtures. + + The Laplacian scale mixture is defined in: + Garrigues & Olshausen (2010) + https://papers.nips.cc/paper/2010/hash/2d6cc4b2d139a53512fb8cbb3086ae2e-Abstract.html + . + + Conceptually, a Laplacian scale mixture is just a weighted sum of Laplacian distributions + with different scales. + + In the paper, a Gamma distribution over: + the inverse of the scale parameter of the Laplacian + is used, + as that is the conjugate prior. + + Parameters + ---------- + dim : int + Number of weights per sample. + alpha : float + Shape or concentration parameter of the Gamma distribution + over the Laplacian's scale. + beta : float + Rate or inverse scale parameter of the Gamma distribution + over the Laplacian's scale. + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + alpha: float, + beta: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if alpha <= 0: + raise ValueError(f"Must have alpha > 0, got `alpha`={alpha}.") + if beta <= 0: + raise ValueError(f"Must have beta > 0, got `beta`={beta}.") + + self.dim = dim + self.alpha = alpha + self.beta = beta + self.positive_only = positive_only + + self.gamma_distr = Gamma(self.alpha, self.beta) + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + N = num_samples + + inverse_lambdas = self.gamma_distr.sample((N, self.D)) + + weights = Laplace( + loc=torch.zeros((N, self.D), dtype=torch.float32), + scale=1. / inverse_lambdas, + ).sample() + + if self.positive_only: + weights = torch.abs(weights) + + return weights + + def log_prob( + self, + sample: torch.Tensor, + ): + super().check_sample_input(sample) + + log_prob = ( + torch.log(torch.tensor(self.alpha)) + + self.alpha * torch.log(torch.tensor(self.beta)) + - (self.alpha + 1) * torch.log(self.beta + torch.abs(sample)) + ) # [N, D] + if self.positive_only: + log_prob[sample < 0.] = -torch.inf + else: + log_prob -= torch.log(torch.tensor(2.)) + + log_prob = torch.sum(log_prob, dim=1) # [N] + + return log_prob diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py index 54c1598..e6c6ccb 100644 --- a/sparsecoding/priors/spike_slab.py +++ b/sparsecoding/priors/spike_slab.py @@ -5,7 +5,7 @@ class SpikeSlabPrior(Prior): - """Prior where weights are drawn from a "spike-and-slab" distribution. + """Prior where weights are drawn i.i.d. from a "spike-and-slab" distribution. The "spike" is at 0 and the "slab" is Laplacian. @@ -88,19 +88,15 @@ def log_prob( log_prob[spike_mask] = torch.log(torch.tensor(self.p_spike)) # Add log-probability for slab. + log_prob[slab_mask] = ( + torch.log(torch.tensor(1. - self.p_spike)) + - torch.log(torch.tensor(self.scale)) + - sample[slab_mask] / self.scale + ) if self.positive_only: - log_prob[slab_mask] = ( - torch.log(torch.tensor(1. - self.p_spike)) - - torch.log(torch.tensor(self.scale)) - - sample[slab_mask] / self.scale - ) log_prob[sample < 0.] = -torch.inf else: - log_prob[slab_mask] = ( - torch.log(torch.tensor(1. - self.p_spike)) - - torch.log(torch.tensor(2. * self.scale)) - - torch.abs(sample[slab_mask]) / self.scale - ) + log_prob[slab_mask] -= torch.log(torch.tensor(2.)) log_prob = torch.sum(log_prob, dim=1) # [N] diff --git a/tests/priors/test_lsm.py b/tests/priors/test_lsm.py new file mode 100644 index 0000000..bc3c4e7 --- /dev/null +++ b/tests/priors/test_lsm.py @@ -0,0 +1,100 @@ +import numpy as np +import torch +import unittest + +from sparsecoding.priors.lsm import LSMPrior + + +class TestLSMPrior(unittest.TestCase): + def test_sample(self): + N = 10000 + D = 4 + alpha = 2 + beta = 2 + + torch.manual_seed(1997) + + for positive_only in [True, False]: + lsm_prior = LSMPrior( + D, + alpha, + beta, + positive_only, + ) + weights = lsm_prior.sample(N) + + assert weights.shape == (N, D) + + # Check distribution. + if positive_only: + assert torch.sum(weights < 0.) == 0 + else: + assert torch.allclose( + torch.sum(weights < 0.) / (N * D), + torch.sum(weights > 0.) / (N * D), + atol=2e-2, + ) + weights = torch.abs(weights) + + # Note: + # Antiderivative of positive-only is: + # -Beta^alpha * (Beta + x)^(-alpha), + # cdf is: + # 1. - Beta^alpha * (B + x)^(-alpha), + # quantile fn is: + # -Beta + exp((log(1-y) - alpha*log(Beta)) / -alpha) + + for quantile in torch.arange(5) / 5.: + cutoff = ( + -beta + + np.exp( + (np.log(1. - quantile) - alpha * np.log(beta)) + / (-alpha) + ) + ) + assert torch.allclose( + torch.sum(weights < cutoff) / (N * D), + quantile, + atol=1e-2, + ) + + def test_log_prob(self): + D = 3 + alpha = 2 + beta = 2 + + samples = torch.Tensor([[-1., 0., 1.]]) + + pos_only_log_prob = ( + torch.log(torch.tensor(alpha)) - torch.log(torch.tensor(beta)) + + 2 * ( + torch.log(torch.tensor(alpha)) + alpha * torch.log(torch.tensor(beta)) + - (alpha + 1) * torch.log(torch.tensor(1 + beta)) + ) + ) + + for positive_only in [True, False]: + lsm_prior = LSMPrior( + D, + alpha, + beta, + positive_only, + ) + + if positive_only: + assert lsm_prior.log_prob(samples)[0] == -torch.inf + + samples = torch.abs(samples) + assert torch.allclose( + lsm_prior.log_prob(samples)[0], + pos_only_log_prob, + ) + else: + assert torch.allclose( + lsm_prior.log_prob(samples)[0], + pos_only_log_prob - D * torch.log(torch.tensor(2.)), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py index dceb53e..c893d5b 100644 --- a/tests/priors/test_spike_slab.py +++ b/tests/priors/test_spike_slab.py @@ -59,6 +59,15 @@ def test_log_prob(self): p_spike = 0.5 scale = 1. + samples = torch.Tensor([[-1., 0., 1.]]) + + pos_only_log_prob = ( + torch.log(torch.tensor(p_spike)) + + 2 * ( + -1. + torch.log(torch.tensor(1. - p_spike)) + ) + ) + for positive_only in [True, False]: spike_slab_prior = SpikeSlabPrior( D, @@ -67,28 +76,18 @@ def test_log_prob(self): positive_only, ) - samples = torch.Tensor([[-1., 0., 1.]]) - if positive_only: assert spike_slab_prior.log_prob(samples)[0] == -torch.inf samples = torch.abs(samples) assert torch.allclose( spike_slab_prior.log_prob(samples)[0], - ( - -1. + torch.log(torch.tensor(1. - p_spike)) - + torch.log(torch.tensor(p_spike)) - - 1. + torch.log(torch.tensor(1. - p_spike)) - ) + pos_only_log_prob, ) else: assert torch.allclose( spike_slab_prior.log_prob(samples)[0], - ( - -1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.)) - + torch.log(torch.tensor(p_spike)) - - 1. + torch.log(torch.tensor(1. - p_spike)) - torch.log(torch.tensor(2.)) - ) + pos_only_log_prob - (D - 1) * torch.log(torch.tensor(2.)), ) From 5d2938e0de7ce4c8f6413586b57de50491a8a60c Mon Sep 17 00:00:00 2001 From: alvinzz Date: Wed, 25 May 2022 03:59:12 -0700 Subject: [PATCH 2/3] test with LSM prior --- tests/inference/common.py | 7 +++++++ tests/inference/test_LSM.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/inference/common.py b/tests/inference/common.py index 2ff14d2..b090722 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,6 +1,7 @@ import torch from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors.lsm import LSMPrior from sparsecoding.priors.spike_slab import SpikeSlabPrior from sparsecoding.data.datasets.bars import BarsDataset @@ -24,6 +25,12 @@ ).type(torch.float32) ), ), + LSMPrior( + dim=2 * PATCH_SIZE, + alpha=80.0, + beta=0.02, + positive_only=False, + ), ] DATASET = [ diff --git a/tests/inference/test_LSM.py b/tests/inference/test_LSM.py index 0084daf..e6ab6f6 100644 --- a/tests/inference/test_LSM.py +++ b/tests/inference/test_LSM.py @@ -28,7 +28,7 @@ def test_inference(self): a = inference_method.infer(data, DICTIONARY) - self.assertAllClose(a, dataset.weights, atol=5e-2) + self.assertAllClose(a, dataset.weights, atol=7.5e-2) if __name__ == "__main__": From f7e041b227bf0f8b801c88a9e9f379f8a283f303 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Wed, 1 Jun 2022 11:59:34 -0700 Subject: [PATCH 3/3] bugfix spike_slab log_prob --- sparsecoding/priors/spike_slab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py index e6c6ccb..6336cf0 100644 --- a/sparsecoding/priors/spike_slab.py +++ b/sparsecoding/priors/spike_slab.py @@ -91,7 +91,7 @@ def log_prob( log_prob[slab_mask] = ( torch.log(torch.tensor(1. - self.p_spike)) - torch.log(torch.tensor(self.scale)) - - sample[slab_mask] / self.scale + - torch.abs(sample[slab_mask]) / self.scale ) if self.positive_only: log_prob[sample < 0.] = -torch.inf