-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
80 lines (53 loc) · 1.92 KB
/
utils.py
File metadata and controls
80 lines (53 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
''' Handling the data io '''
from torchvision import transforms, datasets
import numpy as np
import zipfile
from io import open
import glob
from PIL import Image, ImageOps
import os
import string
# Read data
def extractZipFiles(zip_file, extract_to):
''' Extract from zip '''
with zipfile.ZipFile(zip_file, 'r')as zipped_ref:
zipped_ref.extractall(extract_to)
print('done')
data_dir = 'data/captcha_images_v2/*.png'
def findFiles(path): return glob.glob(path)
# find letter inde from targets_flat
def letterToIndex(letter):
return all_letters.find(letter)
# print(letterToIndex('l'))
# index to letter
indexToLetter = {letterToIndex(i):i for i in all_letters}
data = [img for img in findFiles(data_dir)]
targets = [os.path.basename(x)[:-4] for x in glob.glob(data_dir)]
# abcde -> [a, b, c, d, e]
pre_targets_flat = [[c for c in x] for x in targets]
encoded_targets = np.array([[letterToIndex(c) for c in x] for x in pre_targets_flat])
targets_flat = [char for word in pre_targets_flat for char in word]
unique_letters = set(char for word in targets for char in word)
class CaptchaDataset(Dataset):
"""
Args:
data (string): Path to the file with all the images.
target (string): Path to the file with annotations.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
def __init__(self, data, target=None, transform=None):
self.data = data
self.target = target
self.transform = transform
def __getitem__(self, index):
# read image
x = Image.open(self.data[index]).convert('RGB')
y = self.target[index]
# resize, turn to 0,1
if self.transform:
x = self.transform(x)
return x, torch.tensor(y, dtype=torch.long)
return x, y
def __len__(self):
return len(self.data)