-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
USPS
class USPS(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None):
super(USPS, self).__init__()
self.root = root
self.transform = transform
self.target_transform = target_transform
filename = 'usps.bz2' if train else 'usps.t.bz2'
full_path = os.path.join(self.root, filename)
import bz2
with bz2.open(full_path) as fp:
raw_data = [l.decode().split() for l in fp.readlines()]
imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
targets = [int(d[0]) - 1 for d in raw_data]
self.data = imgs
self.targets = targets
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
test_loader = torch.utils.data.DataLoader(
USPS('../raw_data', train=False,
transform=transforms.Compose([
transforms.Resize([28, 28]),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])),
batch_size=args.test_batch_size, shuffle=False, **kwargs)
SVHN
import torchvision.datasets as dataset
test_loader = torch.utils.data.DataLoader(
dataset.SVHN('../raw_data', 'test',
transform=transforms.Compose([
transforms.Resize([28, 28]),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
]), download=True),
batch_size=args.test_batch_size, shuffle=False, **kwargs)
Metadata
Metadata
Assignees
Labels
No labels