diff --git a/requirements.txt b/requirements.txt index 7c6c9f2..4ae815a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ numpy matplotlib torch torchvision +tqdm diff --git a/sparsecoding/data/datasets/bars.py b/sparsecoding/data/datasets/bars.py index 16f877d..61e6e5c 100644 --- a/sparsecoding/data/datasets/bars.py +++ b/sparsecoding/data/datasets/bars.py @@ -1,7 +1,12 @@ +import itertools +from typing import List + +import numpy as np import torch from torch.utils.data import Dataset from sparsecoding.priors.common import Prior +from sparsecoding.models import Hierarchical class BarsDataset(Dataset): @@ -47,6 +52,7 @@ def __init__( h_bars = h_bars.expand(self.P, self.P, self.P) v_bars = v_bars.expand(self.P, self.P, self.P) self.basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] + # self.basis /= np.sqrt(self.P) # Normalize basis. self.weights = prior.sample(self.N) # [N, 2*P] @@ -61,3 +67,84 @@ def __len__(self): def __getitem__(self, idx: int): return self.data[idx] + + +class HierarchicalBarsDataset(Dataset): + """Toy hierarchical dataset of horizontal and vertical bars. + + The L=1 basis functions are horizontal and vertical bars. + + The L=0 basis functions are equal mixtures of two horizontal and vertical bars + on the image border. + + Parameters + ---------- + patch_size : int + Side length for elements of the dataset. + dataset_size : int + Number of dataset elements to generate. + priors : List[Prior] + Prior distributions on the weights, starting from the top-level basis + and going down. + + Attributes + ---------- + bases : List[Tensor], + shapes: + - [6, 2 * patch_size] + - [2 * patch_size, patch_size * patch_size] + Dictionary elements (combinations of horizontal and vertical bars). + weights : List[Tensor], + shapes: + - [dataset_size, 6], + - [dataset_size, 2 * patch_size], + - [dataset_size, patch_size * patch_size]. + Weights for each level of the hierarchy. + data : Tensor, shape [dataset_size, patch_size * patch_size] + Weighted linear combinations of the basis elements. + """ + + def __init__( + self, + patch_size: int, + dataset_size: int, + priors: List[Prior], + ): + self.P = patch_size + self.N = dataset_size + self.priors = priors + + # Specify l1_basis: bars. + one_hots = torch.nn.functional.one_hot(torch.arange(self.P)) # [P, P] + one_hots = one_hots.type(torch.float32) # [P, P] + + h_bars = one_hots.reshape(self.P, self.P, 1) + v_bars = one_hots.reshape(self.P, 1, self.P) + + h_bars = h_bars.expand(self.P, self.P, self.P) + v_bars = v_bars.expand(self.P, self.P, self.P) + l2_basis = torch.cat((h_bars, v_bars), dim=0) # [2*P, P, P] + l2_basis /= np.sqrt(self.P) # Normalize basis. + l2_basis = l2_basis.reshape((2 * self.P, self.P * self.P)) + + # Specify l0_basis: combinations of two bars on the border. + border_bar_idxs = [0, self.P - 1, self.P, 2 * self.P - 1] + l1_basis_idxs = torch.tensor(list(itertools.combinations(border_bar_idxs, 2))) + l1_basis = torch.zeros((6, 2 * self.P), dtype=torch.float32) + l1_basis[torch.arange(6), l1_basis_idxs[:, 0]] = 1. / np.sqrt(2.) + l1_basis[torch.arange(6), l1_basis_idxs[:, 1]] = 1. / np.sqrt(2.) + + self.bases = [l1_basis, l2_basis] + + self.weights = list(map( + lambda prior: prior.sample(self.N), + self.priors, + )) + + self.data = Hierarchical.generate(self.bases, self.weights) + + def __len__(self): + return self.N + + def __getitem__(self, idx: int): + return self.data[idx] diff --git a/sparsecoding/models.py b/sparsecoding/models.py index 38bfb19..c030dfe 100644 --- a/sparsecoding/models.py +++ b/sparsecoding/models.py @@ -1,7 +1,12 @@ +from typing import List, Optional + import numpy as np +import pickle as pkl import torch from torch.utils.data import DataLoader -import pickle as pkl +from tqdm import tqdm + +from sparsecoding.priors.common import Prior class SparseCoding(torch.nn.Module): @@ -213,6 +218,601 @@ def save_dictionary(self, filename): filehandler.close() +class Hierarchical(torch.nn.Module): + """Class for hierarchical sparse coding. + + Layer x_{n+1} is recursively defined as: + x_{n+1} := Phi_n x_n + a_n, + where: + Phi_n is a basis set (with unit norm), + a_n has a sparse prior. + + The `a_n`s can be thought of the errors or residuals in a predictive coding + model, or the sparse weights at each layer in a generative model. + + Parameters + ---------- + priors : List[Prior] + Prior on weights for each layer. + """ + + def __init__( + self, + priors: List[Prior], + ): + self.priors = priors + + self.dims = [prior.D for prior in priors] + + self.bases = [ + torch.normal( + mean=torch.zeros((self.dims[n], self.dims[n + 1]), dtype=torch.float32), + std=torch.ones((self.dims[n], self.dims[n + 1]), dtype=torch.float32), + ) + for n in range(self.L - 1) + ] + # Normalize / project onto unit sphere + self.bases = list(map( + lambda basis: basis / torch.norm(basis, dim=1, keepdim=True), + self.bases + )) + for basis in self.bases: + basis.requires_grad = True + + @property + def L(self): + """Number of layers in the generative model. + """ + return len(self.dims) + + def generate( + bases: List[torch.Tensor], + weights: List[torch.Tensor], + ): + """Run the generative model forward. + + Parameters + ---------- + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + weights : List[Tensor], length L, shape [N, D_i] + Weights at each layer. + + Returns + ------- + data : Tensor, shape [N, D_L] + Generated data from the given weights and bases. + """ + Hierarchical._check_bases_weights(bases, weights) + + x_i = weights[0] + for (basis, weight) in zip(bases, weights[1:]): + x_i = torch.einsum( + "ni,ij->nj", + x_i, + basis, + ) + weight + + return x_i + + def sample( + n_samples: int, + priors: List[Prior], + bases: List[torch.Tensor], + ): + """Sample from the generative model. + + Parameters + ---------- + n_samples : int + Number of samples to generate. + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + priors : List[Prior], length L + Priors for the weights at each layer. + + Returns + ------- + data : Tensor, shape [N, D_L] + Sampled data using the given priors and bases. + """ + if n_samples < 0: + raise ValueError(f"`n_samples` must be non-negative, got {n_samples}.") + Hierarchical._check_bases_priors(bases, priors) + + weights = list(map( + lambda prior: prior.sample(n_samples), + priors, + )) + return Hierarchical.generate(bases, weights) + + def log_prob( + data: torch.Tensor, + bases: List[torch.Tensor], + priors: List[Prior], + weights: List[torch.Tensor], + ): + """Compute the log-probability of the `data` under the generative model. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to get the log-probability of. + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + priors: List[Prior], length L + Priors on the weights at each layer. + weights : List[Tensor], length L - 1, shape [N, D_i] + Weights for the basis functions at each layer, + EXCEPT for the bottom layer, where the weights are + implicitly defined as the difference between the data and the + generated predictions from the previous layers. + + Returns + ------- + log_prob : Tensor, shape [N] + Log-probabilities of the data under the generative model. + """ + Hierarchical._check_bases_priors(bases, priors) + Hierarchical._check_bases_weights( + bases, + # Need to add dummy weights since last layer weights + # are not specified in the input. + weights + [torch.zeros((weights[0].shape[0], bases[-1].shape[1]))] + ) + + # First layer, no basis + x_i = weights[0] + log_prob = priors[0].log_prob(weights[0]) + + # Middle layers + for (prior, basis, weight) in zip(priors[1:-1], bases[:-1], weights[1:]): + x_i = torch.einsum( + "ni,ij->nj", + x_i, # [N, D_i] + basis, # [D_i, D_{i+1}] + ) + weight + log_prob = log_prob + prior.log_prob(weight) + + # Last layer, implicit weights calculated from the data + x_l = torch.einsum( + "ni,ij->nj", + x_i, # [N, D_i] + bases[-1], # [D_i, D_{i+1}] + ) + weight_l = data - x_l + log_prob = log_prob + priors[-1].log_prob(weight_l) + + return log_prob + + def infer_weights( + self, + data: torch.Tensor, + n_iter: int = 1000, + learning_rate: float = 0.1, + return_history: bool = False, + initial_weights: Optional[List[torch.Tensor]] = None, + ): + """Infer weights for the input `data` to maximize the log-likelihood. + + Performs gradient descent with Adam. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of gradient descent to perform. + learning_rate : float + Learning rate for the optimizer. + return_history : bool + Flag to return the history of the inferred weights during inference. + initial_weights : optional, List[Tensor], length L - 1, shape [N, D_i] + If provided, the initial weights to start inference from. + Otherwise, weights are set to 0. + + Returns + ------- + weights : List[Tensor], length L, shape [N, D_i] + Inferred weights for each layer. + weights_history : optional, List[Tensor], length L, shape [n_iter + 1, N, D_i] + Returned if `return_history`. The inferred weights for each layer + throughout inference. + """ + N = data.shape[0] + + if initial_weights is None: + top_weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + for i in range(self.L - 1) + ] + else: + top_weights = initial_weights + for weight in top_weights: + weight.requires_grad = True + + bases = list(map(lambda basis: basis.detach(), self.bases)) + + if return_history: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + weights_history = list(map( + lambda weight: [weight.detach().clone()], + weights, + )) + + optimizer = torch.optim.Adam(top_weights, lr=learning_rate) + for _ in range(n_iter): + log_prob = Hierarchical.log_prob(data, bases, self.priors, top_weights) + optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + optimizer.step() + + if return_history: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + for (weight_history, weight) in zip(weights_history, weights): + weight_history.append(weight.detach().clone()) + + top_weights = list(map(lambda weight: weight.detach(), top_weights)) + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + + if not return_history: + return weights + else: + for layer in range(self.L): + weight_history[layer] = torch.stack(weight_history[layer], dim=0) + return weights, weights_history + + def infer_weights_local( + self, + data: torch.Tensor, + n_iter: int = 100000, + learning_rate: float = 0.0001, + return_history_interval: Optional[int] = None, + initial_weights: Optional[List[torch.Tensor]] = None, + ): + """Infer weights for the input `data` to maximize the log-likelihood. + However, information flow is constrained to be between adjacent layers. + + Performs gradient descent with Adam. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of gradient descent to perform. + learning_rate : float + Learning rate for the optimizer. + return_history_interval : optional, int + If set, inferred weights during inference will be saved + at this frequency. + initial_weights : optional, List[Tensor], length L - 1, shape [N, D_i] + If provided, the initial weights to start inference from. + Otherwise, weights are set to 0. + + Returns + ------- + weights : List[Tensor], length L, shape [N, D_i] + Inferred weights for each layer. + weights_history : optional, List[Tensor], length L, shape [n_iter + 1, N, D_i] + Returned if `return_history`. The inferred weights for each layer + throughout inference. + """ + N = data.shape[0] + + if initial_weights is None: + top_weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32, requires_grad=True) + for i in range(self.L - 1) + ] + else: + top_weights = initial_weights + for weight in top_weights: + weight.requires_grad = True + + bases = list(map(lambda basis: basis.detach(), self.bases)) + + if return_history_interval: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + weights_history = [[weight.detach()] for weight in weights] + + optimizer = torch.optim.Adam(top_weights, lr=learning_rate) + for it in range(n_iter): + # Generate data under current weights (with no gradient) to get targets. + xs_ng = [] + with torch.no_grad(): + xs_ng.append(top_weights[0].detach()) + for (basis, weight) in zip(bases[:-1], top_weights[1:]): + xs_ng.append(xs_ng[-1] @ basis + weight.detach()) + xs_ng.append(data) + + # Get log-probability for Layer 1. + weight_1 = top_weights[0] + x_ng_below = xs_ng[1] + basis_to_below = bases[0] + prior = self.priors[0] + prior_below = self.priors[1] + log_prob = ( + prior.log_prob(weight_1) + + prior_below.log_prob(x_ng_below - weight_1 @ basis_to_below) + ) + + # Get log-probabilities for Layers 2 through L. + for layer in range(2, self.L): + weight = top_weights[layer - 1] + + x_ng_above = xs_ng[layer - 2] + basis_from_above = bases[layer - 2] + + x_ng_below = xs_ng[layer] + basis_to_below = bases[layer - 1] + + prior = self.priors[layer - 1] + prior_below = self.priors[layer] + log_prob += ( + prior.log_prob(weight) + + prior_below.log_prob(x_ng_below - (x_ng_above @ basis_from_above + weight) @ basis_to_below) + ) + + optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + optimizer.step() + + if return_history_interval and it % return_history_interval == 0: + with torch.no_grad(): + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + for (weight_history, weight) in zip(weights_history, weights): + weight_history.append(weight.detach().clone()) + + top_weights = list(map(lambda weight: weight.detach(), top_weights)) + bottom_weights = Hierarchical._compute_bottom_weights(data, bases, top_weights) + weights = top_weights + [bottom_weights] + + if not return_history_interval: + return weights + else: + for layer in range(self.L): + weights_history[layer] = torch.stack(weights_history[layer], dim=0) + return weights, weights_history + + def learn_bases( + self, + data: torch.Tensor, + n_iter: int = 125, + batch_size: Optional[int] = None, + learning_rate: float = 0.01, + inference_n_iter: int = 25, + inference_learning_rate: float = 0.01, + return_history: bool = False, + ): + """Update the bases to maximize the log-likelihood of `data`. + + In each iteration, we first infer the weights under the current basis functions, + and then we update the bases with those weights fixed. + + Uses (stochastic) gradient descent with Adam. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + n_iter : int + Number of iterations of (stochastic) gradient descent to perform. + batch_size : optional, int, default=None + If provided, the batch size used during Stochastic Gradient Descent. + Otherwise, performs Gradient Descent using the entire dataset. + learning_rate : float + Step-size for learning the bases. + inference_n_iter : int + Number of iterations of gradient descent + to perform during weight inference. + inference_learning_rate : float + Step-size for inferring the weights. + return_history : bool + Flag to return the history of the learned bases during inference. + + Returns + ------- + bases_history : optional, List[Tensor], length L - 1, shape [n_iter + 1, D_i, D_{i+1}] + Returned if `return_history`. The learned bases throughout training. + """ + N = data.shape[0] + + do_sgd = (batch_size is None) + if batch_size is None: + batch_size = N + + if return_history: + bases_history = list(map( + lambda basis: [basis.detach().clone()], + self.bases, + )) + + bases_optimizer = torch.optim.Adam(self.bases, lr=learning_rate) + + top_weights = [ + torch.zeros((N, self.dims[i]), dtype=torch.float32) + for i in range(self.L - 1) + ] + + for _ in tqdm(range(n_iter)): + if do_sgd: + epoch_idxs = torch.randperm(N) + epoch_data = data.clone() + epoch_data = epoch_data[epoch_idxs] + epoch_top_weights = [weight[epoch_idxs] for weight in top_weights] + else: + epoch_data = data + epoch_top_weights = top_weights + + for batch in range(N // batch_size): + batch_start_idx = batch * batch_size + batch_end_idx = (batch + 1) * batch_size + batch_data = epoch_data[batch_start_idx:batch_end_idx] + + # Infer weights under the current bases. + batch_top_weights = [ + weight[batch_start_idx:batch_end_idx] + for weight + in epoch_top_weights + ] + for weight in batch_top_weights: + weight.requires_grad = True + batch_weights_optimizer = torch.optim.Adam(batch_top_weights, lr=inference_learning_rate) + bases = list(map(lambda basis: basis.detach(), self.bases)) + for _ in range(inference_n_iter): + log_prob = Hierarchical.log_prob(batch_data, bases, self.priors, batch_top_weights) + batch_weights_optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + batch_weights_optimizer.step() + + # Update bases from the current weights. + batch_top_weights = list(map(lambda weight: weight.detach(), batch_top_weights)) + log_prob = Hierarchical.log_prob(batch_data, self.bases, self.priors, batch_top_weights) + bases_optimizer.zero_grad() + (-torch.mean(log_prob)).backward() + bases_optimizer.step() + + # Normalize basis elements (project them back onto the unit sphere). + with torch.no_grad(): + for basis in self.bases: + basis /= torch.norm(basis, dim=1, keepdim=True) + + if do_sgd: + for (top_weight, epoch_top_weight) in zip(top_weights, epoch_top_weights): + top_weight[epoch_idxs] = epoch_top_weight.detach() + + if return_history: + for (basis_history, basis) in zip(bases_history, self.bases): + basis_history.append(basis.detach().clone()) + + if return_history: + for layer in range(self.L - 1): + bases_history[layer] = torch.stack(bases_history[layer], dim=0) + return bases_history + + def inspect_bases(self): + """Runs the generative model forward for each basis element. + + This allows visual inspection of what each basis function represents + at the final (bottom) layer. + + Returns + ------- + bases_viz : List[List[Tensor]], shape [D_L] + Visualizations of the basis functions at each layer. + """ + bases_viz = [] + for layer in range(self.L - 1): + layer_bases_viz = [] + for basis_fn in range(self.bases[layer].shape[0]): + weights = [torch.zeros((1, dim)) for dim in self.dims] + weights[layer][0, basis_fn] = 1. + layer_bases_viz.append(Hierarchical.generate(self.bases, weights)[0]) + bases_viz.append(layer_bases_viz) + return bases_viz + + def _check_bases_weights(bases, weights): + """Check bases and weights for shape compatibility. + """ + if len(weights) != len(bases) + 1: + raise ValueError( + f"Must have exactly one more weight than basis " + f"(`L` layers and `L-1` bases to transform between them), " + f"got {len(weights)} weights and {len(bases)} bases." + ) + if not all([ + weights[i].shape[0] == weights[0].shape[0] + for i in range(1, len(weights)) + ]): + raise ValueError( + "Weight tensors must all have the same size in the 0-th dimension." + "This is the size of the data to generate." + ) + for (layer, (basis_i, basis_j)) in enumerate(zip(bases[:-1], bases[1:])): + if basis_i.shape[1] != basis_j.shape[0]: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"produces weights of dimension {basis_i.shape[1]} " + f"for layer {layer+1} but " + f"basis between layer {layer+1} and layer {layer+2} " + f"expects {basis_j.shape[0]} weights for layer {layer+1}." + ) + for (layer, (basis, weight)) in enumerate(zip(bases, weights[:-1])): + if basis.shape[0] != weight.shape[1]: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"expects {basis.shape[0]} weights for {layer}, " + f"but {weight.shape[1]} weights " + f"are provided for layer {layer}." + ) + if bases[-1].shape[1] != weights[-1].shape[1]: + raise ValueError( + f"The final basis outputs data with dimension {bases[-1].shape[1]}, " + f"but final layer weights have dimension {weights[-1].shape[1]}." + ) + + def _check_bases_priors(bases, priors): + """Check bases and priors for shape compatibility. + """ + if len(priors) != len(bases) + 1: + raise ValueError( + f"Must have exactly one more prior than basis " + f"(`L` layers and `L-1` bases to transform between them), " + f"got {len(priors)} priors and {len(bases)} bases." + ) + for (layer, (basis_i, basis_j)) in enumerate(zip(bases[:-1], bases[1:])): + if basis_i.shape[1] != basis_j.shape[0]: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"produces priors of dimension {basis_i.shape[1]} " + f"for layer {layer+1} but " + f"basis between layer {layer+1} and layer {layer+2} " + f"expects {basis_j.shape[0]} priors for layer {layer+1}." + ) + for (layer, (basis, prior)) in enumerate(zip(bases, priors[:-1])): + if basis.shape[0] != prior.D: + raise ValueError( + f"Basis between layer {layer} and layer {layer+1} " + f"expects {basis.shape[0]} weights for {layer}, " + f"but the prior for layer {layer} is over {prior.D} weights." + ) + if bases[-1].shape[1] != priors[-1].D: + raise ValueError( + f"The final basis outputs data with dimension {bases[-1].shape[1]}, " + f"but final layer prior is over {priors[-1].D} weights." + ) + + def _compute_bottom_weights( + data: torch.Tensor, + bases: List[torch.Tensor], + top_weights: List[torch.Tensor], + ): + """Compute the bottom-layer weights for `data`, given weights for all the other layers. + + Parameters + ---------- + data : Tensor, shape [N, D_L] + Data to be generated. + bases : List[Tensor], length L - 1, shape [D_i, D_{i+1}] + Basis functions to transform between layers. + weights : List[Tensor], length L - 1, shape [N, D_i] + Weights at the top `L - 1` layers. + """ + weights = top_weights + [torch.zeros_like(data)] + Hierarchical._check_bases_weights(bases, weights) + bottom_weights = data - Hierarchical.generate(bases, weights) + return bottom_weights + + class SimulSparseCoding(SparseCoding): def __init__(self, inference_method, n_basis, n_features, sparsity_penalty, inf_rate=1, learn_rate=1, time_step=1, t_max=1000, diff --git a/sparsecoding/priors/l0.py b/sparsecoding/priors/l0.py index 18bf698..83a770c 100644 --- a/sparsecoding/priors/l0.py +++ b/sparsecoding/priors/l0.py @@ -77,4 +77,94 @@ def log_prob( # TODO: Add L0ExpPrior, where the number of active units is distributed exponentially. -# TODO: Add L0IidPrior, where the magnitude of an active unit is distributed according to an i.i.d. Prior. +class L0IidPrior(Prior): + """L0-sparse Prior with non-binary weights. + + If a weight is active, its value is drawn from an i.i.d. Prior. + + Parameters + ---------- + prob_distr : Tensor, shape [D], dtype float32 + Probability distribution over the l0-norm of the weights. + active_weight_prior : Prior + The distribution for active weights. + Since weights drawn from this distribution must be i.i.d., + we enforce `active_weight_prior.D` to be 1. + """ + + def __init__( + self, + prob_distr: torch.Tensor, + active_weight_prior: Prior, + ): + if prob_distr.dim() != 1: + raise ValueError(f"`prob_distr` shape must be (D,), got {prob_distr.shape}.") + if prob_distr.dtype != torch.float32: + raise ValueError(f"`prob_distr` dtype must be torch.float32, got {prob_distr.dtype}.") + if not torch.allclose(torch.sum(prob_distr), torch.ones(1, dtype=torch.float32)): + raise ValueError(f"`torch.sum(prob_distr)` must be 1., got {torch.sum(prob_distr)}.") + if active_weight_prior.D != 1: + raise ValueError( + f"`active_weight_prior.D` must be 1 (got {active_weight_prior.D}). " + f"This enforces that can sample i.i.d. weights." + ) + + self.prob_distr = prob_distr + self.active_weight_prior = active_weight_prior + + @property + def D(self): + return self.prob_distr.shape[0] + + def sample( + self, + num_samples: int, + ): + N = num_samples + + num_active_weights = 1 + torch.multinomial( + input=self.prob_distr, + num_samples=num_samples, + replacement=True, + ) # [N] + + d_idxs = torch.arange(self.D) + active_idx_mask = ( + d_idxs.reshape(1, self.D) + < num_active_weights.reshape(N, 1) + ) # [N, self.D] + + n_idxs = torch.arange(N).reshape(N, 1).expand(N, self.D) # [N, D] + # Need to shuffle here so that it's not always the first weights that are active. + shuffled_d_idxs = [torch.randperm(self.D) for _ in range(N)] + shuffled_d_idxs = torch.stack(shuffled_d_idxs, dim=0) # [N, D] + + # [num_active_weights], [num_active_weights] + active_weight_idxs = n_idxs[active_idx_mask], shuffled_d_idxs[active_idx_mask] + + weights = torch.zeros((N, self.D), dtype=torch.float32) + n_active_idxs = int(torch.sum(active_idx_mask).cpu().numpy()) + active_weight_values = self.active_weight_prior.sample(n_active_idxs) + weights[active_weight_idxs] += active_weight_values.reshape(-1) + + return weights + + def log_prob( + self, + sample: torch.Tensor, + ): + super().check_sample_input(sample) + + active_weight_mask = (sample != 0.) + + l0_norm = torch.sum(active_weight_mask, dim=1).type(torch.long) # [num_samples] + log_prob = torch.log(self.prob_distr[l0_norm - 1]) + log_prob[l0_norm == 0] = -torch.inf + + active_log_prob = ( + self.active_weight_prior.log_prob(sample.reshape(-1, 1)).reshape(sample.shape) + ) # [num_samples, D] + active_log_prob[~active_weight_mask] = 0. # [num_samples, D] + log_prob += torch.sum(active_log_prob, dim=1) # [num_samples] + + return log_prob diff --git a/tests/models/test_Hierarchical.py b/tests/models/test_Hierarchical.py new file mode 100644 index 0000000..80a919a --- /dev/null +++ b/tests/models/test_Hierarchical.py @@ -0,0 +1,89 @@ +import torch +import unittest + +from sparsecoding.data.datasets.bars import HierarchicalBarsDataset +from sparsecoding.models import Hierarchical +from sparsecoding.priors.laplace import LaplacePrior +from tests.testing_utilities import TestCase + +torch.manual_seed(1997) + +PATCH_SIZE = 8 +DATASET_SIZE = 1000 + +PRIORS = [ + LaplacePrior( + dim=6, + scale=1.0, + positive_only=False, + ), + LaplacePrior( + dim=2 * PATCH_SIZE, + scale=0.1, + positive_only=False, + ), + LaplacePrior( + dim=PATCH_SIZE * PATCH_SIZE, + scale=0.01, + positive_only=False, + ), +] + +DATASET = HierarchicalBarsDataset( + patch_size=PATCH_SIZE, + dataset_size=DATASET_SIZE, + priors=PRIORS, +) + + +class TestHierarchical(TestCase): + def test_infer_weights(self): + """ + Test that Hierarchical inference recovers the correct weights. + """ + model = Hierarchical(priors=PRIORS) + model.bases = DATASET.bases + + weights = model.infer_weights(DATASET.data) + + inferred_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + DATASET.bases, + PRIORS, + weights[:-1], + )) + dataset_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + DATASET.bases, + PRIORS, + DATASET.weights[:-1], + )) + self.assertAllClose(inferred_log_probs, dataset_log_probs, atol=5e-2) + + def test_learn_bases(self): + """ + Test that Hierarchical inference recovers the correct bases. + """ + model = Hierarchical(priors=PRIORS) + + model.learn_bases(DATASET.data) + + weights = model.infer_weights(DATASET.data) + + inferred_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + list(map(lambda basis: basis.detach(), model.bases)), + PRIORS, + weights[:-1], + )) + dataset_log_probs = torch.mean(Hierarchical.log_prob( + DATASET.data, + DATASET.bases, + PRIORS, + DATASET.weights[:-1], + )) + self.assertAllClose(inferred_log_probs, dataset_log_probs, atol=3e0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/priors/test_l0.py b/tests/priors/test_l0.py index 323d12e..be9e68c 100644 --- a/tests/priors/test_l0.py +++ b/tests/priors/test_l0.py @@ -3,7 +3,8 @@ import torch import unittest -from sparsecoding.priors.l0 import L0Prior +from sparsecoding.priors.l0 import L0Prior, L0IidPrior +from sparsecoding.priors.laplace import LaplacePrior class TestL0Prior(unittest.TestCase): @@ -63,5 +64,74 @@ def test_log_prob(self): assert log_probs[7] == -torch.inf +class TestL0IidPrior(unittest.TestCase): + def test_sample(self): + N = 10000 + prob_distr = torch.tensor([0.5, 0.25, 0, 0.25]) + iid_prior = LaplacePrior(dim=1, scale=1, positive_only=True) + + torch.manual_seed(1997) + + D = prob_distr.shape[0] + + l0_iid_prior = L0IidPrior(prob_distr, iid_prior) + weights = l0_iid_prior.sample(N) + + assert weights.shape == (N, D) + + # Check uniform distribution over which weights are active. + per_weight_hist = torch.sum(weights, dim=0) # [D] + normalized_per_weight_hist = per_weight_hist / torch.sum(per_weight_hist) # [D] + assert torch.allclose( + normalized_per_weight_hist, + torch.full((D,), 0.25, dtype=torch.float32), + atol=1e-2, + ) + + # Check the distribution over the l0-norm of the weights. + num_active_per_sample = torch.sum(weights != 0., dim=1) # [N] + for num_active in range(1, 5): + assert torch.allclose( + torch.sum(num_active_per_sample == num_active) / N, + prob_distr[num_active - 1], + atol=1e-2, + ) + + # Check Laplacian distribution. + laplace_weights = weights[weights > 0.] + for quantile in torch.arange(5) / 5.: + cutoff = -torch.log(1. - quantile) + assert torch.allclose( + torch.sum(laplace_weights < cutoff) / (N * torch.sum(prob_distr * (1 + torch.arange(D)))), + quantile, + atol=1e-2, + ) + + def test_log_prob(self): + prob_distr = torch.tensor([0.75, 0.25, 0.]) + iid_prior = LaplacePrior(dim=1, scale=1, positive_only=True) + + l0_iid_prior = L0IidPrior(prob_distr, iid_prior) + + samples = list(product([0, 1], repeat=3)) # [2**D, D] + samples = torch.tensor(samples, dtype=torch.float32) # [2**D, D] + + log_probs = l0_iid_prior.log_prob(samples) + + # The l0-norm at index `i` + # is the number of ones + # in the binary representation of `i`. + assert log_probs[0] == -torch.inf + assert torch.allclose( + log_probs[[1, 2, 4]], + torch.log(torch.tensor(0.75)) - 1., + ) + assert torch.allclose( + log_probs[[3, 5, 6]], + torch.log(torch.tensor(0.25)) - 2., + ) + assert log_probs[7] == -torch.inf + + if __name__ == "__main__": unittest.main()