-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatasets.py
More file actions
115 lines (87 loc) · 3.39 KB
/
datasets.py
File metadata and controls
115 lines (87 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch.utils.data as data
from PIL import Image
import numpy as np
import torchvision
from torchvision.datasets import CIFAR10, ImageFolder, DatasetFolder
import os
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
def mkdirs(dirpath):
try:
os.makedirs(dirpath)
except Exception as _:
pass
class CIFAR10_truncated(data.Dataset):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
self.download = download
self.data, self.target = self.__build_truncated_dataset__()
def __build_truncated_dataset__(self):
cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download)
if torchvision.__version__ == '0.2.1':
if self.train:
data, target = cifar_dataobj.train_data, np.array(cifar_dataobj.train_labels)
else:
data, target = cifar_dataobj.test_data, np.array(cifar_dataobj.test_labels)
else:
data = cifar_dataobj.data
target = np.array(cifar_dataobj.targets)
if self.dataidxs is not None:
data = data[self.dataidxs]
target = target[self.dataidxs]
return data, target
def truncate_channel(self, index):
for i in range(index.shape[0]):
gs_index = index[i]
self.data[gs_index, :, :, 1] = 0.0
self.data[gs_index, :, :, 2] = 0.0
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.target[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
class ImageFolder_custom(DatasetFolder):
def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None):
self.root = root
self.dataidxs = dataidxs
self.train = train
self.transform = transform
self.target_transform = target_transform
imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform)
self.loader = imagefolder_obj.loader
if self.dataidxs is not None:
self.samples = np.array(imagefolder_obj.samples)[self.dataidxs]
else:
self.samples = np.array(imagefolder_obj.samples)
def __getitem__(self, index):
path = self.samples[index][0]
target = self.samples[index][1]
target = int(target)
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
if self.dataidxs is None:
return len(self.samples)
else:
return len(self.dataidxs)