forked from rwightman/pytorch-nips2017-attack-example
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
91 lines (74 loc) · 2.65 KB
/
dataset.py
File metadata and controls
91 lines (74 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import os.path
import torch
import pandas as pd
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
IMG_EXTENSIONS = ['.png', '.jpg']
class LeNormalize(object):
"""Normalize to -1..1 in Google Inception style
"""
def __call__(self, tensor):
for t in tensor:
t.sub_(0.5).mul_(2.0)
return tensor
def default_inception_transform(img_size):
tf = transforms.Compose([
transforms.Scale(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
LeNormalize(),
])
return tf
def find_inputs(folder, filename_to_target=None, types=IMG_EXTENSIONS):
inputs = []
for root, _, files in os.walk(folder, topdown=False):
for rel_filename in files:
base, ext = os.path.splitext(rel_filename)
if ext.lower() in types:
abs_filename = os.path.join(root, rel_filename)
target = filename_to_target[rel_filename] if filename_to_target else 0
inputs.append((abs_filename, target))
return inputs
class Dataset(data.Dataset):
def __init__(
self,
root,
target_file='target_class.csv',
transform=None):
if target_file:
target_df = pd.read_csv(os.path.join(root, target_file), header=None)
f_to_t = dict(zip(target_df[0], target_df[1] - 1)) # -1 for 0-999 class ids
else:
f_to_t = dict()
imgs = find_inputs(root, filename_to_target=f_to_t)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
def __getitem__(self, index):
path, target = self.imgs[index]
img = Image.open(path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.zeros(1).long()
return img, target
def __len__(self):
return len(self.imgs)
def set_transform(self, transform):
self.transform = transform
def filenames(self, indices=[], basename=False):
if indices:
if basename:
return [os.path.basename(self.imgs[i][0]) for i in indices]
else:
return [self.imgs[i][0] for i in indices]
else:
if basename:
return [os.path.basename(x[0]) for x in self.imgs]
else:
return [x[0] for x in self.imgs]