From ebb5a87a97345aaa42a58b75563d983c614f7ce1 Mon Sep 17 00:00:00 2001 From: HarrisonSantiago Date: Mon, 2 Dec 2024 19:10:39 -0500 Subject: [PATCH 1/4] trying to make quilt work w/ overlapping patches --- sparsecoding/transforms/images.py | 61 ++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index eb1ecf3..e056e32 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -410,11 +410,11 @@ def quilt( height: int, width: int, patches: torch.Tensor, + stride: int = None, ): - """Gather square patches into an image. - - Inverse of `patchify()`. - + """Gather square patches into an image, supporting overlapping patches. + Works with patches created by `patchify()` with a custom stride. + Parameters ---------- height : int @@ -422,12 +422,15 @@ def quilt( width : int Width for the reconstructed image. patches : Tensor, shape [*, N, C, P, P] - Non-overlapping patches from an input image, + Potentially overlapping patches from an input image, where: P is the patch size, N is the number of patches, C is the number of channels in the image. - + stride : int, optional + Stride used when creating patches. If None, assumes non-overlapping patches + (stride = patch_size). + Returns ------- image : Tensor, shape [*, C, height, width] @@ -437,30 +440,52 @@ def quilt( N, C, P = patches.shape[-4:-1] H = height W = width - - if int(H / P) * int(W / P) != N: + + if stride is None: + stride = P + + # Calculate expected number of patches based on stride + expected_N = ( + int((H - P + 1 + stride) // stride) + * int((W - P + 1 + stride) // stride) + ) + if expected_N != N: raise ValueError( - f"Expected {N} patches per image, " - f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}." + f"Expected {expected_N} patches per image based on stride {stride}, " + f"got {N} patches." ) - + if ( - H % P != 0 - or W % P != 0 + H % stride != 0 + or W % stride != 0 ): warnings.warn( - f"Image size ({H, W}) not evenly divisible by `patch_size` ({P})," - f"parts on the bottom and/or right will be zeroed.", + f"Image size ({H, W}) not evenly divisible by stride ({stride})," + f"parts on the bottom and/or right may be affected.", UserWarning, ) - + + # Reshape patches for folding operation patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P] patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N] + + # Use fold operation with the specified stride image = torch.nn.functional.fold( input=patches, output_size=(H, W), kernel_size=P, - stride=P, + stride=stride, ) # [prod(*), C, H, W] - + + # Create a ones tensor of the same shape as patches to track overlapping regions + normalization = torch.nn.functional.fold( + torch.ones_like(patches), + output_size=(H, W), + kernel_size=P, + stride=stride, + ) + + # Normalize by the count of overlapping patches + image = image / (normalization + 1e-6) # Add small epsilon to avoid division by zero + return image.reshape(*leading_dims, C, H, W) From de2d5adf1771cb9a7a9b568c31aad1c588ea5283 Mon Sep 17 00:00:00 2001 From: HarrisonSantiago Date: Mon, 2 Dec 2024 21:49:46 -0500 Subject: [PATCH 2/4] linting --- sparsecoding/transforms/images.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index e056e32..9a196d7 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -412,9 +412,10 @@ def quilt( patches: torch.Tensor, stride: int = None, ): - """Gather square patches into an image, supporting overlapping patches. - Works with patches created by `patchify()` with a custom stride. - + """Gather square patches into an image. + + Inverse of `patchify()` + Parameters ---------- height : int @@ -430,7 +431,7 @@ def quilt( stride : int, optional Stride used when creating patches. If None, assumes non-overlapping patches (stride = patch_size). - + Returns ------- image : Tensor, shape [*, C, height, width] @@ -440,10 +441,10 @@ def quilt( N, C, P = patches.shape[-4:-1] H = height W = width - + if stride is None: stride = P - + # Calculate expected number of patches based on stride expected_N = ( int((H - P + 1 + stride) // stride) @@ -454,7 +455,12 @@ def quilt( f"Expected {expected_N} patches per image based on stride {stride}, " f"got {N} patches." ) - + + if stride > P: + raise RuntimeError( + "Stride cannot be larger than the size of a patch when quilting" + ) + if ( H % stride != 0 or W % stride != 0 @@ -464,11 +470,11 @@ def quilt( f"parts on the bottom and/or right may be affected.", UserWarning, ) - + # Reshape patches for folding operation patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P] patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N] - + # Use fold operation with the specified stride image = torch.nn.functional.fold( input=patches, @@ -476,7 +482,7 @@ def quilt( kernel_size=P, stride=stride, ) # [prod(*), C, H, W] - + # Create a ones tensor of the same shape as patches to track overlapping regions normalization = torch.nn.functional.fold( torch.ones_like(patches), @@ -484,8 +490,8 @@ def quilt( kernel_size=P, stride=stride, ) - + # Normalize by the count of overlapping patches image = image / (normalization + 1e-6) # Add small epsilon to avoid division by zero - + return image.reshape(*leading_dims, C, H, W) From 89074a76f56ee1266d035c64202ac441b5b35f7a Mon Sep 17 00:00:00 2001 From: HarrisonSantiago Date: Mon, 2 Dec 2024 21:56:21 -0500 Subject: [PATCH 3/4] doc string --- sparsecoding/transforms/images.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index 9a196d7..dd71304 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -412,7 +412,7 @@ def quilt( patches: torch.Tensor, stride: int = None, ): - """Gather square patches into an image. + """Gather square patches into an image. When patches overlap, take the average of overlapping pixels Inverse of `patchify()` @@ -445,7 +445,6 @@ def quilt( if stride is None: stride = P - # Calculate expected number of patches based on stride expected_N = ( int((H - P + 1 + stride) // stride) * int((W - P + 1 + stride) // stride) @@ -471,11 +470,9 @@ def quilt( UserWarning, ) - # Reshape patches for folding operation patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P] patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N] - # Use fold operation with the specified stride image = torch.nn.functional.fold( input=patches, output_size=(H, W), @@ -483,7 +480,6 @@ def quilt( stride=stride, ) # [prod(*), C, H, W] - # Create a ones tensor of the same shape as patches to track overlapping regions normalization = torch.nn.functional.fold( torch.ones_like(patches), output_size=(H, W), @@ -491,7 +487,6 @@ def quilt( stride=stride, ) - # Normalize by the count of overlapping patches - image = image / (normalization + 1e-6) # Add small epsilon to avoid division by zero + image = image / (normalization + 1e-6) return image.reshape(*leading_dims, C, H, W) From 9dad7789771ac2b178f4b51867fa2cb9fa7e2baa Mon Sep 17 00:00:00 2001 From: HarrisonSantiago Date: Mon, 2 Dec 2024 21:57:41 -0500 Subject: [PATCH 4/4] linting --- sparsecoding/transforms/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index dd71304..26dfe43 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -487,6 +487,6 @@ def quilt( stride=stride, ) - image = image / (normalization + 1e-6) + image = image / (normalization + 1e-6) return image.reshape(*leading_dims, C, H, W)