diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py index 13c82a8..c2b58cf 100644 --- a/sparsecoding/datasets.py +++ b/sparsecoding/datasets.py @@ -1,7 +1,7 @@ import torch import os from scipy.io import loadmat -from sparsecoding.transforms import patchify +from sparsecoding.transforms import sample_random_patches from torch.utils.data import Dataset from sparsecoding.priors import Prior @@ -79,10 +79,11 @@ class FieldDataset(Dataset): root : str Location to download the dataset to. patch_size : int - Side length of patches for sparse dictionary learning. - stride : int, optional - Stride for sampling patches. If not specified, set to `patch_size` - (non-overlapping patches). + Side length of patches to extract from images. + num_patches : int + Number of patches to extract from images. + whitened : bool, default=True + Download the whitened or unwhitened dataset. """ B = 10 @@ -93,27 +94,31 @@ class FieldDataset(Dataset): def __init__( self, root: str, + num_patches: int, patch_size: int = 8, - stride: int = None, + whitened: bool = True ): self.P = patch_size - if stride is None: - stride = patch_size root = os.path.expanduser(root) os.system(f"mkdir -p {root}") - if not os.path.exists(f"{root}/field.mat"): - os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") - os.system(f"mv IMAGES.mat {root}/field.mat") - - self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] + if whitened: + filename = "IMAGES.mat" + key = "IMAGES" + else: + filename = "IMAGES_RAW.mat" + key = "IMAGESr" + path_to_dataset = os.path.join(root, filename) + if not os.path.exists(path_to_dataset): + os.system(f"wget https://rctn.org/bruno/sparsenet/{filename} -P {root}") + + self.images = torch.from_numpy(loadmat(path_to_dataset)[key].astype(float)) # [H, W, B] assert self.images.shape == (self.H, self.W, self.B) self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] - self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] - self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] + self.patches = sample_random_patches(patch_size, num_patches, self.images) # [N, C, P, P] def __len__(self): return self.patches.shape[0] diff --git a/sparsecoding/test_datasets.py b/sparsecoding/test_datasets.py new file mode 100644 index 0000000..a38f231 --- /dev/null +++ b/sparsecoding/test_datasets.py @@ -0,0 +1,18 @@ +import pytest +from sparsecoding.datasets import FieldDataset + + +@pytest.mark.parametrize("whitened", [True, False]) +def test_FieldDataset( + patch_size_fixture: int, + dataset_size_fixture: int, + whitened: bool, +): + fielddataset = FieldDataset( + root="data", + num_patches=dataset_size_fixture, + patch_size=patch_size_fixture, + whitened=whitened + ) + assert len(fielddataset) == dataset_size_fixture + assert fielddataset.patches.shape == (dataset_size_fixture, 1, patch_size_fixture, patch_size_fixture) diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index cd48994..fb86e1e 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -33,7 +33,7 @@ def check_images(images: torch.Tensor, algorithm: str = "zca"): ) -def whiten_images(images: torch.Tensor, algorithm: str, stats: Dict = None, **kwargs) -> torch.Tensor: +def whiten_images(images: torch.Tensor, algorithm: str, stats: Optional[Dict] = None, **kwargs) -> torch.Tensor: """ Wrapper for all whitening transformations @@ -335,7 +335,7 @@ def sample_random_patches( def patchify( patch_size: int, image: torch.Tensor, - stride: int = None, + stride: Optional[int] = None, ): """Break an image into square patches. diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index c2ad175..8028fce 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -1,5 +1,5 @@ import torch -from typing import Dict +from typing import Dict, Optional def compute_whitening_stats(X: torch.Tensor): @@ -33,8 +33,8 @@ def compute_whitening_stats(X: torch.Tensor): def whiten( X: torch.Tensor, algorithm: str = "zca", - stats: Dict = None, - n_components: float = None, + stats: Optional[Dict] = None, + n_components: Optional[float] = None, epsilon: float = 0.0, return_W: bool = False, ) -> torch.Tensor: