diff --git a/sparsecoding/priors/laplace.py b/sparsecoding/priors/laplace.py new file mode 100644 index 0000000..fe2157d --- /dev/null +++ b/sparsecoding/priors/laplace.py @@ -0,0 +1,60 @@ +import torch +from torch.distributions.laplace import Laplace + +from sparsecoding.priors.common import Prior + + +class LaplacePrior(Prior): + """Prior corresponding to a Laplacian distribution. + + Parameters + ---------- + dim : int + Number of weights per sample. + scale : float + The "scale" of the Laplacian distribution (larger is wider). + positive_only : bool + Ensure that the weights are positive by taking the absolute value + of weights sampled from the Laplacian. + """ + + def __init__( + self, + dim: int, + scale: float, + positive_only: bool = True, + ): + if dim < 0: + raise ValueError(f"`dim` should be nonnegative, got {dim}.") + if scale <= 0: + raise ValueError(f"`scale` must be positive, got {scale}.") + + self.dim = dim + self.scale = scale + self.positive_only = positive_only + + self.distr = Laplace(loc=torch.tensor(0.), scale=torch.tensor(self.scale)) + + @property + def D(self): + return self.dim + + def sample(self, num_samples: int): + weights = self.distr.rsample((num_samples, self.D)) + if self.positive_only: + weights = torch.abs(weights) + return weights + + def log_prob( + self, + sample: torch.Tensor, + ): + super().check_sample_input(sample) + + log_prob = self.distr.log_prob(sample) + if self.positive_only: + log_prob += torch.log(torch.tensor(2.)) + log_prob[sample < 0.] = -torch.inf + log_prob = torch.sum(log_prob, dim=1) # [N] + + return log_prob diff --git a/sparsecoding/priors/spike_slab.py b/sparsecoding/priors/spike_slab.py index 6336cf0..7b98ca9 100644 --- a/sparsecoding/priors/spike_slab.py +++ b/sparsecoding/priors/spike_slab.py @@ -1,5 +1,4 @@ import torch -from torch.distributions.laplace import Laplace from sparsecoding.priors.common import Prior @@ -7,8 +6,6 @@ class SpikeSlabPrior(Prior): """Prior where weights are drawn i.i.d. from a "spike-and-slab" distribution. - The "spike" is at 0 and the "slab" is Laplacian. - See: https://wesselb.github.io/assets/write-ups/Bruinsma,%20Spike%20and%20Slab%20Priors.pdf for a good review of the spike-and-slab model. @@ -19,31 +16,31 @@ class SpikeSlabPrior(Prior): Number of weights per sample. p_spike : float The probability of the weight being 0. - scale : float - The "scale" of the Laplacian distribution (larger is wider). - positive_only : bool - Ensure that the weights are positive by taking the absolute value - of weights sampled from the Laplacian. + slab : Prior + The distribution of the "slab". + Since weights drawn from this distribution must be i.i.d., + we enforce `slab.D` to be 1. """ def __init__( self, dim: int, p_spike: float, - scale: float, - positive_only: bool = True, + slab: Prior, ): if dim < 0: raise ValueError(f"`dim` should be nonnegative, got {dim}.") if p_spike < 0 or p_spike > 1: raise ValueError(f"Must have 0 <= `p_spike` <= 1, got `p_spike`={p_spike}.") - if scale <= 0: - raise ValueError(f"`scale` must be positive, got {scale}.") + if slab.D != 1: + raise ValueError( + f"`slab.D` must be 1 (got {slab.D}). " + f"This enforces that can sample i.i.d. weights." + ) self.dim = dim self.p_spike = p_spike - self.scale = scale - self.positive_only = positive_only + self.slab = slab @property def D(self): @@ -53,13 +50,8 @@ def sample(self, num_samples: int): N = num_samples zero_weights = torch.zeros((N, self.D), dtype=torch.float32) - slab_weights = Laplace( - loc=zero_weights, - scale=torch.full((N, self.D), self.scale, dtype=torch.float32), - ).sample() # [N, D] - - if self.positive_only: - slab_weights = torch.abs(slab_weights) + slab_weights = self.slab.sample(num_samples * self.D) + slab_weights = slab_weights.reshape((num_samples, self.D)) spike_over_slab = torch.rand(N, self.D, dtype=torch.float32) < self.p_spike @@ -89,14 +81,9 @@ def log_prob( # Add log-probability for slab. log_prob[slab_mask] = ( - torch.log(torch.tensor(1. - self.p_spike)) - - torch.log(torch.tensor(self.scale)) - - torch.abs(sample[slab_mask]) / self.scale + self.slab.log_prob(sample[slab_mask].reshape(-1, 1)).reshape(-1) + + torch.log(torch.tensor(1. - self.p_spike)) ) - if self.positive_only: - log_prob[sample < 0.] = -torch.inf - else: - log_prob[slab_mask] -= torch.log(torch.tensor(2.)) log_prob = torch.sum(log_prob, dim=1) # [N] diff --git a/tests/inference/common.py b/tests/inference/common.py index b090722..bdad71c 100644 --- a/tests/inference/common.py +++ b/tests/inference/common.py @@ -1,6 +1,7 @@ import torch from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors.laplace import LaplacePrior from sparsecoding.priors.lsm import LSMPrior from sparsecoding.priors.spike_slab import SpikeSlabPrior from sparsecoding.data.datasets.bars import BarsDataset @@ -14,8 +15,11 @@ SpikeSlabPrior( dim=2 * PATCH_SIZE, p_spike=0.8, - scale=1.0, - positive_only=True, + slab=LaplacePrior( + dim=1, + scale=1.0, + positive_only=True, + ), ), L0Prior( prob_distr=( diff --git a/tests/priors/test_laplace.py b/tests/priors/test_laplace.py new file mode 100644 index 0000000..91e5d96 --- /dev/null +++ b/tests/priors/test_laplace.py @@ -0,0 +1,76 @@ +import torch +import unittest + +from sparsecoding.priors.laplace import LaplacePrior + + +class TestLaplacePrior(unittest.TestCase): + def test_sample(self): + N = 10000 + D = 4 + scale = 1. + + torch.manual_seed(1997) + + for positive_only in [True, False]: + laplace_prior = LaplacePrior( + D, + scale, + positive_only, + ) + weights = laplace_prior.sample(N) + + assert weights.shape == (N, D) + + # Check Laplacian distribution. + if positive_only: + assert torch.sum(weights < 0.) == 0 + else: + assert torch.allclose( + torch.sum(weights < 0.) / (N * D), + torch.sum(weights > 0.) / (N * D), + atol=2e-2, + ) + weights = torch.abs(weights) + + laplace_weights = weights[weights > 0.] + for quantile in torch.arange(5) / 5.: + cutoff = -torch.log(1. - quantile) + assert torch.allclose( + torch.sum(laplace_weights < cutoff) / (N * D), + quantile, + atol=1e-2, + ) + + def test_log_prob(self): + D = 3 + scale = 1. + + samples = torch.Tensor([[-1., 0., 1.]]) + + pos_only_log_prob = torch.tensor(-2.) + + for positive_only in [True, False]: + laplace_prior = LaplacePrior( + D, + scale, + positive_only, + ) + + if positive_only: + assert laplace_prior.log_prob(samples)[0] == -torch.inf + + samples = torch.abs(samples) + assert torch.allclose( + laplace_prior.log_prob(samples)[0], + pos_only_log_prob, + ) + else: + assert torch.allclose( + laplace_prior.log_prob(samples)[0], + pos_only_log_prob - D * torch.log(torch.tensor(2.)), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/priors/test_spike_slab.py b/tests/priors/test_spike_slab.py index c893d5b..cd10bc6 100644 --- a/tests/priors/test_spike_slab.py +++ b/tests/priors/test_spike_slab.py @@ -1,6 +1,7 @@ import torch import unittest +from sparsecoding.priors.laplace import LaplacePrior from sparsecoding.priors.spike_slab import SpikeSlabPrior @@ -9,86 +10,74 @@ def test_sample(self): N = 10000 D = 4 p_spike = 0.5 - scale = 1. + slab = LaplacePrior( + dim=1, + scale=1.0, + positive_only=True, + ) torch.manual_seed(1997) p_slab = 1. - p_spike - for positive_only in [True, False]: - spike_slab_prior = SpikeSlabPrior( - D, - p_spike, - scale, - positive_only, - ) - weights = spike_slab_prior.sample(N) + spike_slab_prior = SpikeSlabPrior( + D, + p_spike, + slab, + ) + weights = spike_slab_prior.sample(N) + + assert weights.shape == (N, D) - assert weights.shape == (N, D) + # Check spike probability. + assert torch.allclose( + torch.sum(weights == 0.) / (N * D), + torch.tensor(p_spike), + atol=1e-2, + ) + + # Check Laplacian distribution. + N_slab = p_slab * N * D + assert torch.sum(weights < 0.) == 0 - # Check spike probability. + laplace_weights = weights[weights > 0.] + for quantile in torch.arange(5) / 5.: + cutoff = -torch.log(1. - quantile) assert torch.allclose( - torch.sum(weights == 0.) / (N * D), - torch.tensor(p_spike), + torch.sum(laplace_weights < cutoff) / N_slab, + quantile, atol=1e-2, ) - # Check Laplacian distribution. - N_slab = p_slab * N * D - if positive_only: - assert torch.sum(weights < 0.) == 0 - else: - assert torch.allclose( - torch.sum(weights < 0.) / N_slab, - torch.sum(weights > 0.) / N_slab, - atol=2e-2, - ) - weights = torch.abs(weights) - - laplace_weights = weights[weights > 0.] - for quantile in torch.arange(5) / 5.: - cutoff = -torch.log(1. - quantile) - assert torch.allclose( - torch.sum(laplace_weights < cutoff) / N_slab, - quantile, - atol=1e-2, - ) - def test_log_prob(self): D = 3 p_spike = 0.5 - scale = 1. + slab = LaplacePrior( + dim=1, + scale=1.0, + positive_only=True, + ) samples = torch.Tensor([[-1., 0., 1.]]) - pos_only_log_prob = ( - torch.log(torch.tensor(p_spike)) - + 2 * ( - -1. + torch.log(torch.tensor(1. - p_spike)) - ) + spike_slab_prior = SpikeSlabPrior( + D, + p_spike, + slab, ) - for positive_only in [True, False]: - spike_slab_prior = SpikeSlabPrior( - D, - p_spike, - scale, - positive_only, - ) + assert spike_slab_prior.log_prob(samples)[0] == -torch.inf - if positive_only: - assert spike_slab_prior.log_prob(samples)[0] == -torch.inf - - samples = torch.abs(samples) - assert torch.allclose( - spike_slab_prior.log_prob(samples)[0], - pos_only_log_prob, - ) - else: - assert torch.allclose( - spike_slab_prior.log_prob(samples)[0], - pos_only_log_prob - (D - 1) * torch.log(torch.tensor(2.)), + samples = torch.abs(samples) + assert torch.allclose( + spike_slab_prior.log_prob(samples)[0], + ( + torch.log(torch.tensor(p_spike)) + + 2 * ( + -1. + torch.log(torch.tensor(1. - p_spike)) ) + ), + ) if __name__ == "__main__":