From 93eb16f4b07eb30184b2601f6b4456148f52bf03 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Thu, 2 Jun 2022 10:45:15 -0700 Subject: [PATCH 01/16] add tqdm to requrements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 7c6c9f2..4ae815a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ numpy matplotlib torch torchvision +tqdm From 17f87ba421357a702e2b1010032d58fb973722f6 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Thu, 2 Jun 2022 10:45:48 -0700 Subject: [PATCH 02/16] add hierarchical model --- sparsecoding/models.py | 357 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 356 insertions(+), 1 deletion(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 38bfb19..c38bbe8 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -1,7 +1,12 @@ +from typing import List + import numpy as np +import pickle as pkl import torch from torch.utils.data import DataLoader -import pickle as pkl +from tqdm import tqdm + +from sparsecoding.priors.common import Prior class SparseCoding(torch.nn.Module): @@ -213,6 +218,356 @@ def save_dictionary(self, filename): filehandler.close() +class Hierarchical(torch.nn.Module): + """Class for hierarchical sparse coding. + + Layer x_{n+1} is recursively defined as: + x_{n+1} := Phi_n x_n + a_n, + where: + Phi_n is a basis set (with unit norm), + a_n has a sparse prior. + + The `a_n`s can be thought of the errors or residuals in a predictive coding + model, or the sparse weights at each layer in a generative model. + + Parameters + ---------- + priors : List[Prior] + Prior on weights for each layer. + """ + + def __init__( + self, + priors: List[Prior], + ): + self.priors = priors + + self.dims = [prior.D for prior in priors] + + self.bases = [ + torch.normal( + mean=torch.zeros((self.dims[n], self.dims[n + 1]), dtype=torch.float32), + std=torch.ones((self.dims[n], self.dims[n + 1]), dtype=torch.float32), + ) + for n in range(self.L - 1) + ] + # Normalize / project onto unit sphere + self.bases = list(map( + lambda basis: basis / torch.norm(basis, dim=1, keepdim=True), + self.bases + )) + for basis in self.bases: + basis.requires_grad = True + + @property + def L(self): + """Number of layers in the generative model. + """ + return len(self.dims) + + def generate( + bases: List[torch.Tensor], + weights: List[torch.Tensor], + ): + """Run the generative model forward. + + Parameters + ---------- + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + weights : List[Tensor], length L, shape [N, D_i] + Weights at each layer. + + Returns + ------- + data : Tensor, shape [N, D_L] + Generated data from the given weights and bases. + """ + Hierarchical._check_bases_weights(bases, weights) + + x_i = weights[0] + for (basis, weight) in zip(bases, weights[1:]): + x_i = torch.einsum( + "ni,ij->nj", + x_i, + basis, + ) + weight + + return x_i + + def sample( + n_samples: int, + priors: List[Prior], + bases: List[torch.Tensor], + ): + """Sample from the generative model. + + Parameters + ---------- + n_samples : int + Number of samples to generate. + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + priors : List[Prior], length L + Priors for the weights at each layer. + + Returns + ------- + data : Tensor, shape [N, D_L] + Sampled data using the given priors and bases. + """ + if n_samples < 0: + raise ValueError(f"`n_samples` must be non-negative, got {n_samples}.") + Hierarchical._check_bases_priors(bases, priors) + + weights = list(map( + lambda prior: prior.sample(n_samples), + priors, + )) + return Hierarchical.generate(bases, weights) + + def log_prob( + data: torch.Tensor, + bases: List[torch.Tensor], + priors: List[Prior], + weights: List[torch.Tensor], + ): + """Compute the log-probability of the `data` under the generative model. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to get the log-probability of. + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + priors: List[Prior], length L + Priors on the weights at each layer. + weights : List[Tensor], length L - 1, shape [N, D_i] + Weights for the basis functions at each layer, + EXCEPT for the bottom layer, where the weights are + implicitly defined as the difference between the data and the + generated predictions from the previous layers. + + Returns + ------- + log_prob : Tensor, shape [N] + Log-probabilities of the data under the generative model. + """ + Hierarchical._check_bases_priors(bases, priors) + Hierarchical._check_bases_weights( + bases, + # Need to add dummy weights since last layer weights + # are not specified in the input. + weights + [torch.zeros((weights[0].shape[0], bases[-1].shape[1]))] + ) + + # First layer, no basis + x_i = weights[0] + log_prob = priors[0].log_prob(weights[0]) + + # Middle layers + for (prior, basis, weight) in zip(priors[1:-1], bases[:-1], weights[1:]): + x_i = torch.einsum( + "ni,ij->nj", + x_i, # [N, D_i] + basis, # [D_i, D_{i+1}] + ) + weight + log_prob = log_prob + prior.log_prob(weight) + + # Last layer, implicit weights calculated from the data + x_l = torch.einsum( + "ni,ij->nj", + x_i, # [N, D_i] + bases[-1], # [D_i, D_{i+1}] + ) + weight_l = data - x_l + log_prob = log_prob + priors[-1].log_prob(weight_l) + + return log_prob + + def infer_weights( + self, + data: torch.Tensor, + n_iter: int = 1000, + learning_rate: float = 0.1, + ): + """Infer weights for the input `data` to maximize the log-likelihood. + + Performs gradient descent with Adam. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of gradient descent to perform. + learning_rate : float + Learning rate for the optimizer. + + Returns + ------- + weights : List[Tensor], length L, shape [N, D_i] + Inferred weights for each layer. + """ + N = data.shape[0] + + weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + for i in range(self.L - 1) + ] + bases = list(map(lambda basis: basis.detach(), self.bases)) + + optimizer = torch.optim.Adam(weights, lr=learning_rate) + for _ in range(n_iter): + log_prob = Hierarchical.log_prob(data, bases, self.priors, weights) + optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + optimizer.step() + + weights = list(map(lambda weight: weight.detach(), weights)) + noise = ( + data + - Hierarchical.generate( + self.bases, + weights + [torch.zeros_like(data)], + ) + ) + return weights + [noise] + + def learn_bases( + self, + data: torch.Tensor, + n_iter: int = 125, + learning_rate: float = 0.01, + inference_n_iter: int = 1000, + inference_learning_rate: float = 0.1, + ): + """Update the bases to maximize the log-likelihood of `data`. + + In each iteration, we first infer the weights under the current basis functions, + and then we update the bases with those weights fixed. + + Uses gradient descent with Adam. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of gradient descent to perform. + + Returns + ------- + weights : List[Tensor], length L, shape [N, D_i] + Inferred weights for each layer. + """ + optimizer = torch.optim.Adam(self.bases, lr=learning_rate) + for _ in tqdm(range(n_iter)): + weights = self.infer_weights(data, inference_n_iter, inference_learning_rate) + + log_prob = Hierarchical.log_prob(data, self.bases, self.priors, weights[:-1]) + optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + optimizer.step() + + # Normalize basis elements (project them back onto the unit sphere). + with torch.no_grad(): + for basis in self.bases: + basis /= torch.norm(basis, dim=1, keepdim=True) + + def inspect_bases(self): + """Runs the generative model forward for each basis element. + + This allows visual inspection of what each basis function represents + at the final (bottom) layer. + + Returns + ------- + bases : List[List[Tensor]], shape [D_L] + Visualizations of the basis functions at each layer. + """ + bases = [] + for layer in range(self.L - 1): + layer_bases = [] + for basis_fn in range(self.bases[layer].shape[0]): + weights = [torch.zeros((1, dim)) for dim in self.dims] + weights[layer][0, basis_fn] = 1. + layer_bases.append(Hierarchical.generate(self.bases, weights)) + bases.append(layer_bases) + return bases + + def _check_bases_weights(bases, weights): + """Check bases and weights for shape compatibility. + """ + if len(weights) != len(bases) + 1: + raise ValueError( + f"Must have exactly one more weight than basis " + f"(`L` layers and `L-1` bases to transform between them), " + f"got {len(weights)} weights and {len(bases)} bases." + ) + if not all([ + weights[i].shape[0] == weights[0].shape[0] + for i in range(1, len(weights)) + ]): + raise ValueError( + "Weight tensors must all have the same size in the 0-th dimension." + "This is the size of the data to generate." + ) + for (layer, (basis_i, basis_j)) in enumerate(zip(bases[:-1], bases[1:])): + if basis_i.shape[1] != basis_j.shape[0]: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"produces weights of dimension {basis_i.shape[1]} " + f"for layer {layer+1} but " + f"basis between layer {layer+1} and layer {layer+2} " + f"expects {basis_j.shape[0]} weights for layer {layer+1}." + ) + for (layer, (basis, weight)) in enumerate(zip(bases, weights[:-1])): + if basis.shape[0] != weight.shape[1]: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"expects {basis.shape[0]} weights for {layer}, " + f"but {weight.shape[1]} weights " + f"are provided for layer {layer}." + ) + if bases[-1].shape[1] != weights[-1].shape[1]: + raise ValueError( + f"The final basis outputs data with dimension {bases[-1].shape[1]}, " + f"but final layer weights have dimension {weights[-1].shape[1]}." + ) + + def _check_bases_priors(bases, priors): + """Check bases and priors for shape compatibility. + """ + if len(priors) != len(bases) + 1: + raise ValueError( + f"Must have exactly one more prior than basis " + f"(`L` layers and `L-1` bases to transform between them), " + f"got {len(priors)} priors and {len(bases)} bases." + ) + for (layer, (basis_i, basis_j)) in enumerate(zip(bases[:-1], bases[1:])): + if basis_i.shape[1] != basis_j.shape[0]: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"produces priors of dimension {basis_i.shape[1]} " + f"for layer {layer+1} but " + f"basis between layer {layer+1} and layer {layer+2} " + f"expects {basis_j.shape[0]} priors for layer {layer+1}." + ) + for (layer, (basis, prior)) in enumerate(zip(bases, priors[:-1])): + if basis.shape[0] != prior.D: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"expects {basis.shape[0]} weights for {layer}, " + f"but the prior for layer {layer} is over {prior.D} weights." + ) + if bases[-1].shape[1] != priors[-1].D: + raise ValueError( + f"The final basis outputs data with dimension {bases[-1].shape[1]}, " + f"but final layer prior is over {priors[-1].D} weights." + ) + + class SimulSparseCoding(SparseCoding): def __init__(self, inference_method, n_basis, n_features, sparsity_penalty, inf_rate=1, learn_rate=1, time_step=1, t_max=1000, From 4ae18bd498f4e1ca39c125e67df67e5dc7ca4ce1 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Thu, 2 Jun 2022 10:46:08 -0700 Subject: [PATCH 03/16] add hierarchical bars dataset --- sparsecoding/data/datasets/bars.py | 87 ++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/data/datasets/bars.py index 16f877d..92c8657 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/data/datasets/bars.py @@ -1,7 +1,12 @@ +import itertools +from typing import List + +import numpy as np import torch from torch.utils.data import Dataset from sparsecoding.priors.common import Prior +from sparsecoding.models import Hierarchical class BarsDataset(Dataset): @@ -47,6 +52,7 @@ def __init__( h_bars = h_bars.expand(self.P, self.P, self.P) v_bars = v_bars.expand(self.P, self.P, self.P) self.basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] + self.basis /= np.sqrt(self.P) # Normalize basis. self.weights = prior.sample(self.N) # [N, 2*P] @@ -61,3 +67,84 @@ def __len__(self): def __getitem__(self, idx: int): return self.data[idx] + + +class HierarchicalBarsDataset(Dataset): + """Toy hierarchical dataset of horizontal and vertical bars. + + The L=1 basis functions are horizontal and vertical bars. + + The L=0 basis functions are equal mixtures of two horizontal and vertical bars + on the image border. + + Parameters + ---------- + patch_size : int + Side length for elements of the dataset. + dataset_size : int + Number of dataset elements to generate. + priors : List[Prior] + Prior distributions on the weights, starting from the top-level basis + and going down. + + Attributes + ---------- + bases : List[Tensor], + shapes: + - [6, 2 * patch_size] + - [2 * patch_size, patch_size * patch_size] + Dictionary elements (combinations of horizontal and vertical bars). + weights : List[Tensor], + shapes: + - [dataset_size, 6], + - [dataset_size, 2 * patch_size], + - [dataset_size, patch_size * patch_size]. + Weights for each level of the hierarchy. + data : Tensor, shape [dataset_size, patch_size * patch_size] + Weighted linear combinations of the basis elements. + """ + + def __init__( + self, + patch_size: int, + dataset_size: int, + priors: List[Prior], + ): + self.P = patch_size + self.N = dataset_size + self.priors = priors + + # Specify l1_basis: bars. + one_hots = torch.nn.functional.one_hot(torch.arange(self.P)) # [P, P] + one_hots = one_hots.type(torch.float32) # [P, P] + + h_bars = one_hots.reshape(self.P, self.P, 1) + v_bars = one_hots.reshape(self.P, 1, self.P) + + h_bars = h_bars.expand(self.P, self.P, self.P) + v_bars = v_bars.expand(self.P, self.P, self.P) + l1_basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] + l1_basis /= np.sqrt(self.P) # Normalize basis. + l1_basis = l1_basis.reshape((2 * self.P, self.P * self.P)) + + # Specify l0_basis: combinations of two bars on the border. + border_bar_idxs = [0, self.P - 1, self.P, 2 * self.P - 1] + l0_basis_idxs = torch.tensor(list(itertools.combinations(border_bar_idxs, 2))) + l0_basis = torch.zeros((6, 2 * self.P), dtype=torch.float32) + l0_basis[torch.arange(6), l0_basis_idxs[:, 0]] = 1. / np.sqrt(2.) + l0_basis[torch.arange(6), l0_basis_idxs[:, 1]] = 1. / np.sqrt(2.) + + self.bases = [l0_basis, l1_basis] + + self.weights = list(map( + lambda prior: prior.sample(self.N), + self.priors, + )) + + self.data = Hierarchical.generate(self.bases, self.weights) + + def __len__(self): + return self.N + + def __getitem__(self, idx: int): + return self.data[idx] From c17c65ffd5fd52d1a2d71de4b82140cc0e6c7e18 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Thu, 2 Jun 2022 10:46:45 -0700 Subject: [PATCH 04/16] add unittests --- tests/models/test_Hierarchical.py | 89 +++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/models/test_Hierarchical.py diff --git a/tests/models/test_Hierarchical.py b/tests/models/test_Hierarchical.py new file mode 100644 index 0000000..80a919a --- /dev/null +++ b/tests/models/test_Hierarchical.py @@ -0,0 +1,89 @@ +import torch +import unittest + +from sparsecoding.data.datasets.bars import HierarchicalBarsDataset +from sparsecoding.models import Hierarchical +from sparsecoding.priors.laplace import LaplacePrior +from tests.testing_utilities import TestCase + +torch.manual_seed(1997) + +PATCH_SIZE = 8 +DATASET_SIZE = 1000 + +PRIORS = [ + LaplacePrior( + dim=6, + scale=1.0, + positive_only=False, + ), + LaplacePrior( + dim=2 * PATCH_SIZE, + scale=0.1, + positive_only=False, + ), + LaplacePrior( + dim=PATCH_SIZE * PATCH_SIZE, + scale=0.01, + positive_only=False, + ), +] + +DATASET = HierarchicalBarsDataset( + patch_size=PATCH_SIZE, + dataset_size=DATASET_SIZE, + priors=PRIORS, +) + + +class TestHierarchical(TestCase): + def test_infer_weights(self): + """ + Test that Hierarchical inference recovers the correct weights. + """ + model = Hierarchical(priors=PRIORS) + model.bases = DATASET.bases + + weights = model.infer_weights(DATASET.data) + + inferred_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + DATASET.bases, + PRIORS, + weights[:-1], + )) + dataset_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + DATASET.bases, + PRIORS, + DATASET.weights[:-1], + )) + self.assertAllClose(inferred_log_probs, dataset_log_probs, atol=5e-2) + + def test_learn_bases(self): + """ + Test that Hierarchical inference recovers the correct bases. + """ + model = Hierarchical(priors=PRIORS) + + model.learn_bases(DATASET.data) + + weights = model.infer_weights(DATASET.data) + + inferred_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + list(map(lambda basis: basis.detach(), model.bases)), + PRIORS, + weights[:-1], + )) + dataset_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + DATASET.bases, + PRIORS, + DATASET.weights[:-1], + )) + self.assertAllClose(inferred_log_probs, dataset_log_probs, atol=3e0) + + +if __name__ == "__main__": + unittest.main() From ec7d945e90c00a28da1559ceefee1f16d26386a8 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Thu, 2 Jun 2022 10:50:42 -0700 Subject: [PATCH 05/16] lint --- sparsecoding/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index c38bbe8..96b0380 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -347,7 +347,7 @@ def log_prob( EXCEPT for the bottom layer, where the weights are implicitly defined as the difference between the data and the generated predictions from the previous layers. - + Returns ------- log_prob : Tensor, shape [N] @@ -392,7 +392,7 @@ def infer_weights( learning_rate: float = 0.1, ): """Infer weights for the input `data` to maximize the log-likelihood. - + Performs gradient descent with Adam. Parameters @@ -403,7 +403,7 @@ def infer_weights( Number of iterations of gradient descent to perform. learning_rate : float Learning rate for the optimizer. - + Returns ------- weights : List[Tensor], length L, shape [N, D_i] @@ -442,8 +442,8 @@ def learn_bases( inference_n_iter: int = 1000, inference_learning_rate: float = 0.1, ): - """Update the bases to maximize the log-likelihood of `data`. - + """Update the bases to maximize the log-likelihood of `data`. + In each iteration, we first infer the weights under the current basis functions, and then we update the bases with those weights fixed. @@ -455,7 +455,7 @@ def learn_bases( Data to be generated. n_iter : int Number of iterations of gradient descent to perform. - + Returns ------- weights : List[Tensor], length L, shape [N, D_i] From 8e4183e48a4329079223341f0c22205af40c147e Mon Sep 17 00:00:00 2001 From: alvinzz Date: Thu, 2 Jun 2022 11:04:50 -0700 Subject: [PATCH 06/16] un-break tests --- sparsecoding/data/datasets/bars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/data/datasets/bars.py index 92c8657..f0dc8ca 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/data/datasets/bars.py @@ -52,7 +52,7 @@ def __init__( h_bars = h_bars.expand(self.P, self.P, self.P) v_bars = v_bars.expand(self.P, self.P, self.P) self.basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] - self.basis /= np.sqrt(self.P) # Normalize basis. + # self.basis /= np.sqrt(self.P) # Normalize basis. self.weights = prior.sample(self.N) # [N, 2*P] From 77a3d5641cf35e391b9ec55a72ea92b66a162005 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Fri, 3 Jun 2022 16:28:41 -0700 Subject: [PATCH 07/16] 1-indexing for layers --- sparsecoding/data/datasets/bars.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/data/datasets/bars.py index f0dc8ca..61e6e5c 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/data/datasets/bars.py @@ -123,18 +123,18 @@ def __init__( h_bars = h_bars.expand(self.P, self.P, self.P) v_bars = v_bars.expand(self.P, self.P, self.P) - l1_basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] - l1_basis /= np.sqrt(self.P) # Normalize basis. - l1_basis = l1_basis.reshape((2 * self.P, self.P * self.P)) + l2_basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] + l2_basis /= np.sqrt(self.P) # Normalize basis. + l2_basis = l2_basis.reshape((2 * self.P, self.P * self.P)) # Specify l0_basis: combinations of two bars on the border. border_bar_idxs = [0, self.P - 1, self.P, 2 * self.P - 1] - l0_basis_idxs = torch.tensor(list(itertools.combinations(border_bar_idxs, 2))) - l0_basis = torch.zeros((6, 2 * self.P), dtype=torch.float32) - l0_basis[torch.arange(6), l0_basis_idxs[:, 0]] = 1. / np.sqrt(2.) - l0_basis[torch.arange(6), l0_basis_idxs[:, 1]] = 1. / np.sqrt(2.) + l1_basis_idxs = torch.tensor(list(itertools.combinations(border_bar_idxs, 2))) + l1_basis = torch.zeros((6, 2 * self.P), dtype=torch.float32) + l1_basis[torch.arange(6), l1_basis_idxs[:, 0]] = 1. / np.sqrt(2.) + l1_basis[torch.arange(6), l1_basis_idxs[:, 1]] = 1. / np.sqrt(2.) - self.bases = [l0_basis, l1_basis] + self.bases = [l1_basis, l2_basis] self.weights = list(map( lambda prior: prior.sample(self.N), From d6974215de7b5b421b3d4605213b993087ba1e66 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Fri, 3 Jun 2022 16:29:36 -0700 Subject: [PATCH 08/16] add return_history options - for Hierarchical inference and learning - some variable renaming - make _compute_bottom_weights() function --- sparsecoding/models.py | 127 +++++++++++++++++++++++++++++++++-------- 1 file changed, 102 insertions(+), 25 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 96b0380..9c2e9b2 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -390,6 +390,7 @@ def infer_weights( data: torch.Tensor, n_iter: int = 1000, learning_rate: float = 0.1, + return_history: bool = False, ): """Infer weights for the input `data` to maximize the log-likelihood. @@ -403,36 +404,70 @@ def infer_weights( Number of iterations of gradient descent to perform. learning_rate : float Learning rate for the optimizer. + return_history : bool + Flag to return the history of the inferred weights during inference. Returns ------- weights : List[Tensor], length L, shape [N, D_i] Inferred weights for each layer. + weights_history : optional, List[Tensor], length L, shape [n_iter + 1, N, D_i] + Returned if `return_history`. The inferred weights for each layer + throughout inference. """ N = data.shape[0] - weights = [ + top_weights = [ torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) for i in range(self.L - 1) ] bases = list(map(lambda basis: basis.detach(), self.bases)) - optimizer = torch.optim.Adam(weights, lr=learning_rate) + if return_history: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + weights_history = list(map( + lambda weight: torch.unsqueeze(weight, dim=0), + weights, + )) + + def add_weights_to_history(x): + weights_history, weights = x + weights_history = torch.cat( + (weights_history, weights.unsqueeze(0)), + dim=0, + ) + return weights_history + + optimizer = torch.optim.Adam(top_weights, lr=learning_rate) for _ in range(n_iter): - log_prob = Hierarchical.log_prob(data, bases, self.priors, weights) + log_prob = Hierarchical.log_prob(data, bases, self.priors, top_weights) optimizer.zero_grad() (-torch.mean(log_prob)).backward() optimizer.step() - weights = list(map(lambda weight: weight.detach(), weights)) - noise = ( - data - - Hierarchical.generate( - self.bases, - weights + [torch.zeros_like(data)], - ) - ) - return weights + [noise] + if return_history: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + weights_history = list(map( + add_weights_to_history, + zip(weights_history, weights), + )) + + top_weights = list(map(lambda weight: weight.detach(), top_weights)) + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + + if not return_history: + return weights + else: + weights_history = list(map( + add_weights_to_history, + zip(weights_history, weights), + )) + return weights, weights_history def learn_bases( self, @@ -441,6 +476,7 @@ def learn_bases( learning_rate: float = 0.01, inference_n_iter: int = 1000, inference_learning_rate: float = 0.1, + return_history: bool = False, ): """Update the bases to maximize the log-likelihood of `data`. @@ -450,17 +486,28 @@ def learn_bases( Uses gradient descent with Adam. Parameters - ---------- - data : Tensor, shape [N, D_L] - Data to be generated. - n_iter : int - Number of iterations of gradient descent to perform. + ----------training + Flag to return the history of the learned bases during inference. Returns ------- - weights : List[Tensor], length L, shape [N, D_i] - Inferred weights for each layer. + bases_history : optional, List[Tensor], length L - 1, shape [n_iter + 1, D_i, D_{i+1}] + Returned if `return_history`. The learned bases throughout training. """ + if return_history: + bases_history = list(map( + lambda basis: basis.detach().unsqueeze(0), + self.bases, + )) + + def add_basis_to_history(x): + basis_history, basis = x + basis_history = torch.cat( + (basis_history, basis.unsqueeze(0)), + dim=0, + ) + return basis_history + optimizer = torch.optim.Adam(self.bases, lr=learning_rate) for _ in tqdm(range(n_iter)): weights = self.infer_weights(data, inference_n_iter, inference_learning_rate) @@ -475,6 +522,15 @@ def learn_bases( for basis in self.bases: basis /= torch.norm(basis, dim=1, keepdim=True) + if return_history: + bases_history = list(map( + add_basis_to_history, + zip(bases_history, self.bases), + )) + + if return_history: + return bases_history + def inspect_bases(self): """Runs the generative model forward for each basis element. @@ -483,18 +539,18 @@ def inspect_bases(self): Returns ------- - bases : List[List[Tensor]], shape [D_L] + bases_viz : List[List[Tensor]], shape [D_L] Visualizations of the basis functions at each layer. """ - bases = [] + bases_viz = [] for layer in range(self.L - 1): - layer_bases = [] + layer_bases_viz = [] for basis_fn in range(self.bases[layer].shape[0]): weights = [torch.zeros((1, dim)) for dim in self.dims] weights[layer][0, basis_fn] = 1. - layer_bases.append(Hierarchical.generate(self.bases, weights)) - bases.append(layer_bases) - return bases + layer_bases_viz.append(Hierarchical.generate(self.bases, weights)[0]) + bases_viz.append(layer_bases_viz) + return bases_viz def _check_bases_weights(bases, weights): """Check bases and weights for shape compatibility. @@ -567,6 +623,27 @@ def _check_bases_priors(bases, priors): f"but final layer prior is over {priors[-1].D} weights." ) + def _compute_bottom_weights( + data: torch.Tensor, + bases: List[torch.Tensor], + top_weights: List[torch.Tensor], + ): + """Compute the bottom-layer weights for `data`, given weights for all the other layers. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + weights : List[Tensor], length L - 1, shape [N, D_i] + Weights at the top `L - 1` layers. + """ + weights = top_weights + [torch.zeros_like(data)] + Hierarchical._check_bases_weights(bases, weights) + bottom_weights = data - Hierarchical.generate(bases, weights) + return bottom_weights + class SimulSparseCoding(SparseCoding): def __init__(self, inference_method, n_basis, n_features, sparsity_penalty, From 1858991051b823da2863c1d0f4c6ff500be3082f Mon Sep 17 00:00:00 2001 From: alvinzz Date: Mon, 6 Jun 2022 10:20:28 -0700 Subject: [PATCH 09/16] weight inference from initial weights --- sparsecoding/models.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 9c2e9b2..fa31af2 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import numpy as np import pickle as pkl @@ -391,6 +391,7 @@ def infer_weights( n_iter: int = 1000, learning_rate: float = 0.1, return_history: bool = False, + initial_weights: Optional[List[torch.Tensor]] = None, ): """Infer weights for the input `data` to maximize the log-likelihood. @@ -406,6 +407,9 @@ def infer_weights( Learning rate for the optimizer. return_history : bool Flag to return the history of the inferred weights during inference. + initial_weights : optional, List[Tensor], length L - 1, shape [N, D_i] + If provided, the initial weights to start inference from. + Otherwise, weights are set to 0. Returns ------- @@ -417,10 +421,16 @@ def infer_weights( """ N = data.shape[0] - top_weights = [ - torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) - for i in range(self.L - 1) - ] + if initial_weights is None: + top_weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + for i in range(self.L - 1) + ] + else: + top_weights = initial_weights + for weight in top_weights: + weight.requires_grad = True + bases = list(map(lambda basis: basis.detach(), self.bases)) if return_history: From f8467c6f9a05d6ef163d793da29a333464a06132 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Mon, 6 Jun 2022 10:26:46 -0700 Subject: [PATCH 10/16] fix docstring --- sparsecoding/models.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index fa31af2..387d3c2 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -496,7 +496,19 @@ def learn_bases( Uses gradient descent with Adam. Parameters - ----------training + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of gradient descent to perform. + learning_rate : float + Step-size for learning the bases. + inference_n_iter : int + Number of iterations of gradient descent + to perform during weight inference. + inference_learning_rate : float + Step-size for inferring the weights. + return_history : bool Flag to return the history of the learned bases during inference. Returns From 910939ae925a61ec7fffc3358f6c7a4d1f607fa4 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Mon, 6 Jun 2022 11:50:02 -0700 Subject: [PATCH 11/16] warm-start weights during bases learning --- sparsecoding/models.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 387d3c2..e1b3fb2 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -484,8 +484,8 @@ def learn_bases( data: torch.Tensor, n_iter: int = 125, learning_rate: float = 0.01, - inference_n_iter: int = 1000, - inference_learning_rate: float = 0.1, + inference_n_iter: int = 25, + inference_learning_rate: float = 0.01, return_history: bool = False, ): """Update the bases to maximize the log-likelihood of `data`. @@ -515,7 +515,9 @@ def learn_bases( ------- bases_history : optional, List[Tensor], length L - 1, shape [n_iter + 1, D_i, D_{i+1}] Returned if `return_history`. The learned bases throughout training. - """ + """ + N = data.shape[0] + if return_history: bases_history = list(map( lambda basis: basis.detach().unsqueeze(0), @@ -530,14 +532,29 @@ def add_basis_to_history(x): ) return basis_history - optimizer = torch.optim.Adam(self.bases, lr=learning_rate) - for _ in tqdm(range(n_iter)): - weights = self.infer_weights(data, inference_n_iter, inference_learning_rate) + bases_optimizer = torch.optim.Adam(self.bases, lr=learning_rate) - log_prob = Hierarchical.log_prob(data, self.bases, self.priors, weights[:-1]) - optimizer.zero_grad() + top_weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + for i in range(self.L - 1) + ] + weights_optimizer = torch.optim.Adam(top_weights, lr=inference_learning_rate) + + for _ in tqdm(range(n_iter)): + # Infer weights under the current bases. + bases = list(map(lambda basis: basis.detach(), self.bases)) + for _ in range(inference_n_iter): + log_prob = Hierarchical.log_prob(data, bases, self.priors, top_weights) + weights_optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + weights_optimizer.step() + + # Update bases from the current weights. + weights = list(map(lambda weight: weight.detach(), top_weights)) + log_prob = Hierarchical.log_prob(data, self.bases, self.priors, weights) + bases_optimizer.zero_grad() (-torch.mean(log_prob)).backward() - optimizer.step() + bases_optimizer.step() # Normalize basis elements (project them back onto the unit sphere). with torch.no_grad(): From d70d7bb680a6244524a18bc5786783bd4635b419 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Mon, 6 Jun 2022 17:16:48 -0700 Subject: [PATCH 12/16] speed up `weight_history`, `basis_history` --- sparsecoding/models.py | 40 ++++++++++------------------------------ 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index e1b3fb2..8653c73 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -438,18 +438,10 @@ def infer_weights( bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) weights = top_weights + [bottom_weights] weights_history = list(map( - lambda weight: torch.unsqueeze(weight, dim=0), + lambda weight: [weight.detach().clone()], weights, )) - def add_weights_to_history(x): - weights_history, weights = x - weights_history = torch.cat( - (weights_history, weights.unsqueeze(0)), - dim=0, - ) - return weights_history - optimizer = torch.optim.Adam(top_weights, lr=learning_rate) for _ in range(n_iter): log_prob = Hierarchical.log_prob(data, bases, self.priors, top_weights) @@ -461,10 +453,8 @@ def add_weights_to_history(x): with torch.no_grad(): bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) weights = top_weights + [bottom_weights] - weights_history = list(map( - add_weights_to_history, - zip(weights_history, weights), - )) + for (weight_history, weight) in zip(weights_history, weights): + weight_history.append(weight.detach().clone()) top_weights = list(map(lambda weight: weight.detach(), top_weights)) bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) @@ -473,10 +463,8 @@ def add_weights_to_history(x): if not return_history: return weights else: - weights_history = list(map( - add_weights_to_history, - zip(weights_history, weights), - )) + for layer in range(self.L): + weight_history[layer] = torch.stack(weight_history[layer], dim=0) return weights, weights_history def learn_bases( @@ -520,18 +508,10 @@ def learn_bases( if return_history: bases_history = list(map( - lambda basis: basis.detach().unsqueeze(0), + lambda basis: [basis.detach().clone()], self.bases, )) - def add_basis_to_history(x): - basis_history, basis = x - basis_history = torch.cat( - (basis_history, basis.unsqueeze(0)), - dim=0, - ) - return basis_history - bases_optimizer = torch.optim.Adam(self.bases, lr=learning_rate) top_weights = [ @@ -562,12 +542,12 @@ def add_basis_to_history(x): basis /= torch.norm(basis, dim=1, keepdim=True) if return_history: - bases_history = list(map( - add_basis_to_history, - zip(bases_history, self.bases), - )) + for (basis_history, basis) in zip(bases_history, self.bases): + basis_history.append(basis.detach().clone()) if return_history: + for layer in range(self.L): + bases_history[layer] = torch.stack(bases_history[layer], dim=0) return bases_history def inspect_bases(self): From 5763931b3a51dfd57ec19abc30070b0953e4cae5 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Mon, 6 Jun 2022 17:26:25 -0700 Subject: [PATCH 13/16] add infer_weights_local --- sparsecoding/models.py | 116 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 8653c73..02b7434 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -467,6 +467,122 @@ def infer_weights( weight_history[layer] = torch.stack(weight_history[layer], dim=0) return weights, weights_history + def infer_weights_local( + self, + data: torch.Tensor, + n_iter: int = 100000, + learning_rate: float = 0.0001, + return_history_interval: Optional[int] = None, + initial_weights: Optional[List[torch.Tensor]] = None, + ): + """Infer weights for the input `data` to maximize the log-likelihood. + However, information flow is constrained to be between adjacent layers. + + Performs gradient descent with Adam. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of gradient descent to perform. + learning_rate : float + Learning rate for the optimizer. + return_history_interval : optional, int + If set, inferred weights during inference will be saved + at this frequency. + initial_weights : optional, List[Tensor], length L - 1, shape [N, D_i] + If provided, the initial weights to start inference from. + Otherwise, weights are set to 0. + + Returns + ------- + weights : List[Tensor], length L, shape [N, D_i] + Inferred weights for each layer. + weights_history : optional, List[Tensor], length L, shape [n_iter + 1, N, D_i] + Returned if `return_history`. The inferred weights for each layer + throughout inference. + """ + N = data.shape[0] + + if initial_weights is None: + top_weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + for i in range(self.L - 1) + ] + else: + top_weights = initial_weights + for weight in top_weights: + weight.requires_grad = True + + bases = list(map(lambda basis: basis.detach(), self.bases)) + + if return_history_interval: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + weights_history = [[weight.detach()] for weight in weights] + + optimizer = torch.optim.Adam(top_weights, lr=learning_rate) + for it in range(n_iter): + # Generate data under current weights (with no gradient) to get targets. + xs_ng = [] + with torch.no_grad(): + xs_ng.append(top_weights[0].detach()) + for (basis, weight) in zip(bases[:-1], top_weights[1:]): + xs_ng.append(xs_ng[-1] @ basis + weight.detach()) + xs_ng.append(data) + + # Get log-probability for Layer 1. + weight_1 = top_weights[0] + x_ng_below = xs_ng[1] + basis_to_below = bases[0] + prior = self.priors[0] + prior_below = self.priors[1] + log_prob = ( + prior.log_prob(weight_1) + + prior_below.log_prob(x_ng_below - weight_1 @ basis_to_below) + ) + + # Get log-probabilities for Layers 2 through L. + for layer in range(2, self.L): + weight = top_weights[layer - 1] + + x_ng_above = xs_ng[layer - 2] + basis_from_above = bases[layer - 2] + + x_ng_below = xs_ng[layer] + basis_to_below = bases[layer - 1] + + prior = self.priors[layer - 1] + prior_below = self.priors[layer] + log_prob += ( + prior.log_prob(weight) + + prior_below.log_prob(x_ng_below - (x_ng_above @ basis_from_above + weight) @ basis_to_below) + ) + + optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + optimizer.step() + + if return_history_interval and it % return_history_interval == 0: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + for (weight_history, weight) in zip(weights_history, weights): + weight_history.append(weight.detach().clone()) + + top_weights = list(map(lambda weight: weight.detach(), top_weights)) + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + + if not return_history_interval: + return weights + else: + for layer in range(self.L): + weights_history[layer] = torch.stack(weights_history[layer], dim=0) + return weights, weights_history + def learn_bases( self, data: torch.Tensor, From 2bf8832a7b362dbb83ff2dcf8d3e5eaa528d6c75 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Mon, 6 Jun 2022 17:36:54 -0700 Subject: [PATCH 14/16] fix typo --- sparsecoding/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 02b7434..230631d 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -662,7 +662,7 @@ def learn_bases( basis_history.append(basis.detach().clone()) if return_history: - for layer in range(self.L): + for layer in range(self.L - 1): bases_history[layer] = torch.stack(bases_history[layer], dim=0) return bases_history From a690dde0599283044aa8437402af2b552217a29b Mon Sep 17 00:00:00 2001 From: alvinzz Date: Tue, 7 Jun 2022 21:55:03 -0700 Subject: [PATCH 15/16] add L0IidPrior --- sparsecoding/priors/l0.py | 92 ++++++++++++++++++++++++++++++++++++++- tests/priors/test_l0.py | 72 +++++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 2 deletions(-) diff --git a/sparsecoding/priors/l0.py b/sparsecoding/priors/l0.py index 18bf698..83a770c 100644 --- a/sparsecoding/priors/l0.py +++ b/sparsecoding/priors/l0.py @@ -77,4 +77,94 @@ def log_prob( # TODO: Add L0ExpPrior, where the number of active units is distributed exponentially. -# TODO: Add L0IidPrior, where the magnitude of an active unit is distributed according to an i.i.d. Prior. +class L0IidPrior(Prior): + """L0-sparse Prior with non-binary weights. + + If a weight is active, its value is drawn from an i.i.d. Prior. + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + active_weight_prior : Prior + The distribution for active weights. + Since weights drawn from this distribution must be i.i.d., + we enforce `active_weight_prior.D` to be 1. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + active_weight_prior: Prior, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones(1, dtype=torch.float32)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + if active_weight_prior.D != 1: + raise ValueError( + f"`active_weight_prior.D` must be 1 (got {active_weight_prior.D}). " + f"This enforces that can sample i.i.d. weights." + ) + + self.prob_distr = prob_distr + self.active_weight_prior = active_weight_prior + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample( + self, + num_samples: int, + ): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = ( + d_idxs.reshape(1, self.D) + < num_active_weights.reshape(N, 1) + ) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + n_active_idxs = int(torch.sum(active_idx_mask).cpu().numpy()) + active_weight_values = self.active_weight_prior.sample(n_active_idxs) + weights[active_weight_idxs] += active_weight_values.reshape(-1) + + return weights + + def log_prob( + self, + sample: torch.Tensor, + ): + super().check_sample_input(sample) + + active_weight_mask = (sample != 0.) + + l0_norm = torch.sum(active_weight_mask, dim=1).type(torch.long) # [num_samples] + log_prob = torch.log(self.prob_distr[l0_norm - 1]) + log_prob[l0_norm == 0] = -torch.inf + + active_log_prob = ( + self.active_weight_prior.log_prob(sample.reshape(-1, 1)).reshape(sample.shape) + ) # [num_samples, D] + active_log_prob[~active_weight_mask] = 0. # [num_samples, D] + log_prob += torch.sum(active_log_prob, dim=1) # [num_samples] + + return log_prob diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py index 323d12e..be9e68c 100644 --- a/tests/priors/test_l0.py +++ b/tests/priors/test_l0.py @@ -3,7 +3,8 @@ import torch import unittest -from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors.l0 import L0Prior, L0IidPrior +from sparsecoding.priors.laplace import LaplacePrior class TestL0Prior(unittest.TestCase): @@ -63,5 +64,74 @@ def test_log_prob(self): assert log_probs[7] == -torch.inf +class TestL0IidPrior(unittest.TestCase): + def test_sample(self): + N = 10000 + prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) + iid_prior = LaplacePrior(dim=1, scale=1, positive_only=True) + + torch.manual_seed(1997) + + D = prob_distr.shape[0] + + l0_iid_prior = L0IidPrior(prob_distr, iid_prior) + weights = l0_iid_prior.sample(N) + + assert weights.shape == (N, D) + + # Check uniform distribution over which weights are active. + per_weight_hist = torch.sum(weights, dim=0) # [D] + normalized_per_weight_hist = per_weight_hist / torch.sum(per_weight_hist) # [D] + assert torch.allclose( + normalized_per_weight_hist, + torch.full((D,), 0.25, dtype=torch.float32), + atol=1e-2, + ) + + # Check the distribution over the l0-norm of the weights. + num_active_per_sample = torch.sum(weights != 0., dim=1) # [N] + for num_active in range(1, 5): + assert torch.allclose( + torch.sum(num_active_per_sample == num_active) / N, + prob_distr[num_active - 1], + atol=1e-2, + ) + + # Check Laplacian distribution. + laplace_weights = weights[weights > 0.] + for quantile in torch.arange(5) / 5.: + cutoff = -torch.log(1. - quantile) + assert torch.allclose( + torch.sum(laplace_weights < cutoff) / (N * torch.sum(prob_distr * (1 + torch.arange(D)))), + quantile, + atol=1e-2, + ) + + def test_log_prob(self): + prob_distr = torch.tensor([0.75, 0.25, 0.]) + iid_prior = LaplacePrior(dim=1, scale=1, positive_only=True) + + l0_iid_prior = L0IidPrior(prob_distr, iid_prior) + + samples = list(product([0, 1], repeat=3)) # [2**D, D] + samples = torch.tensor(samples, dtype=torch.float32) # [2**D, D] + + log_probs = l0_iid_prior.log_prob(samples) + + # The l0-norm at index `i` + # is the number of ones + # in the binary representation of `i`. + assert log_probs[0] == -torch.inf + assert torch.allclose( + log_probs[[1, 2, 4]], + torch.log(torch.tensor(0.75)) - 1., + ) + assert torch.allclose( + log_probs[[3, 5, 6]], + torch.log(torch.tensor(0.25)) - 2., + ) + assert log_probs[7] == -torch.inf + + if __name__ == "__main__": unittest.main() From fc0461de74a610b05fbd2833f9cf940d018a2275 Mon Sep 17 00:00:00 2001 From: alvinzz Date: Tue, 7 Jun 2022 22:23:33 -0700 Subject: [PATCH 16/16] support SGD --- sparsecoding/models.py | 83 +++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 25 deletions(-) diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 230631d..c030dfe 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -587,6 +587,7 @@ def learn_bases( self, data: torch.Tensor, n_iter: int = 125, + batch_size: Optional[int] = None, learning_rate: float = 0.01, inference_n_iter: int = 25, inference_learning_rate: float = 0.01, @@ -597,14 +598,17 @@ def learn_bases( In each iteration, we first infer the weights under the current basis functions, and then we update the bases with those weights fixed. - Uses gradient descent with Adam. + Uses (stochastic) gradient descent with Adam. Parameters ---------- data : Tensor, shape [N, D_L] Data to be generated. n_iter : int - Number of iterations of gradient descent to perform. + Number of iterations of (stochastic) gradient descent to perform. + batch_size : optional, int, default=None + If provided, the batch size used during Stochastic Gradient Descent. + Otherwise, performs Gradient Descent using the entire dataset. learning_rate : float Step-size for learning the bases. inference_n_iter : int @@ -622,6 +626,10 @@ def learn_bases( """ N = data.shape[0] + do_sgd = (batch_size is None) + if batch_size is None: + batch_size = N + if return_history: bases_history = list(map( lambda basis: [basis.detach().clone()], @@ -631,35 +639,60 @@ def learn_bases( bases_optimizer = torch.optim.Adam(self.bases, lr=learning_rate) top_weights = [ - torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + torch.zeros((N, self.dims[i]), dtype=torch.float32) for i in range(self.L - 1) ] - weights_optimizer = torch.optim.Adam(top_weights, lr=inference_learning_rate) - + for _ in tqdm(range(n_iter)): - # Infer weights under the current bases. - bases = list(map(lambda basis: basis.detach(), self.bases)) - for _ in range(inference_n_iter): - log_prob = Hierarchical.log_prob(data, bases, self.priors, top_weights) - weights_optimizer.zero_grad() + if do_sgd: + epoch_idxs = torch.randperm(N) + epoch_data = data.clone() + epoch_data = epoch_data[epoch_idxs] + epoch_top_weights = [weight[epoch_idxs] for weight in top_weights] + else: + epoch_data = data + epoch_top_weights = top_weights + + for batch in range(N // batch_size): + batch_start_idx = batch * batch_size + batch_end_idx = (batch + 1) * batch_size + batch_data = epoch_data[batch_start_idx:batch_end_idx] + + # Infer weights under the current bases. + batch_top_weights = [ + weight[batch_start_idx:batch_end_idx] + for weight + in epoch_top_weights + ] + for weight in batch_top_weights: + weight.requires_grad = True + batch_weights_optimizer = torch.optim.Adam(batch_top_weights, lr=inference_learning_rate) + bases = list(map(lambda basis: basis.detach(), self.bases)) + for _ in range(inference_n_iter): + log_prob = Hierarchical.log_prob(batch_data, bases, self.priors, batch_top_weights) + batch_weights_optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + batch_weights_optimizer.step() + + # Update bases from the current weights. + batch_top_weights = list(map(lambda weight: weight.detach(), batch_top_weights)) + log_prob = Hierarchical.log_prob(batch_data, self.bases, self.priors, batch_top_weights) + bases_optimizer.zero_grad() (-torch.mean(log_prob)).backward() - weights_optimizer.step() - - # Update bases from the current weights. - weights = list(map(lambda weight: weight.detach(), top_weights)) - log_prob = Hierarchical.log_prob(data, self.bases, self.priors, weights) - bases_optimizer.zero_grad() - (-torch.mean(log_prob)).backward() - bases_optimizer.step() + bases_optimizer.step() - # Normalize basis elements (project them back onto the unit sphere). - with torch.no_grad(): - for basis in self.bases: - basis /= torch.norm(basis, dim=1, keepdim=True) + # Normalize basis elements (project them back onto the unit sphere). + with torch.no_grad(): + for basis in self.bases: + basis /= torch.norm(basis, dim=1, keepdim=True) - if return_history: - for (basis_history, basis) in zip(bases_history, self.bases): - basis_history.append(basis.detach().clone()) + if do_sgd: + for (top_weight, epoch_top_weight) in zip(top_weights, epoch_top_weights): + top_weight[epoch_idxs] = epoch_top_weight.detach() + + if return_history: + for (basis_history, basis) in zip(bases_history, self.bases): + basis_history.append(basis.detach().clone()) if return_history: for layer in range(self.L - 1):