diff --git a/sparsecoding/transforms/images.py b/sparsecoding/transforms/images.py index eb1ecf3..26dfe43 100644 --- a/sparsecoding/transforms/images.py +++ b/sparsecoding/transforms/images.py @@ -410,10 +410,11 @@ def quilt( height: int, width: int, patches: torch.Tensor, + stride: int = None, ): - """Gather square patches into an image. + """Gather square patches into an image. When patches overlap, take the average of overlapping pixels - Inverse of `patchify()`. + Inverse of `patchify()` Parameters ---------- @@ -422,11 +423,14 @@ def quilt( width : int Width for the reconstructed image. patches : Tensor, shape [*, N, C, P, P] - Non-overlapping patches from an input image, + Potentially overlapping patches from an input image, where: P is the patch size, N is the number of patches, C is the number of channels in the image. + stride : int, optional + Stride used when creating patches. If None, assumes non-overlapping patches + (stride = patch_size). Returns ------- @@ -438,29 +442,51 @@ def quilt( H = height W = width - if int(H / P) * int(W / P) != N: + if stride is None: + stride = P + + expected_N = ( + int((H - P + 1 + stride) // stride) + * int((W - P + 1 + stride) // stride) + ) + if expected_N != N: raise ValueError( - f"Expected {N} patches per image, " - f"got int(H/P) * int(W/P) = {int(H / P) * int(W / P)}." + f"Expected {expected_N} patches per image based on stride {stride}, " + f"got {N} patches." + ) + + if stride > P: + raise RuntimeError( + "Stride cannot be larger than the size of a patch when quilting" ) if ( - H % P != 0 - or W % P != 0 + H % stride != 0 + or W % stride != 0 ): warnings.warn( - f"Image size ({H, W}) not evenly divisible by `patch_size` ({P})," - f"parts on the bottom and/or right will be zeroed.", + f"Image size ({H, W}) not evenly divisible by stride ({stride})," + f"parts on the bottom and/or right may be affected.", UserWarning, ) patches = patches.reshape(-1, N, C*P*P) # [prod(*), N, C*P*P] patches = torch.permute(patches, (0, 2, 1)) # [prod(*), C*P*P, N] + image = torch.nn.functional.fold( input=patches, output_size=(H, W), kernel_size=P, - stride=P, + stride=stride, ) # [prod(*), C, H, W] + normalization = torch.nn.functional.fold( + torch.ones_like(patches), + output_size=(H, W), + kernel_size=P, + stride=stride, + ) + + image = image / (normalization + 1e-6) + return image.reshape(*leading_dims, C, H, W)