Add a transform to make samples 0 mean and unit standard deviation.
aka, reintroduce this into our current transform module.
def preprocess_images(images, patch_size_x, patch_size_y):
data = images.detach().clone()
means = torch.mean(data, dim=(1, 2, 3), keepdims=True)
data = data - means
stds = 10torch.std(data, dim=(1, 2, 3), keepdims=True)
data = data / stds
data = data.reshape(-1, patch_size_xpatch_size_y)
return data