From 51a99cbfbb1c598abb9b9eb2b3a1895af03dc55a Mon Sep 17 00:00:00 2001 From: Alexander Belsten Date: Fri, 28 Feb 2025 14:29:51 -0800 Subject: [PATCH 1/6] sample random images --- sparsecoding/datasets.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py index 13c82a8..c2489c9 100644 --- a/sparsecoding/datasets.py +++ b/sparsecoding/datasets.py @@ -1,7 +1,7 @@ import torch import os from scipy.io import loadmat -from sparsecoding.transforms import patchify +from sparsecoding.transforms import sample_random_patches from torch.utils.data import Dataset from sparsecoding.priors import Prior @@ -93,12 +93,10 @@ class FieldDataset(Dataset): def __init__( self, root: str, + num_patches: int, patch_size: int = 8, - stride: int = None, ): self.P = patch_size - if stride is None: - stride = patch_size root = os.path.expanduser(root) os.system(f"mkdir -p {root}") @@ -112,8 +110,7 @@ def __init__( self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] self.images = torch.reshape(self.images, (self.B, self.C, self.H, self.W)) # [B, C, H, W] - self.patches = patchify(patch_size, self.images, stride) # [B, N, C, P, P] - self.patches = torch.reshape(self.patches, (-1, self.C, self.P, self.P)) # [B*N, C, P, P] + self.patches = sample_random_patches(patch_size, num_patches, self.images) # [N, C, P, P] def __len__(self): return self.patches.shape[0] From 99ddd541ef99593de352e46a199c00f8601c7750 Mon Sep 17 00:00:00 2001 From: Alexander Belsten Date: Fri, 28 Feb 2025 14:31:17 -0800 Subject: [PATCH 2/6] Use Optional typing --- sparsecoding/transforms/images.py | 4 ++-- sparsecoding/transforms/whiten.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index cd48994..fb86e1e 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -33,7 +33,7 @@ def check_images(images: torch.Tensor, algorithm: str = "zca"): ) -def whiten_images(images: torch.Tensor, algorithm: str, stats: Dict = None, **kwargs) -> torch.Tensor: +def whiten_images(images: torch.Tensor, algorithm: str, stats: Optional[Dict] = None, **kwargs) -> torch.Tensor: """ Wrapper for all whitening transformations @@ -335,7 +335,7 @@ def sample_random_patches( def patchify( patch_size: int, image: torch.Tensor, - stride: int = None, + stride: Optional[int] = None, ): """Break an image into square patches. diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index c2ad175..e14744f 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -1,5 +1,5 @@ import torch -from typing import Dict +from typing import Dict, Optional def compute_whitening_stats(X: torch.Tensor): @@ -33,8 +33,8 @@ def compute_whitening_stats(X: torch.Tensor): def whiten( X: torch.Tensor, algorithm: str = "zca", - stats: Dict = None, - n_components: float = None, + stats: Optional[Dict]= None, + n_components: Optional[float] = None, epsilon: float = 0.0, return_W: bool = False, ) -> torch.Tensor: From 4a40d63c61e0b9269e65e006aebe1069a3e26130 Mon Sep 17 00:00:00 2001 From: Alexander Belsten Date: Fri, 28 Feb 2025 14:32:33 -0800 Subject: [PATCH 3/6] lint --- sparsecoding/transforms/whiten.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsecoding/transforms/whiten.py b/sparsecoding/transforms/whiten.py index e14744f..8028fce 100644 --- a/sparsecoding/transforms/whiten.py +++ b/sparsecoding/transforms/whiten.py @@ -33,7 +33,7 @@ def compute_whitening_stats(X: torch.Tensor): def whiten( X: torch.Tensor, algorithm: str = "zca", - stats: Optional[Dict]= None, + stats: Optional[Dict] = None, n_components: Optional[float] = None, epsilon: float = 0.0, return_W: bool = False, From cdef9caed4f131fdd3a8f46b402502252ea9cce2 Mon Sep 17 00:00:00 2001 From: Alexander Belsten Date: Fri, 28 Feb 2025 14:57:11 -0800 Subject: [PATCH 4/6] test FieldDataset --- sparsecoding/test_datasets.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 sparsecoding/test_datasets.py diff --git a/sparsecoding/test_datasets.py b/sparsecoding/test_datasets.py new file mode 100644 index 0000000..63fb7dd --- /dev/null +++ b/sparsecoding/test_datasets.py @@ -0,0 +1,14 @@ +from sparsecoding.datasets import FieldDataset + + +def test_FieldDataset( + patch_size_fixture: int, + dataset_size_fixture: int, +): + fielddataset = FieldDataset( + root="data/", + num_patches=dataset_size_fixture, + patch_size=patch_size_fixture, + ) + assert len(fielddataset) == dataset_size_fixture + assert fielddataset.patches.shape == (dataset_size_fixture, 1, patch_size_fixture, patch_size_fixture) From b6416b51c83d823b43adf8577581faad18f0144a Mon Sep 17 00:00:00 2001 From: Alexander Belsten Date: Fri, 28 Feb 2025 15:23:32 -0800 Subject: [PATCH 5/6] Functionality to download unwhitened images. Update inline documentation. Test whitened and unwhitened data --- sparsecoding/datasets.py | 26 +++++++++++++++++--------- sparsecoding/test_datasets.py | 8 ++++++-- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sparsecoding/datasets.py b/sparsecoding/datasets.py index c2489c9..c2b58cf 100644 --- a/sparsecoding/datasets.py +++ b/sparsecoding/datasets.py @@ -79,10 +79,11 @@ class FieldDataset(Dataset): root : str Location to download the dataset to. patch_size : int - Side length of patches for sparse dictionary learning. - stride : int, optional - Stride for sampling patches. If not specified, set to `patch_size` - (non-overlapping patches). + Side length of patches to extract from images. + num_patches : int + Number of patches to extract from images. + whitened : bool, default=True + Download the whitened or unwhitened dataset. """ B = 10 @@ -95,16 +96,23 @@ def __init__( root: str, num_patches: int, patch_size: int = 8, + whitened: bool = True ): self.P = patch_size root = os.path.expanduser(root) os.system(f"mkdir -p {root}") - if not os.path.exists(f"{root}/field.mat"): - os.system("wget https://rctn.org/bruno/sparsenet/IMAGES.mat") - os.system(f"mv IMAGES.mat {root}/field.mat") - - self.images = torch.tensor(loadmat(f"{root}/field.mat")["IMAGES"]) # [H, W, B] + if whitened: + filename = "IMAGES.mat" + key = "IMAGES" + else: + filename = "IMAGES_RAW.mat" + key = "IMAGESr" + path_to_dataset = os.path.join(root, filename) + if not os.path.exists(path_to_dataset): + os.system(f"wget https://rctn.org/bruno/sparsenet/{filename} -P {root}") + + self.images = torch.from_numpy(loadmat(path_to_dataset)[key].astype(float)) # [H, W, B] assert self.images.shape == (self.H, self.W, self.B) self.images = torch.permute(self.images, (2, 0, 1)) # [B, H, W] diff --git a/sparsecoding/test_datasets.py b/sparsecoding/test_datasets.py index 63fb7dd..ff2c67b 100644 --- a/sparsecoding/test_datasets.py +++ b/sparsecoding/test_datasets.py @@ -1,14 +1,18 @@ -from sparsecoding.datasets import FieldDataset +import pytest +from sparsecoding.datasets import FieldDataset +@pytest.mark.parametrize("whitened", [True, False]) def test_FieldDataset( patch_size_fixture: int, dataset_size_fixture: int, + whitened:bool, ): fielddataset = FieldDataset( - root="data/", + root="data", num_patches=dataset_size_fixture, patch_size=patch_size_fixture, + whitened=whitened ) assert len(fielddataset) == dataset_size_fixture assert fielddataset.patches.shape == (dataset_size_fixture, 1, patch_size_fixture, patch_size_fixture) From 29895033491fcedcc9c0bf9667d964fd43c97d88 Mon Sep 17 00:00:00 2001 From: Alexander Belsten Date: Fri, 28 Feb 2025 15:30:38 -0800 Subject: [PATCH 6/6] flake --- sparsecoding/test_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sparsecoding/test_datasets.py b/sparsecoding/test_datasets.py index ff2c67b..a38f231 100644 --- a/sparsecoding/test_datasets.py +++ b/sparsecoding/test_datasets.py @@ -1,12 +1,12 @@ import pytest - from sparsecoding.datasets import FieldDataset + @pytest.mark.parametrize("whitened", [True, False]) def test_FieldDataset( patch_size_fixture: int, dataset_size_fixture: int, - whitened:bool, + whitened: bool, ): fielddataset = FieldDataset( root="data",