-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
85 lines (69 loc) · 2.39 KB
/
dataset.py
File metadata and controls
85 lines (69 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
'''
Author: Zehui Lin
Date: 2021-01-22 17:20:36
LastEditTime: 2021-03-08 12:16:14
LastEditors: Zehui Lin
Description:
'''
import os
import cv2
import torch
import random
import numpy as np
from tqdm import tqdm
from imgaug import augmenters as iaa
from torch.utils.data import Dataset
class MySet(Dataset):
def __init__(self, txt_path, mode="train", is_debug=False):
self.mode = mode
self.data_buffer = read_buffer(self.mode, txt_path, is_debug)
def __getitem__(self, item):
data_item = self.data_buffer[item]
raw_data = data_item["image"]
label = data_item["label"]
seq = iaa.Sequential([
iaa.Fliplr(0.5),
iaa.Crop(percent=(0, 0.1)),
iaa.Multiply((0.8, 1.2))],
random_order=True)
seq_same = seq.to_deterministic()
if self.mode == "train":
# cv2.imwrite('./data/original/'+str(item)+'.png', raw_data) # save original image
processed_data = seq_same(image=raw_data)
# cv2.imwrite('./data/augmentation/'+str(item)+'.png', processed_data) # visualize the augmentation effect
else:
processed_data = raw_data
data = torch.from_numpy(processed_data.transpose((2, 0, 1)).astype(np.float32)/255)
return data, label
def __len__(self):
return len(self.data_buffer)
def read_buffer(mode, txt_path, is_debug=False):
if mode == "train":
txt_read_path = os.path.join(txt_path, "train.txt")
elif mode == "val":
txt_read_path = os.path.join(txt_path, "val.txt")
elif mode == "test":
txt_read_path = os.path.join(txt_path, "test.txt")
elif mode == "all":
txt_read_path = os.path.join(txt_path, "all.txt")
else:
raise ValueError
fid = open(txt_read_path, "r")
lines = fid.readlines()
random.shuffle(lines) # avoid all same label in smoke test
if is_debug:
tiny_set = lines[0:int(1/10*len(lines))]
lines = tiny_set
fid.close()
data_buffer = []
for line in tqdm(lines, desc="Loading: "+mode+" data"):
line = line.strip()
data_path, label = line.split(",")
raw_data = cv2.imread(data_path)
label = torch.from_numpy(np.array(label, dtype=np.int64))
image_buffer = {
"image": raw_data,
"label": label
}
data_buffer.append(image_buffer)
return data_buffer