Skip to content

Test MNIST trained model on SVHN and USPS #1

@Simon4Yan

Description

@Simon4Yan

Pretrained MNIST model

USPS

DATASET

class USPS(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None):
        super(USPS, self).__init__()
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        filename = 'usps.bz2' if train else 'usps.t.bz2'
        full_path = os.path.join(self.root, filename)

        import bz2
        with bz2.open(full_path) as fp:
            raw_data = [l.decode().split() for l in fp.readlines()]
            imgs = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
            imgs = np.asarray(imgs, dtype=np.float32).reshape((-1, 16, 16))
            imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
            targets = [int(d[0]) - 1 for d in raw_data]

        self.data = imgs
        self.targets = targets

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

test_loader = torch.utils.data.DataLoader(
        USPS('../raw_data', train=False,
             transform=transforms.Compose([
                 transforms.Resize([28, 28]),
                 transforms.ToTensor(),
                 transforms.Normalize((0.5,), (0.5,))
             ])),
        batch_size=args.test_batch_size, shuffle=False, **kwargs)

SVHN

import torchvision.datasets as dataset

test_loader = torch.utils.data.DataLoader(
    dataset.SVHN('../raw_data', 'test',
                 transform=transforms.Compose([
                     transforms.Resize([28, 28]),
                     transforms.ToTensor(),
                     transforms.Normalize((0.5,), (0.5,))
                 ]), download=True),
    batch_size=args.test_batch_size, shuffle=False, **kwargs)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions